attention.py 313 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.

"""Attention."""
6
import collections
7
from contextlib import nullcontext
8
from importlib.metadata import version as get_pkg_version
9
from importlib.metadata import PackageNotFoundError
10
import math
11
import os
12
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
13
import warnings
14
import logging
15

cyanguwa's avatar
cyanguwa committed
16
import numpy as np
17
from packaging.version import Version as PkgVersion
18
19
20

import torch

21
import transformer_engine_torch as tex
22
23
24
25
26
from transformer_engine.pytorch.utils import (
    get_cudnn_version,
    nvtx_range_pop,
    nvtx_range_push,
)
27
from transformer_engine.pytorch.cpp_extensions.fused_attn import (
28
29
    fused_attn_fwd,
    fused_attn_bwd,
30
    FusedAttnBackend,
31
32
33
34
35
36
37
    META_QKV,
    META_O,
)
from transformer_engine.pytorch.fp8 import (
    FP8GlobalStateManager,
    get_fp8_te_dtype,
    get_fp8_torch_dtype,
38
)
39
from transformer_engine.pytorch.float8_tensor import Float8Tensor
40
from transformer_engine.pytorch.tensor._internal.float8_tensor_base import Float8TensorBase
41
from transformer_engine.pytorch.module import LayerNormLinear, Linear
42
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
43
44
45
46
47
from transformer_engine.pytorch.utils import (
    divide,
    attention_mask_func,
    split_tensor_along_dim,
    get_device_compute_capability,
48
    get_default_init_method,
49
50
51
52
)
from transformer_engine.pytorch.constants import (
    AttnMaskTypes,
    AttnTypes,
53
    AttnBiasTypes,
54
    QKVLayouts,
55
    dist_group_type,
56
    TE_DType,
57
58
59
60
)
from transformer_engine.pytorch.softmax import FusedScaleMaskSoftmax
from transformer_engine.pytorch.distributed import (
    get_distributed_world_size,
61
    get_distributed_rank,
62
    checkpoint,
63
64
65
    set_all_rng_states,
    CudaRNGStatesTracker,
    graph_safe_rng_available,
66
67
    gather_along_first_dim,
    reduce_scatter_along_first_dim,
68
)
69
from transformer_engine.pytorch.jit import jit_fuser, no_torch_dynamo
70
from transformer_engine.pytorch.graph import is_graph_capturing
71
72
73
74
75
from transformer_engine.pytorch.tensor.quantized_tensor import (
    QuantizedTensor,
    prepare_for_saving,
    restore_from_saved,
)
76

77
78
79
80
81
82
# Import attention utils
import transformer_engine.pytorch.dot_product_attention.utils as dpa_utils
from transformer_engine.pytorch.dot_product_attention.inference import InferenceParams
from transformer_engine.pytorch.dot_product_attention.utils import FlashAttentionUtils as fa_utils
from transformer_engine.pytorch.dot_product_attention.utils import AttentionLogging as attn_log
from transformer_engine.pytorch.dot_product_attention.rope import apply_rotary_pos_emb
83
84


85
86
87
88
# Setup Attention Logging
attn_log.setup_logging()

# Global vars for flash attn imports
89
flash_attn_cuda_bwd = None
90
91
flash_attn_func = None
flash_attn_varlen_func = None
92
93
94
95
_flash_attn_fwd = None
_flash_attn_bwd = None
_flash_attn_varlen_fwd = None
_flash_attn_varlen_bwd = None
96
try:
97
    fa_utils.version = PkgVersion(get_pkg_version("flash-attn"))
98
except PackageNotFoundError:
99
100
101
102
103
104
    if (
        torch.cuda.is_available()
        and get_device_compute_capability() >= (8, 0)
        and dpa_utils._NVTE_FLASH_ATTN
    ):
        attn_log.fa_logger.debug(
105
            "flash-attn v2 is not installed. To use, please install it by"
106
            """ "pip3 install flash-attn".""",
107
108
        )
else:
109
    if torch.cuda.is_available() and get_device_compute_capability() >= (10, 0):
110
111
112
113
        if fa_utils.version_required_blackwell <= fa_utils.version <= fa_utils.max_version:
            fa_utils.is_installed = True
    elif fa_utils.version_required <= fa_utils.version <= fa_utils.max_version:
        fa_utils.is_installed = True
114

115
    if fa_utils.is_installed:
116
        from flash_attn_2_cuda import varlen_bwd as flash_attn_cuda_bwd
117
        from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func
118
119
        from flash_attn.flash_attn_interface import _flash_attn_forward as _flash_attn_fwd
        from flash_attn.flash_attn_interface import _flash_attn_backward as _flash_attn_bwd
120
        from flash_attn.flash_attn_interface import (
121
            _flash_attn_varlen_forward as _flash_attn_varlen_fwd,
122
123
        )
        from flash_attn.flash_attn_interface import (
124
            _flash_attn_varlen_backward as _flash_attn_varlen_bwd,
125
126
        )

127
128
        # Setup Flash attention utils
        fa_utils.set_flash_attention_version()
129
    elif (
130
131
132
        torch.cuda.is_available()
        and get_device_compute_capability() >= (8, 0)
        and dpa_utils._NVTE_FLASH_ATTN
133
    ):
134
        attn_log.fa_logger.warning(
135
            "Supported flash-attn versions are %s. Found flash-attn %s.",
136
            dpa_utils._get_supported_versions(
137
                (
138
                    fa_utils.version_required
139
                    if get_device_compute_capability() < (10, 0)
140
                    else fa_utils.version_required_blackwell
141
                ),
142
                fa_utils.max_version,
143
            ),
144
            fa_utils.version,
145
146
147
148
149
        )

# Detect flash-attn v3 in the environment
# This section will be removed when FA3 is released as a regular FA package,
# i.e. flashattn-hopper 3.0.0 as flash-attn 3.0.0
150
try:
151
    fa_utils.fa3_version = PkgVersion(get_pkg_version("flashattn-hopper"))
152
except PackageNotFoundError:
153
154
155
156
157
158
    if (
        torch.cuda.is_available()
        and get_device_compute_capability() >= (9, 0)
        and dpa_utils._NVTE_FLASH_ATTN
    ):
        attn_log.fa_logger.debug(
159
            "flash-attn v3 is not installed. To use, please install it by \n%s",
160
            fa_utils.v3_installation_steps,
161
        )
162
163
164
165
166
else:
    from flashattn_hopper.flash_attn_interface import flash_attn_func as flash_attn_func_v3
    from flashattn_hopper.flash_attn_interface import (
        flash_attn_varlen_func as flash_attn_varlen_func_v3,
    )
167
168
    from flashattn_hopper.flash_attn_interface import _flash_attn_forward as _flash_attn_fwd_v3
    from flashattn_hopper.flash_attn_interface import _flash_attn_backward as _flash_attn_bwd_v3
169
    from flashattn_hopper.flash_attn_interface import (
170
        _flash_attn_varlen_forward as _flash_attn_varlen_fwd_v3,
171
172
    )
    from flashattn_hopper.flash_attn_interface import (
173
        _flash_attn_varlen_backward as _flash_attn_varlen_bwd_v3,
174
    )
175

176
    fa_utils.set_flash_attention_3_params()
177

178
# Global vars for available attention backends and ALiBi cache
179
180
181
182
183
184
185
_attention_backends = {
    "attention_params": None,
    "use_flash_attention": None,
    "use_fused_attention": None,
    "fused_attention_backend": None,
    "use_unfused_attention": None,
    "backend_selection_requires_update": False,
186
}
187

188
189
190
191
192
193
194
195
196
197
198
_alibi_cache = {
    "_num_heads": None,
    "_alibi_slopes": None,
    "_max_seqlen_q": None,
    "_max_seqlen_kv": None,
    "_bottom_right_alignment": True,
    "_alibi_bias": None,
    "_alibi_slopes_require_update": False,
    "_alibi_bias_require_update": False,
}

199
__all__ = ["DotProductAttention", "MultiheadAttention"]
200
201


202
203
204
205
206
def maybe_contiguous(tensor: torch.Tensor) -> torch.Tensor:
    """Make tensor contiguous if final stride is not 1."""
    return tensor.contiguous() if tensor.stride(-1) != 1 else tensor


207
208
209
def flash_attn_p2p_communicate(
    rank, send_tensor, send_dst, recv_tensor, recv_src, cp_group, batch_p2p_comm
):
210
    """Point-to-point communications of KV and dKV in Attention with context parallelism"""
211
212
213
214
    send_recv_ops = []

    if batch_p2p_comm:
        if rank % 2 == 0:
215
216
217
218
219
220
            send_op = torch.distributed.P2POp(
                torch.distributed.isend, send_tensor, send_dst, cp_group
            )
            recv_op = torch.distributed.P2POp(
                torch.distributed.irecv, recv_tensor, recv_src, cp_group
            )
221
222
223
            send_recv_ops.append(send_op)
            send_recv_ops.append(recv_op)
        else:
224
225
226
227
228
229
            recv_op = torch.distributed.P2POp(
                torch.distributed.irecv, recv_tensor, recv_src, cp_group
            )
            send_op = torch.distributed.P2POp(
                torch.distributed.isend, send_tensor, send_dst, cp_group
            )
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
            send_recv_ops.append(recv_op)
            send_recv_ops.append(send_op)
        send_recv_reqs = torch.distributed.batch_isend_irecv(send_recv_ops)
    else:
        if rank % 2 == 0:
            send_op = torch.distributed.isend(send_tensor, send_dst, cp_group)
            recv_op = torch.distributed.irecv(recv_tensor, recv_src, cp_group)
            send_recv_ops.append(send_op)
            send_recv_ops.append(recv_op)
        else:
            recv_op = torch.distributed.irecv(recv_tensor, recv_src, cp_group)
            send_op = torch.distributed.isend(send_tensor, send_dst, cp_group)
            send_recv_ops.append(recv_op)
            send_recv_ops.append(send_op)
        send_recv_reqs = send_recv_ops

    return send_recv_reqs


249
250
251
252
253
254
255
256
257
258
259
260
261
262
@jit_fuser
def flash_attn_fwd_out_correction_init(
    out_init_step: torch.Tensor,
    softmax_lse: torch.Tensor,
    softmax_lse_init_step: torch.Tensor,
    seq_dim: int,
):
    """Merge partial outputs of the first step in Attention with context parallelism"""
    softmax_lse_corrected_exp = torch.exp(softmax_lse_init_step - softmax_lse).movedim(2, seq_dim)
    softmax_lse_corrected_exp = softmax_lse_corrected_exp.unsqueeze(-1)
    out_corrected = out_init_step * softmax_lse_corrected_exp
    return out_corrected.to(out_init_step.dtype)


263
@jit_fuser
264
265
266
267
268
def flash_attn_fwd_out_correction(
    out: torch.Tensor,
    out_per_step: torch.Tensor,
    softmax_lse: torch.Tensor,
    softmax_lse_per_step: torch.Tensor,
269
    seq_dim: int,
270
):
271
    """Merge partial outputs of each step in Attention with context parallelism"""
272
    softmax_lse_corrected_exp = torch.exp(softmax_lse_per_step - softmax_lse).movedim(2, seq_dim)
273
    softmax_lse_corrected_exp = softmax_lse_corrected_exp.unsqueeze(-1)
274
    out_corrected = out_per_step * softmax_lse_corrected_exp
275
276
277
    out.add_(out_corrected)


278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
@jit_fuser
def flash_attn_fwd_second_half_out_correction(
    out: torch.Tensor,
    out_per_step: torch.Tensor,
    softmax_lse: torch.Tensor,
    softmax_lse_per_step: torch.Tensor,
    seq_dim: int,
):
    """Merge second half of partial outputs of each step in Attention with context parallelism"""
    out_ = out.select(seq_dim, 1)
    softmax_lse_ = softmax_lse.view(*softmax_lse.shape[:-1], 2, -1)[..., 1, :]
    softmax_lse_corrected_exp = torch.exp(softmax_lse_per_step - softmax_lse_).movedim(2, seq_dim)
    softmax_lse_corrected_exp = softmax_lse_corrected_exp.unsqueeze(-1)
    out_corrected = out_per_step * softmax_lse_corrected_exp
    out_.add_(out_corrected)


295
@jit_fuser
296
297
298
299
def flash_attn_fwd_softmax_lse_correction(
    softmax_lse: torch.Tensor,
    softmax_lse_per_step: torch.Tensor,
):
300
    """Merge softmax stats of each step in Attention with context parallelism"""
301
302
    max_scale = torch.max(softmax_lse, softmax_lse_per_step)
    min_scale = torch.min(softmax_lse, softmax_lse_per_step)
303
    new_scale = max_scale + torch.log1p(torch.exp(min_scale - max_scale))
304
    softmax_lse.copy_(new_scale)
305
306


307
308
309
310
311
312
313
314
315
316
317
318
319
@jit_fuser
def flash_attn_fwd_second_half_softmax_lse_correction(
    softmax_lse: torch.Tensor,
    softmax_lse_per_step: torch.Tensor,
):
    """Merge second half of softmax stats of each step in Attention with context parallelism"""
    softmax_lse_ = softmax_lse[..., 1, :]
    max_scale = torch.max(softmax_lse_, softmax_lse_per_step)
    min_scale = torch.min(softmax_lse_, softmax_lse_per_step)
    new_scale = max_scale + torch.log1p(torch.exp(min_scale - max_scale))
    softmax_lse_.copy_(new_scale)


320
321
@jit_fuser
def get_cu_seqlens_on_cp_rank(
322
323
324
325
326
327
    cu_seqlens: torch.Tensor,
    cu_seqlens_padded_on_cp_rank: torch.Tensor,
    cp_size: int,
    cp_rank: int,
    first_half: bool,
    second_half: bool,
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
):
    """Compute cu_seqlens of a context parallelism rank"""
    seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
    seqlens_padded = (cu_seqlens_padded_on_cp_rank[1:] - cu_seqlens_padded_on_cp_rank[:-1]) // 2
    zeros = torch.zeros_like(seqlens)
    cu_seqlens_on_cp_rank = torch.zeros_like(cu_seqlens)
    if first_half:
        seqlens_1 = seqlens - cp_rank * seqlens_padded
        seqlens_1 = seqlens_1.clamp(zeros, seqlens_padded)
        cu_seqlens_on_cp_rank[1:].add_(seqlens_1)
    if second_half:
        seqlens_2 = seqlens - (2 * cp_size - cp_rank - 1) * seqlens_padded
        seqlens_2 = seqlens_2.clamp(zeros, seqlens_padded)
        cu_seqlens_on_cp_rank[1:].add_(seqlens_2)
    cu_seqlens_on_cp_rank.cumsum_(dim=0)
    return cu_seqlens_on_cp_rank


346
@jit_fuser
347
def get_seq_chunk_ids_for_reordering_before_attn(cp_size, device):
348
349
    """
    Context parallelism assigns two discontiguous sequence chunks to each GPU for load balancing.
350
351
352
    To make sure tokens are ordered correctly for compute, we need to reorder sequence chunks to
    be contigupus before attention compute. This function is to compute sequence chunk ids for
    reordering.
353
354
    """
    chunk_ids = torch.empty(2 * cp_size, dtype=torch.int32, device=device)
355
356
357
    for rank in range(cp_size):
        chunk_ids[rank] = 2 * rank
        chunk_ids[rank + cp_size] = 2 * cp_size - 2 * rank - 1
358
359
360
    return chunk_ids


361
@jit_fuser
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
def get_seq_chunk_ids_for_reordering_after_attn(cp_size, device):
    """
    Context parallelism assigns two discontiguous sequence chunks to each GPU for load balancing.
    We need to reorder sequence chunks back to discontiguous after attention compute. This function
    is to compute sequence chunk ids for reordering.
    """
    chunk_ids = torch.empty(2 * cp_size, dtype=torch.int32, device=device)
    for rank in range(cp_size):
        chunk_ids[2 * rank] = rank
        chunk_ids[2 * rank + 1] = 2 * cp_size - rank - 1
    return chunk_ids


@jit_fuser
def reorder_seq_chunks_for_a2a_before_attn(x, chunk_ids_for_a2a, seq_dim, cp_size):
    """Reorder sequence chunk for A2A communication before attention compute."""
    # [cp, b, s, np//cp, hn] -> [b, cp, s, np//cp, hn]
    # or [cp, s, b, np//cp, hn] -> [cp, s, b, np//cp, hn]
    x = x.movedim(0, seq_dim).contiguous()
    # [b, cp, s, np//cp, hn] -> [b, cp*2, s//2, np//cp, hn]
    # or [cp, s, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn]
    x = x.view(*x.shape[:seq_dim], cp_size * 2, -1, *x.shape[(seq_dim + 2) :])
    # reorder the sequence chunks
    x = torch.index_select(x, dim=seq_dim, index=chunk_ids_for_a2a)
    return x


@jit_fuser
def reorder_seq_chunks_for_a2a_after_attn(x, chunk_ids_for_a2a, seq_dim, cp_size):
    """Reorder sequence chunk for A2A communication after attention compute."""
    # [b, cp*2, s//2, np//cp, hn] -> [cp*2, b, s//2, np//cp, hn]
    # or [cp*2, s//2, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn]
    x = x.movedim(seq_dim, 0).contiguous()
    # reorder the sequence chunks
    x = torch.index_select(x, dim=0, index=chunk_ids_for_a2a)
    # [cp*2, b, s//2, np//cp, hn] -> [cp, 2, b, s//2, np//cp, hn]
    # or [cp*2, s//2, b, np//cp, hn] -> [cp, 2, s//2, b, np//cp, hn]
    x = x.view(cp_size, 2, *x.shape[1:])
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
    return x


def flash_attn_a2a_communicate(
    a2a_inputs: Union[torch.Tensor, List[torch.Tensor]],
    chunk_ids_for_a2a: torch.Tensor,
    seq_dim: int,
    cp_size: int,
    cp_group: dist_group_type,
    cp_stream: torch.cuda.Stream,
    before_attn: bool,
) -> Union[torch.Tensor, List[torch.Tensor]]:
    """A2A communication for context parallelism."""
    a2a_inputs = [a2a_inputs] if not isinstance(a2a_inputs, list) else a2a_inputs
    a2a_outputs, a2a_reqs = [None] * len(a2a_inputs), [None] * len(a2a_inputs)
    if before_attn:
        for i in range(len(a2a_inputs) + 2):
            if 0 < i < len(a2a_inputs) + 1:
                a2a_outputs[i - 1] = torch.empty_like(a2a_inputs[i - 1])
                a2a_reqs[i - 1] = torch.distributed.all_to_all_single(
                    a2a_outputs[i - 1], a2a_inputs[i - 1], group=cp_group, async_op=True
                )
            if i > 1:
                with torch.cuda.stream(cp_stream):
                    a2a_reqs[i - 2].wait()
                    x = a2a_outputs[i - 2]
                    # reorder the sequence chunks
427
428
                    x = reorder_seq_chunks_for_a2a_before_attn(
                        x, chunk_ids_for_a2a, seq_dim, cp_size
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
                    )
                    # [b, cp*2, s//2, np//cp, hn] -> [b, cp*s, np//cp, hn]
                    # or [cp*2, s//2, b, np//cp, hn] -> [cp*s, b, np//cp, hn]
                    a2a_outputs[i - 2] = x.view(*x.shape[:seq_dim], -1, *x.shape[(seq_dim + 2) :])
            if i < len(a2a_inputs):
                x = a2a_inputs[i]
                # [b, s, np, hn] -> [b, s, cp, np//cp, hn]
                # or [s, b, np, hn] -> [s, b, cp, np//cp, hn]
                x = x.view(*x.shape[:-2], cp_size, x.shape[-2] // cp_size, x.shape[-1])
                # [b, s, cp, np//cp, hn] -> [cp, b, s, np//cp, hn]
                # or [s, b, cp, np//cp, hn] -> [cp, s, b, np//cp, hn]
                a2a_inputs[i] = x.movedim(-3, 0).contiguous()
    else:
        for i in range(len(a2a_inputs) + 2):
            if 0 < i < len(a2a_inputs) + 1:
                a2a_outputs[i - 1] = torch.empty_like(a2a_inputs[i - 1])
                a2a_reqs[i - 1] = torch.distributed.all_to_all_single(
                    a2a_outputs[i - 1], a2a_inputs[i - 1], group=cp_group, async_op=True
                )
            if i < len(a2a_inputs):
                x = a2a_inputs[i]
                # [b, cp*s, np//cp, hn] -> [b, cp*2, s//2, np//cp, hn]
                # or [cp*s, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn]
                x = x.view(*x.shape[:seq_dim], cp_size * 2, -1, *x.shape[(seq_dim + 1) :])
                # reorder the sequence chunks
454
455
                a2a_inputs[i] = reorder_seq_chunks_for_a2a_after_attn(
                    x, chunk_ids_for_a2a, seq_dim, cp_size
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
                )
            if i > 1:
                with torch.cuda.stream(cp_stream):
                    a2a_reqs[i - 2].wait()
                    x = a2a_outputs[i - 2]
                    # [cp, 2, b, s//2, np//cp, hn] -> [b, 2, s//2, cp, np//cp, hn]
                    # or [cp, 2, s//2, b, np//cp, hn] -> [2, s//2, b, cp, np//cp, hn]
                    x = x.movedim(0, -3).movedim(0, seq_dim).contiguous()
                    # [b, 2, s//2, cp, np//cp, hn] -> [b*s, np, hn]
                    # or [2, s//2, b, cp, np//cp, hn] -> [s*b, np, hn]
                    a2a_outputs[i - 2] = x.view(-1, x.shape[-3] * x.shape[-2], x.shape[-1])
    torch.cuda.current_stream().wait_stream(cp_stream)
    return a2a_outputs[0] if len(a2a_inputs) == 1 else a2a_outputs


471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
_cu_seqlens_info_with_cp_cache = {}


def _get_cu_seqlens_info_with_cp(
    batch_size: int,
    max_seqlen: int,
    cp_size: int,
    cu_seqlens: torch.Tensor,
):
    """Cumulative sequence lengths with CP being considered."""
    global _cu_seqlens_info_with_cp_cache
    if (batch_size, max_seqlen, cp_size) not in _cu_seqlens_info_with_cp_cache:
        _cu_seqlens_info_with_cp_cache[(batch_size, max_seqlen, cp_size)] = (
            cu_seqlens // cp_size,
            cu_seqlens // (cp_size * 2),
        )
    return _cu_seqlens_info_with_cp_cache[(batch_size, max_seqlen, cp_size)]


490
class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
491
    """
492
493
494
    Attention implementation with context parallelism. Exchange KV between CP ranks
    with P2P in ring topology. Split attention compute into multiple steps, and overlap
    current-step compute with next-step communication.
495
496
497
498
499

    This implementation also supports hierarchical CP, which parallelizes attention
    heads in low-level CP groups and parallelizes sequence dimension in high-level CP
    groups. For more details, please refer to `LongVILA <https://arxiv.org/abs/2408.10188>`_
    and `USP <https://arxiv.org/abs/2405.07719>`_.
500
501
502
    """

    @staticmethod
503
504
505
506
507
508
509
    def forward(
        ctx,
        is_training,
        q,
        k,
        v,
        cu_seqlens_q,
510
        cu_seqlens_kv,
511
        max_seqlen_q,
512
        max_seqlen_kv,
513
514
        cu_seqlens_q_padded,
        cu_seqlens_kv_padded,
515
516
517
518
519
520
521
522
        dropout_p,
        softmax_scale,
        qkv_format,
        attn_mask_type,
        attn_bias_type,
        attn_bias,
        deterministic,
        use_fused_attention,
523
524
        fp8,
        fp8_meta,
525
526
527
        cp_group,
        cp_global_ranks,
        cp_stream,
528
        quantizers,
529
        pad_between_seqs,
530
    ):
531
        # pylint: disable=missing-function-docstring
532
        nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVP2P.forward")
533
534
535
        if softmax_scale is None:
            softmax_scale = q.shape[-1] ** (-0.5)

536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
        if isinstance(cp_group, list):
            assert (
                qkv_format != "thd"
            ), f"{qkv_format} format is not supported with hierarchical CP implementation yet!"
            assert attn_bias_type == "no_bias", (
                f"{attn_bias_type} bias type is not supported with hierarchical CP implementation"
                " yet!"
            )
            cp_group_a2a = cp_group[0]
            cp_size_a2a = get_distributed_world_size(cp_group_a2a)
            rank_a2a = get_distributed_rank(cp_group_a2a)
            cp_group = cp_group[1]
        else:
            cp_group_a2a = None
            cp_size_a2a = 1
            rank_a2a = 0

553
554
        cp_size = get_distributed_world_size(cp_group)
        rank = get_distributed_rank(cp_group)
555
556
        send_dst = cp_global_ranks[(rank + 1) % cp_size * cp_size_a2a + rank_a2a]
        recv_src = cp_global_ranks[(rank - 1) % cp_size * cp_size_a2a + rank_a2a]
557
558
        batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2)

559
560
        causal = "causal" in attn_mask_type
        padding = "padding" in attn_mask_type
561

562
        batch_dim = None
563
        seq_dim = None
564
        cu_seqlens_q_half, cu_seqlens_kv_half = None, None
565
        if qkv_format in ["bshd", "sbhd"]:
566
            seq_dim = qkv_format.index("s")
567
            qkv_layout = qkv_format + "_" + qkv_format[:-2] + "2" + qkv_format[-2:]
568
569
570
571
572
573
574
575
576
            cu_seqlens_q_padded, cu_seqlens_kv_padded = None, None
            if use_fused_attention:
                batch_dim = qkv_format.index("b")
                cu_seqlens_q, cu_seqlens_q_half = _get_cu_seqlens_info_with_cp(
                    q.shape[batch_dim], max_seqlen_q, cp_size, cu_seqlens_q
                )
                cu_seqlens_kv, cu_seqlens_kv_half = _get_cu_seqlens_info_with_cp(
                    q.shape[batch_dim], max_seqlen_kv, cp_size, cu_seqlens_kv
                )
577
578
        else:
            qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format
579
580
            cu_seqlens_q_padded = cu_seqlens_q_padded // cp_size
            cu_seqlens_kv_padded = cu_seqlens_kv_padded // cp_size
581
582
583
584
585

        max_seqlen_q = max_seqlen_q // cp_size
        max_seqlen_kv = max_seqlen_kv // cp_size
        cu_seqlens_q_per_step = [None for _ in range(cp_size)]
        cu_seqlens_kv_per_step = [None for _ in range(cp_size)]
586

587
        fused_attn_backend = None
588
        qkv_dtype = q.dtype
589
590
591
        amax_per_step = None
        S_quantizer_per_step = [None for _ in range(cp_size)]
        O_CP_quantizer_per_step = [None for _ in range(cp_size)]
592
593
        # "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype
        is_input_fp8 = False
594
595
596
597
598
599
600
601
602
603
604
        is_output_fp8 = False

        (
            QKV_quantizer,
            O_quantizer,
            O_CP_quantizer,
            S_quantizer,
            dQKV_quantizer,
            dQKV_CP_quantizer,
            dO_quantizer,
            dP_quantizer,
605
        ) = dpa_utils.get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=True)
606

607
608
609
        if fp8:
            if use_fused_attention:
                fused_attn_backend = FusedAttnBackend["FP8"]
610

611
612
613
614
                assert isinstance(k, q.__class__) and isinstance(
                    v, q.__class__
                ), "q, k, and v must have the same type."
                is_input_fp8 = isinstance(q, Float8Tensor)
615
616
617
618
619
                is_output_fp8 = fp8_meta is not None and fp8_meta["recipe"].fp8_mha
                if is_input_fp8:
                    QKV_quantizer = q._quantizer
                    q, k, v = q._data, k._data, v._data
                else:
620
621
                    q_f16, k_f16, v_f16 = q, k, v
                    if cp_size_a2a == 1 or int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
622
                        q = QKV_quantizer(q_f16)._data
623
                    if int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
624
625
626
627
628
629
630
631
                        k, v = [QKV_quantizer(x)._data for x in [k_f16, v_f16]]
                amax_per_step = torch.zeros((2, cp_size), dtype=torch.float32, device=q.device)
                # partial result quantizer
                for i in range(cp_size):
                    S_quantizer_per_step[i] = S_quantizer.copy()
                    S_quantizer_per_step[i].amax = amax_per_step[0][i]
                    O_CP_quantizer_per_step[i] = O_CP_quantizer.copy()
                    O_CP_quantizer_per_step[i].amax = amax_per_step[1][i]
632
633
634
635
636
637
638
639
            else:
                assert False, "FP8 is only supported with Fused Attention!"
        else:
            q_f16 = q
            if use_fused_attention:
                fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"]

        if cp_size_a2a > 1:
640
            chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_before_attn(cp_size_a2a, q.device)
641

642
643
644
645
646
            q, k, v = flash_attn_a2a_communicate(
                [q, k, v], chunk_ids_for_a2a, seq_dim, cp_size_a2a, cp_group_a2a, cp_stream, True
            )
            if not fp8:
                q_f16 = q
647
            elif not is_input_fp8 and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
648
                q_f16 = q
649
                q = QKV_quantizer(q_f16)._data
650

651
652
653
        assert qkv_format == "thd" or (
            q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0
        ), "Sequence length per GPU needs to be divisible by 2!"
654
        if causal:
655
656
            if qkv_format == "bshd":
                # [b, s, np, hn] -> [b, 2, s//2, np, hn]
657
                q, k, v = [x.view(x.shape[0], 2, x.shape[1] // 2, *x.shape[2:]) for x in [q, k, v]]
658
659
            elif qkv_format == "sbhd":
                # [s, b, np, hn] -> [2, s//2, b, np, hn]
660
                q, k, v = [x.view(2, x.shape[0] // 2, *x.shape[1:]) for x in [q, k, v]]
661
        if attn_bias is not None:
662
            assert len(attn_bias.shape) == 4, (
663
664
665
                "Only support bias shape of [b, h, sq, sk] for forward, "
                "and [1, h, sq, sk] for backward!"
            )
666
667
668
            assert (
                attn_bias.shape[-2] % 2 == 0 and attn_bias.shape[-1] % (2 * cp_size) == 0
            ), "Sequence length does not meet divisible requirements!"
669
            # [b, np, sq, sk] -> [b, np, 2, sq//2, 2*cp, sk//(2*cp)]
670
671
672
673
674
675
            attn_bias_ = attn_bias.view(
                *attn_bias.shape[:-2],
                2,
                attn_bias.shape[-2] // 2,
                2 * cp_size,
                attn_bias.shape[-1] // (2 * cp_size),
676
677
            )
            # [b, np, sq, sk] -> [b, np, sq, 2*cp, sk//(2*cp)]
678
679
            attn_bias = attn_bias.view(
                *attn_bias.shape[:-1], 2 * cp_size, attn_bias.shape[-1] // (2 * cp_size)
680
            )
681
        assert q.shape[-1] % 8 == 0, "hidden size per attention head should be multiple of 8"
682

683
684
685
686
687
        softmax_lse_in_packed_format = False
        if qkv_format == "thd":
            if use_fused_attention:
                softmax_lse_in_packed_format = get_cudnn_version() >= (9, 6, 0)
            else:
688
                softmax_lse_in_packed_format = fa_utils.v2_6_0_plus or fa_utils.use_v3
689

690
        flash_attn_fwd = None
691
692
        if not use_fused_attention:
            fa_forward_kwargs = {"softmax_scale": softmax_scale}
693
            if fa_utils.use_v3:
694
695
696
697
                if qkv_format == "thd":
                    flash_attn_fwd = _flash_attn_varlen_fwd_v3
                else:
                    flash_attn_fwd = _flash_attn_fwd_v3
698
699
                fa_forward_kwargs["window_size"] = (-1, 0) if causal else (-1, -1)
            else:
700
701
702
703
                if qkv_format == "thd":
                    flash_attn_fwd = _flash_attn_varlen_fwd
                else:
                    flash_attn_fwd = _flash_attn_fwd
704
705
                fa_forward_kwargs["dropout_p"] = dropout_p
                fa_forward_kwargs["return_softmax"] = False
706
                if (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus) or fa_utils.use_v3:
707
                    fa_forward_kwargs["window_size"] = (-1, 0) if causal else (-1, -1)
708
                elif fa_utils.v2_7_0_plus:
709
710
                    fa_forward_kwargs["window_size_left"] = -1
                    fa_forward_kwargs["window_size_right"] = 0 if causal else -1
711
                if fa_utils.v2_4_plus:
712
                    fa_forward_kwargs["alibi_slopes"] = None
713
                if fa_utils.v2_5_7_plus and qkv_format == "thd":
714
                    fa_forward_kwargs["block_table"] = None
715
                if fa_utils.v2_6_0_plus:
716
                    fa_forward_kwargs["softcap"] = 0.0
717

718
719
720
        # Flash Attn inputs
        q_inputs = [None, None]
        kv_inputs = [None, None]
721
        attn_bias_inputs = [None, None]
722
723
724
725
        # Flash Attn outputs
        out_per_step = [None for _ in range(cp_size)]
        softmax_lse_per_step = [None for _ in range(cp_size)]
        rng_states = [None for _ in range(cp_size)]
726
        attn_biases = [None for _ in range(cp_size)]
727
728
729
730
731
732
733

        # create two streams to resolve wave quantization issue of Flash Attn in each step
        flash_attn_streams = [torch.cuda.current_stream(), cp_stream]
        # synchronize fwd results correction across steps
        fwd_results_correction_done = torch.cuda.Event()

        p2p_comm_buffers = [None for _ in range(cp_size)]
734
        if qkv_format in ["bshd", "sbhd"]:
735
736
737
            p2p_comm_buffers[0] = torch.cat((k.unsqueeze(-3), v.unsqueeze(-3)), dim=-3)
        else:
            p2p_comm_buffers[0] = torch.cat((k.unsqueeze(0), v.unsqueeze(0)), dim=0)
738
739
        send_recv_reqs = [[], []]

740
        out = None
741
        for i in range(cp_size + 1):
742
            if i < cp_size:
743
                with torch.cuda.stream(flash_attn_streams[i % 2]):
744
                    # wait until KV is received
745
                    for req in send_recv_reqs[(i + 1) % 2]:
746
747
                        req.wait()

748
749
750
751
752
753
754
755
756
757
758
759
                    if i < (cp_size - 1):
                        p2p_comm_buffers[i + 1] = torch.empty_like(p2p_comm_buffers[i])
                        send_recv_reqs[i % 2] = flash_attn_p2p_communicate(
                            rank,
                            p2p_comm_buffers[i],
                            send_dst,
                            p2p_comm_buffers[i + 1],
                            recv_src,
                            cp_group,
                            batch_p2p_comm,
                        )

760
                    if not fp8 or is_input_fp8 or int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
761
762
763
                        kv_inputs[i % 2] = p2p_comm_buffers[i]
                    else:
                        # KV exchange is in BF16/FP16, cast received KV in each step
764
                        kv_inputs[i % 2] = QKV_quantizer(p2p_comm_buffers[i])._data
765
766
                    if causal:
                        if i == 0:
767
                            if pad_between_seqs:
768
769
770
771
772
773
                                cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank(
                                    cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, True, True
                                )
                                cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank(
                                    cu_seqlens_kv, cu_seqlens_kv_padded, cp_size, rank, True, True
                                )
774
775
                            elif qkv_format == "thd":
                                cu_seqlens_q_per_step[i] = cu_seqlens_q // cp_size
776
                                cu_seqlens_kv_per_step[i] = cu_seqlens_kv // cp_size
777
778
779
                            else:
                                cu_seqlens_q_per_step[i] = cu_seqlens_q
                                cu_seqlens_kv_per_step[i] = cu_seqlens_kv
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
                            if qkv_format == "bshd":
                                # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
                                q_inputs[i % 2] = q.view(q.shape[0], -1, *q.shape[-2:])
                                # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn]
                                kv_inputs[i % 2] = kv_inputs[i % 2].view(
                                    k.shape[0], -1, 2, *k.shape[-2:]
                                )
                            elif qkv_format == "sbhd":
                                # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
                                q_inputs[i % 2] = q.view(-1, *q.shape[-3:])
                                # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn]
                                kv_inputs[i % 2] = kv_inputs[i % 2].view(
                                    -1, k.shape[2], 2, *k.shape[-2:]
                                )
                            elif qkv_format == "thd":
                                q_inputs[i % 2] = q
796
                            if use_fused_attention:
797
798
                                if attn_bias is not None:
                                    idx = (rank - i) % cp_size
799
800
801
802
803
804
                                    attn_bias_inputs[i % 2] = torch.cat(
                                        (
                                            attn_bias[..., idx, :],
                                            attn_bias[..., (2 * cp_size - idx - 1), :],
                                        ),
                                        dim=-1,
805
                                    ).contiguous()
806
807
808
809
810
811
812
813
814
815
816
817

                                q_part = q_inputs[i % 2]
                                k_part = (
                                    kv_inputs[i % 2][..., 0, :, :]
                                    if qkv_format in ["bshd", "sbhd"]
                                    else kv_inputs[i % 2][0]
                                )
                                v_part = (
                                    kv_inputs[i % 2][..., 1, :, :]
                                    if qkv_format in ["bshd", "sbhd"]
                                    else kv_inputs[i % 2][1]
                                )
818
                                fp8_meta_kwargs = {}
819
820
821
822
823
824
825
826
827
828
                                if fp8:
                                    q_part = QKV_quantizer.create_tensor_from_data(
                                        q_part, fake_dtype=qkv_dtype, internal=True
                                    )
                                    k_part = QKV_quantizer.create_tensor_from_data(
                                        k_part, fake_dtype=qkv_dtype, internal=True
                                    )
                                    v_part = QKV_quantizer.create_tensor_from_data(
                                        v_part, fake_dtype=qkv_dtype, internal=True
                                    )
829
830
                                    fp8_meta_kwargs["s_quantizer"] = S_quantizer_per_step[i]
                                    fp8_meta_kwargs["o_quantizer"] = O_CP_quantizer_per_step[i]
831

832
833
834
835
836
837
                                out_per_step[i], aux_ctx_tensors = fused_attn_fwd(
                                    is_training,
                                    max_seqlen_q,
                                    max_seqlen_kv,
                                    cu_seqlens_q_per_step[i],
                                    cu_seqlens_kv_per_step[i],
838
839
840
841
842
                                    q_part,
                                    k_part,
                                    v_part,
                                    fake_dtype=qkv_dtype,
                                    fused_attention_backend=fused_attn_backend,
843
844
845
846
847
848
849
850
851
                                    attn_scale=softmax_scale,
                                    dropout=dropout_p,
                                    qkv_layout=qkv_layout,
                                    attn_mask_type=attn_mask_type,
                                    attn_bias_type=attn_bias_type,
                                    attn_bias=attn_bias_inputs[i % 2],
                                    cu_seqlens_q_padded=cu_seqlens_q_padded,
                                    cu_seqlens_kv_padded=cu_seqlens_kv_padded,
                                    **fp8_meta_kwargs,
852
                                )
853
854
855
856
857
                                if fp8:
                                    softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors
                                else:
                                    softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors
                                    attn_biases[i] = rest[0] if len(rest) > 0 else None
858
                            else:
859
860
861
862
863
864
865
866
                                fa_forward_args_thd = []
                                if qkv_format == "thd":
                                    fa_forward_args_thd = [
                                        cu_seqlens_q_per_step[i],
                                        cu_seqlens_kv_per_step[i],
                                        max_seqlen_q,
                                        max_seqlen_kv,
                                    ]
867
                                fa_outputs = flash_attn_fwd(
868
                                    q_inputs[i % 2],
869
870
871
872
873
874
875
876
877
878
879
                                    (
                                        kv_inputs[i % 2][..., 0, :, :]
                                        if qkv_format in ["bshd", "sbhd"]
                                        else kv_inputs[i % 2][0]
                                    ),
                                    (
                                        kv_inputs[i % 2][..., 1, :, :]
                                        if qkv_format in ["bshd", "sbhd"]
                                        else kv_inputs[i % 2][1]
                                    ),
                                    *fa_forward_args_thd,
880
                                    causal=True,
881
                                    **fa_forward_kwargs,
882
                                )
883
                                if not fa_utils.v2_7_0_plus:
884
885
                                    out_per_step[i] = fa_outputs[4]
                                    softmax_lse_per_step[i] = fa_outputs[5]
886
                                    if not fa_utils.use_v3:
887
888
889
890
                                        rng_states[i] = fa_outputs[7]
                                else:
                                    out_per_step[i] = fa_outputs[0]
                                    softmax_lse_per_step[i] = fa_outputs[1]
891
                                    if not fa_utils.use_v3:
892
                                        rng_states[i] = fa_outputs[3]
893
                        elif i <= rank:
894
                            if pad_between_seqs:
895
896
897
898
899
900
901
902
903
904
905
                                cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank(
                                    cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, True, True
                                )
                                cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank(
                                    cu_seqlens_kv,
                                    cu_seqlens_kv_padded,
                                    cp_size,
                                    (rank - i) % cp_size,
                                    True,
                                    False,
                                )
906
907
                            elif qkv_format == "thd":
                                cu_seqlens_q_per_step[i] = cu_seqlens_q // cp_size
908
                                cu_seqlens_kv_per_step[i] = cu_seqlens_kv // (cp_size * 2)
909
910
911
                            else:
                                cu_seqlens_q_per_step[i] = cu_seqlens_q
                                cu_seqlens_kv_per_step[i] = cu_seqlens_kv_half
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
                            if qkv_format == "bshd":
                                # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
                                q_inputs[i % 2] = q.view(q.shape[0], -1, *q.shape[-2:])
                                # [b, 2, sk//2, 2, np, hn] -> [b, sk//2, 2, np, hn]
                                kv_inputs[i % 2] = kv_inputs[i % 2][:, 0, ...]
                            elif qkv_format == "sbhd":
                                # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
                                q_inputs[i % 2] = q.view(-1, *q.shape[-3:])
                                # [2, sk//2, b, 2, np, hn] -> [sk//2, b, 2, np, hn]
                                kv_inputs[i % 2] = kv_inputs[i % 2][0]
                            elif qkv_format == "thd":
                                q_inputs[i % 2] = q
                                # [2, t, np, hn] -> [2, t/2, np, hn]
                                kv_inputs[i % 2] = tex.thd_read_half_tensor(
                                    kv_inputs[i % 2], cu_seqlens_kv_padded, 0
                                )
928
                            if use_fused_attention:
929
                                kv_inputs[i % 2] = kv_inputs[i % 2].contiguous()
930
931
                                if attn_bias is not None:
                                    idx = (rank - i) % cp_size
932
                                    attn_bias_inputs[i % 2] = attn_bias[..., idx, :].contiguous()
933
934
935
936
937
938
939
940
941
942
943
944

                                q_part = q_inputs[i % 2]
                                k_part = (
                                    kv_inputs[i % 2][..., 0, :, :]
                                    if qkv_format in ["bshd", "sbhd"]
                                    else kv_inputs[i % 2][0]
                                )
                                v_part = (
                                    kv_inputs[i % 2][..., 1, :, :]
                                    if qkv_format in ["bshd", "sbhd"]
                                    else kv_inputs[i % 2][1]
                                )
945
                                fp8_meta_kwargs = {}
946
947
948
949
950
951
952
953
954
955
                                if fp8:
                                    q_part = QKV_quantizer.create_tensor_from_data(
                                        q_part, fake_dtype=qkv_dtype, internal=True
                                    )
                                    k_part = QKV_quantizer.create_tensor_from_data(
                                        k_part, fake_dtype=qkv_dtype, internal=True
                                    )
                                    v_part = QKV_quantizer.create_tensor_from_data(
                                        v_part, fake_dtype=qkv_dtype, internal=True
                                    )
956
957
                                    fp8_meta_kwargs["s_quantizer"] = S_quantizer_per_step[i]
                                    fp8_meta_kwargs["o_quantizer"] = O_CP_quantizer_per_step[i]
958
959
960
961
962
963
                                out_per_step[i], aux_ctx_tensors = fused_attn_fwd(
                                    is_training,
                                    max_seqlen_q,
                                    max_seqlen_kv // 2,
                                    cu_seqlens_q_per_step[i],
                                    cu_seqlens_kv_per_step[i],
964
965
966
967
                                    q_part,
                                    k_part,
                                    v_part,
                                    qkv_dtype,
968
969
970
971
972
973
974
975
976
977
978
979
980
981
                                    fused_attn_backend,
                                    attn_scale=softmax_scale,
                                    dropout=dropout_p,
                                    qkv_layout=qkv_layout,
                                    attn_mask_type="padding" if padding else "no_mask",
                                    attn_bias_type=attn_bias_type,
                                    attn_bias=attn_bias_inputs[i % 2],
                                    cu_seqlens_q_padded=cu_seqlens_q_padded,
                                    cu_seqlens_kv_padded=(
                                        None
                                        if cu_seqlens_kv_padded is None
                                        else cu_seqlens_kv_padded // 2
                                    ),
                                    **fp8_meta_kwargs,
982
                                )
983
984
985
986
987
                                if fp8:
                                    softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors
                                else:
                                    softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors
                                    attn_biases[i] = rest[0] if len(rest) > 0 else None
988
                            else:
989
                                fa_forward_args_thd = []
990
                                if qkv_format == "thd":
991
992
993
994
995
996
                                    fa_forward_args_thd = [
                                        cu_seqlens_q_per_step[i],
                                        cu_seqlens_kv_per_step[i],
                                        max_seqlen_q,
                                        max_seqlen_kv // 2,
                                    ]
997
998
                                if fa_utils.use_v3 or (
                                    fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus
999
                                ):
1000
                                    fa_forward_kwargs["window_size"] = (-1, -1)
1001
                                elif fa_utils.v2_7_0_plus:
1002
1003
                                    fa_forward_kwargs["window_size_left"] = -1
                                    fa_forward_kwargs["window_size_right"] = -1
1004
                                fa_outputs = flash_attn_fwd(
1005
                                    q_inputs[i % 2],
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
                                    (
                                        kv_inputs[i % 2][..., 0, :, :]
                                        if qkv_format in ["bshd", "sbhd"]
                                        else kv_inputs[i % 2][0]
                                    ),
                                    (
                                        kv_inputs[i % 2][..., 1, :, :]
                                        if qkv_format in ["bshd", "sbhd"]
                                        else kv_inputs[i % 2][1]
                                    ),
                                    *fa_forward_args_thd,
1017
                                    causal=False,
1018
                                    **fa_forward_kwargs,
1019
                                )
1020
                                if not fa_utils.v2_7_0_plus:
1021
1022
                                    out_per_step[i] = fa_outputs[4]
                                    softmax_lse_per_step[i] = fa_outputs[5]
1023
                                    if not fa_utils.use_v3:
1024
1025
1026
1027
                                        rng_states[i] = fa_outputs[7]
                                else:
                                    out_per_step[i] = fa_outputs[0]
                                    softmax_lse_per_step[i] = fa_outputs[1]
1028
                                    if not fa_utils.use_v3:
1029
                                        rng_states[i] = fa_outputs[3]
1030
                        else:
1031
                            if pad_between_seqs:
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
                                cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank(
                                    cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, False, True
                                )
                                cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank(
                                    cu_seqlens_kv,
                                    cu_seqlens_kv_padded,
                                    cp_size,
                                    (rank - i) % cp_size,
                                    True,
                                    True,
                                )
1043
1044
                            elif qkv_format == "thd":
                                cu_seqlens_q_per_step[i] = cu_seqlens_q // (cp_size * 2)
1045
                                cu_seqlens_kv_per_step[i] = cu_seqlens_kv // cp_size
1046
1047
1048
                            else:
                                cu_seqlens_q_per_step[i] = cu_seqlens_q_half
                                cu_seqlens_kv_per_step[i] = cu_seqlens_kv
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
                            if qkv_format == "bshd":
                                # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
                                q_inputs[i % 2] = q[:, 1, ...]
                                # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn]
                                kv_inputs[i % 2] = kv_inputs[i % 2].view(
                                    k.shape[0], -1, 2, *k.shape[-2:]
                                )
                            elif qkv_format == "sbhd":
                                # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn]
                                q_inputs[i % 2] = q[1]
                                # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn]
                                kv_inputs[i % 2] = kv_inputs[i % 2].view(
                                    -1, k.shape[2], 2, *k.shape[-2:]
                                )
                            elif qkv_format == "thd":
                                # [t, np, hn] -> [t/2, np, hn]
                                q_inputs[i % 2] = tex.thd_read_half_tensor(
                                    q, cu_seqlens_q_padded, 1
                                )
1068
                            if use_fused_attention:
1069
                                q_inputs[i % 2] = q_inputs[i % 2].contiguous()
1070
1071
                                if attn_bias is not None:
                                    idx = (rank - i) % cp_size
1072
1073
1074
1075
1076
1077
                                    attn_bias_inputs[i % 2] = torch.cat(
                                        (
                                            attn_bias_[..., 1, :, idx, :],
                                            attn_bias_[..., 1, :, (2 * cp_size - idx - 1), :],
                                        ),
                                        dim=-1,
1078
                                    ).contiguous()
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090

                                q_part = q_inputs[i % 2]
                                k_part = (
                                    kv_inputs[i % 2][..., 0, :, :]
                                    if qkv_format in ["bshd", "sbhd"]
                                    else kv_inputs[i % 2][0]
                                )
                                v_part = (
                                    kv_inputs[i % 2][..., 1, :, :]
                                    if qkv_format in ["bshd", "sbhd"]
                                    else kv_inputs[i % 2][1]
                                )
1091
                                fp8_meta_kwargs = {}
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
                                if fp8:
                                    q_part = QKV_quantizer.create_tensor_from_data(
                                        q_part, fake_dtype=qkv_dtype, internal=True
                                    )
                                    k_part = QKV_quantizer.create_tensor_from_data(
                                        k_part, fake_dtype=qkv_dtype, internal=True
                                    )
                                    v_part = QKV_quantizer.create_tensor_from_data(
                                        v_part, fake_dtype=qkv_dtype, internal=True
                                    )
1102
1103
                                    fp8_meta_kwargs["s_quantizer"] = S_quantizer_per_step[i]
                                    fp8_meta_kwargs["o_quantizer"] = O_CP_quantizer_per_step[i]
1104
1105
1106
1107
1108
1109
                                out_per_step[i], aux_ctx_tensors = fused_attn_fwd(
                                    is_training,
                                    max_seqlen_q // 2,
                                    max_seqlen_kv,
                                    cu_seqlens_q_per_step[i],
                                    cu_seqlens_kv_per_step[i],
1110
1111
1112
1113
                                    q_part,
                                    k_part,
                                    v_part,
                                    qkv_dtype,
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
                                    fused_attn_backend,
                                    attn_scale=softmax_scale,
                                    dropout=dropout_p,
                                    qkv_layout=qkv_layout,
                                    attn_mask_type="padding" if padding else "no_mask",
                                    attn_bias_type=attn_bias_type,
                                    attn_bias=attn_bias_inputs[i % 2],
                                    cu_seqlens_q_padded=(
                                        None
                                        if cu_seqlens_q_padded is None
                                        else cu_seqlens_q_padded // 2
                                    ),
                                    cu_seqlens_kv_padded=cu_seqlens_kv_padded,
                                    **fp8_meta_kwargs,
1128
                                )
1129
1130
1131
1132
1133
                                if fp8:
                                    softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors
                                else:
                                    softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors
                                    attn_biases[i] = rest[0] if len(rest) > 0 else None
1134
                            else:
1135
                                fa_forward_args_thd = []
1136
                                if qkv_format == "thd":
1137
1138
1139
1140
1141
1142
                                    fa_forward_args_thd = [
                                        cu_seqlens_q_per_step[i],
                                        cu_seqlens_kv_per_step[i],
                                        max_seqlen_q // 2,
                                        max_seqlen_kv,
                                    ]
1143
1144
                                if fa_utils.use_v3 or (
                                    fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus
1145
                                ):
1146
                                    fa_forward_kwargs["window_size"] = (-1, -1)
1147
                                elif fa_utils.v2_7_0_plus:
1148
1149
                                    fa_forward_kwargs["window_size_left"] = -1
                                    fa_forward_kwargs["window_size_right"] = -1
1150
                                fa_outputs = flash_attn_fwd(
1151
                                    q_inputs[i % 2],
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
                                    (
                                        kv_inputs[i % 2][..., 0, :, :]
                                        if qkv_format in ["bshd", "sbhd"]
                                        else kv_inputs[i % 2][0]
                                    ),
                                    (
                                        kv_inputs[i % 2][..., 1, :, :]
                                        if qkv_format in ["bshd", "sbhd"]
                                        else kv_inputs[i % 2][1]
                                    ),
                                    *fa_forward_args_thd,
1163
                                    causal=False,
1164
                                    **fa_forward_kwargs,
1165
                                )
1166
                                if not fa_utils.v2_7_0_plus:
1167
1168
                                    out_per_step[i] = fa_outputs[4]
                                    softmax_lse_per_step[i] = fa_outputs[5]
1169
                                    if not fa_utils.use_v3:
1170
1171
1172
1173
                                        rng_states[i] = fa_outputs[7]
                                else:
                                    out_per_step[i] = fa_outputs[0]
                                    softmax_lse_per_step[i] = fa_outputs[1]
1174
                                    if not fa_utils.use_v3:
1175
                                        rng_states[i] = fa_outputs[3]
1176
                    else:
1177
                        if pad_between_seqs:
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
                            cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank(
                                cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, True, True
                            )
                            cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank(
                                cu_seqlens_kv,
                                cu_seqlens_kv_padded,
                                cp_size,
                                (rank - i) % cp_size,
                                True,
                                True,
                            )
1189
1190
                        elif qkv_format == "thd":
                            cu_seqlens_q_per_step[i] = cu_seqlens_q // cp_size
1191
                            cu_seqlens_kv_per_step[i] = cu_seqlens_kv // cp_size
1192
1193
1194
                        else:
                            cu_seqlens_q_per_step[i] = cu_seqlens_q
                            cu_seqlens_kv_per_step[i] = cu_seqlens_kv
1195
                        if use_fused_attention:
1196
1197
                            if attn_bias is not None:
                                idx = (rank - i) % cp_size
1198
1199
1200
1201
1202
1203
                                attn_bias_inputs[i % 2] = torch.cat(
                                    (
                                        attn_bias[..., idx, :],
                                        attn_bias[..., (2 * cp_size - idx - 1), :],
                                    ),
                                    dim=-1,
1204
                                ).contiguous()
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216

                            q_part = q
                            k_part = (
                                kv_inputs[i % 2][..., 0, :, :]
                                if qkv_format in ["bshd", "sbhd"]
                                else kv_inputs[i % 2][0]
                            )
                            v_part = (
                                kv_inputs[i % 2][..., 1, :, :]
                                if qkv_format in ["bshd", "sbhd"]
                                else kv_inputs[i % 2][1]
                            )
1217
                            fp8_meta_kwargs = {}
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
                            if fp8:
                                q_part = QKV_quantizer.create_tensor_from_data(
                                    q_part, fake_dtype=qkv_dtype, internal=True
                                )
                                k_part = QKV_quantizer.create_tensor_from_data(
                                    k_part, fake_dtype=qkv_dtype, internal=True
                                )
                                v_part = QKV_quantizer.create_tensor_from_data(
                                    v_part, fake_dtype=qkv_dtype, internal=True
                                )
1228
1229
                                fp8_meta_kwargs["s_quantizer"] = S_quantizer_per_step[i]
                                fp8_meta_kwargs["o_quantizer"] = O_CP_quantizer_per_step[i]
1230
1231
1232
1233
1234
1235
                            out_per_step[i], aux_ctx_tensors = fused_attn_fwd(
                                is_training,
                                max_seqlen_q,
                                max_seqlen_kv,
                                cu_seqlens_q_per_step[i],
                                cu_seqlens_kv_per_step[i],
1236
1237
1238
1239
                                q_part,
                                k_part,
                                v_part,
                                qkv_dtype,
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
                                fused_attn_backend,
                                attn_scale=softmax_scale,
                                dropout=dropout_p,
                                qkv_layout=qkv_layout,
                                attn_mask_type=attn_mask_type,
                                attn_bias_type=attn_bias_type,
                                attn_bias=attn_bias_inputs[i % 2],
                                cu_seqlens_q_padded=cu_seqlens_q_padded,
                                cu_seqlens_kv_padded=cu_seqlens_kv_padded,
                                **fp8_meta_kwargs,
1250
                            )
1251
1252
1253
1254
1255
                            if fp8:
                                softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors
                            else:
                                softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors
                                attn_biases[i] = rest[0] if len(rest) > 0 else None
1256
                        else:
1257
1258
1259
1260
1261
1262
1263
1264
                            fa_forward_args_thd = []
                            if qkv_format == "thd":
                                fa_forward_args_thd = [
                                    cu_seqlens_q_per_step[i],
                                    cu_seqlens_kv_per_step[i],
                                    max_seqlen_q,
                                    max_seqlen_kv,
                                ]
1265
                            fa_outputs = flash_attn_fwd(
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
                                q,
                                (
                                    kv_inputs[i % 2][..., 0, :, :]
                                    if qkv_format in ["bshd", "sbhd"]
                                    else kv_inputs[i % 2][0]
                                ),
                                (
                                    kv_inputs[i % 2][..., 1, :, :]
                                    if qkv_format in ["bshd", "sbhd"]
                                    else kv_inputs[i % 2][1]
                                ),
                                *fa_forward_args_thd,
1278
                                causal=False,
1279
                                **fa_forward_kwargs,
1280
                            )
1281
                            if not fa_utils.v2_7_0_plus:
1282
1283
                                out_per_step[i] = fa_outputs[4]
                                softmax_lse_per_step[i] = fa_outputs[5]
1284
                                if not fa_utils.use_v3:
1285
1286
1287
1288
                                    rng_states[i] = fa_outputs[7]
                            else:
                                out_per_step[i] = fa_outputs[0]
                                softmax_lse_per_step[i] = fa_outputs[1]
1289
                                if not fa_utils.use_v3:
1290
                                    rng_states[i] = fa_outputs[3]
1291
1292
1293
1294

            if i > 0:
                # wait until fwd restuls correction of last step is done
                if i > 1:
1295
                    flash_attn_streams[(i - 1) % 2].wait_event(fwd_results_correction_done)
1296

1297
                if use_fused_attention:
1298
1299
                    # [b, np, sq, 1] -> [b, np, sq] or
                    # [t, np, 1] -> [t, np]
1300
                    softmax_lse_per_step[i - 1].squeeze_(-1)
1301
1302
1303
1304
                    if softmax_lse_in_packed_format:
                        softmax_lse_per_step[i - 1] = (
                            softmax_lse_per_step[i - 1].transpose(0, 1).contiguous()
                        )
1305

1306
                with torch.cuda.stream(flash_attn_streams[(i - 1) % 2]):
1307
                    if fp8:
1308
                        out_per_step[i - 1] = out_per_step[i - 1].dequantize(dtype=torch.float32)
1309
1310
                    if i == 1:
                        softmax_lse = torch.clone(softmax_lse_per_step[0]).to(torch.double)
1311
1312
                        if qkv_format == "thd":
                            out = torch.zeros_like(q if not fp8 else out_per_step[0]).view(q.shape)
1313
1314
1315
1316
                    elif (i - 1) <= rank or not causal:
                        flash_attn_fwd_softmax_lse_correction(
                            softmax_lse, softmax_lse_per_step[i - 1]
                        )
1317
                    else:
1318
                        if qkv_format == "thd":
1319
                            tex.thd_second_half_lse_correction(
1320
1321
1322
                                softmax_lse,
                                softmax_lse_per_step[i - 1],
                                cu_seqlens_q_padded,
1323
                                softmax_lse_in_packed_format,
1324
                            )
1325
                        else:
1326
1327
1328
                            flash_attn_fwd_second_half_softmax_lse_correction(
                                softmax_lse.view(*softmax_lse.shape[:-1], 2, -1),
                                softmax_lse_per_step[i - 1],
1329
                            )
1330
1331

                if i < cp_size:
1332
                    flash_attn_streams[(i - 1) % 2].record_event(fwd_results_correction_done)
1333
1334
1335

        torch.cuda.current_stream().wait_stream(flash_attn_streams[1])

1336
1337
1338
1339
        second_half_lse_seqlen = None
        if causal and rank < (cp_size - 1):
            second_half_lse_seqlen = softmax_lse_per_step[-1].shape[-1]

1340
1341
        softmax_lse = softmax_lse.to(torch.float)
        for i in range(cp_size):
1342
            if i <= rank or not causal:
1343
                if qkv_format in ["bshd", "sbhd"]:
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
                    if i == 0:
                        out = flash_attn_fwd_out_correction_init(
                            out_per_step[0],
                            softmax_lse,
                            softmax_lse_per_step[0],
                            seq_dim,
                        )
                        out = out.view(q.shape)
                    else:
                        flash_attn_fwd_out_correction(
                            out.view(*out_per_step[i].shape),
                            out_per_step[i],
                            softmax_lse,
                            softmax_lse_per_step[i],
                            seq_dim,
                        )
1360
                elif qkv_format == "thd":
1361
1362
1363
1364
1365
                    tex.thd_out_correction(
                        out,
                        out_per_step[i],
                        softmax_lse,
                        softmax_lse_per_step[i],
1366
                        cu_seqlens_q_padded,
1367
                        False,
1368
                        softmax_lse_in_packed_format,
1369
                    )
1370
            else:
1371
                if qkv_format in ["bshd", "sbhd"]:
1372
1373
                    flash_attn_fwd_second_half_out_correction(
                        out,
1374
                        out_per_step[i],
1375
                        softmax_lse,
1376
                        softmax_lse_per_step[i],
1377
                        seq_dim,
1378
                    )
1379
                elif qkv_format == "thd":
1380
1381
1382
1383
1384
                    tex.thd_out_correction(
                        out,
                        out_per_step[i],
                        softmax_lse,
                        softmax_lse_per_step[i],
1385
                        cu_seqlens_q_padded,
1386
                        True,
1387
                        softmax_lse_in_packed_format,
1388
                    )
1389
1390

        kv = p2p_comm_buffers[-1]
1391
1392
1393
1394
1395
1396
1397
1398
        if qkv_format == "bshd":
            out = out.view(out.shape[0], -1, *out.shape[-2:])
            ctx.batch_size = out.shape[0]
        elif qkv_format == "sbhd":
            out = out.view(-1, *out.shape[-3:])
            ctx.batch_size = out.shape[1]

        if cp_size_a2a > 1:
1399
            chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size_a2a, out.device)
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
            out = flash_attn_a2a_communicate(
                out, chunk_ids_for_a2a, seq_dim, cp_size_a2a, cp_group_a2a, cp_stream, False
            )
            if use_fused_attention:
                if qkv_format == "bshd":
                    # [b*s, np, hn] -> [b, s, np, hn]
                    out = out.view(ctx.batch_size, -1, *out.shape[-2:])
                elif qkv_format == "sbhd":
                    # [s*b, np, hn] -> [s, b, np, hn]
                    out = out.view(-1, ctx.batch_size, *out.shape[-2:])
        elif not use_fused_attention:
1411
            out = out.view(-1, *out.shape[-2:])
1412

1413
1414
1415
1416
1417
        if fp8 and use_fused_attention:
            amax_cp_fwd = amax_per_step.amax(dim=1)
            S_quantizer.amax = amax_cp_fwd[0]
            O_CP_quantizer.amax = amax_cp_fwd[1]

1418
        out_fp8 = None
1419
        out_f16 = out.to(qkv_dtype)
1420

1421
        if fp8 and (is_output_fp8 or int(os.getenv("NVTE_FP8_DPA_BWD", "1"))):
1422
1423
1424
            out_fp8 = O_quantizer(out_f16)  # final result

        out_ret = out_fp8 if (fp8 and is_output_fp8) else out_f16
1425
1426

        if fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
1427
            q_save, kv_save, out_save = q, kv, out_fp8._data
1428
        elif fp8 and is_input_fp8:
1429
            q_save, kv_save, out_save = q, kv, out_f16
1430
        else:
1431
            q_f16 = q_f16.view(q.shape)
1432
1433
            q_save, kv_save, out_save = q_f16, kv, out_f16

1434
        tensors_to_save, tensor_objects = prepare_for_saving(
1435
1436
1437
            q_save,
            kv_save,
            out_save,
1438
            softmax_lse,
1439
1440
            cu_seqlens_q_padded,
            cu_seqlens_kv_padded,
1441
1442
            *cu_seqlens_q_per_step,
            *cu_seqlens_kv_per_step,
1443
1444
            *rng_states,
            *attn_biases,
1445
        )
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
        ctx.save_for_backward(*tensors_to_save)
        ctx.tensor_objects = tensor_objects

        ctx.qkv_dtype = qkv_dtype
        ctx.QKV_quantizer = QKV_quantizer
        ctx.O_quantizer = O_quantizer
        ctx.O_CP_quantizer = O_CP_quantizer
        ctx.S_quantizer = S_quantizer
        ctx.dQKV_quantizer = dQKV_quantizer
        ctx.dQKV_CP_quantizer = dQKV_CP_quantizer
        ctx.dO_quantizer = dO_quantizer
        ctx.dP_quantizer = dP_quantizer

1459
1460
1461
        ctx.cp_group_a2a = cp_group_a2a
        ctx.cp_size_a2a = cp_size_a2a
        ctx.rank_a2a = rank_a2a
1462
1463
        ctx.cp_group = cp_group
        ctx.cp_global_ranks = cp_global_ranks
1464
        ctx.cp_stream = cp_stream
1465
1466
        ctx.dropout_p = dropout_p
        ctx.max_seqlen_q = max_seqlen_q
1467
        ctx.max_seqlen_kv = max_seqlen_kv
1468
        ctx.softmax_scale = softmax_scale
1469
        ctx.qkv_format = qkv_format
1470
        ctx.attn_mask_type = attn_mask_type
1471
1472
        ctx.attn_bias_type = attn_bias_type
        ctx.attn_bias_shape = None if attn_bias is None else attn_bias.shape
1473
        ctx.deterministic = deterministic
1474
        ctx.use_fused_attention = use_fused_attention
1475
        ctx.softmax_lse_in_packed_format = softmax_lse_in_packed_format
1476
        ctx.second_half_lse_seqlen = second_half_lse_seqlen
1477
1478
        ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
        ctx.fp8_meta = fp8_meta
1479
1480
        ctx.is_input_fp8 = is_input_fp8
        ctx.is_output_fp8 = is_output_fp8
1481
        nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVP2P.forward")
1482

1483
        return out_ret
1484
1485
1486

    @staticmethod
    def backward(ctx, dout):
1487
        # pylint: disable=missing-function-docstring
1488
        nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVP2P.backward")
1489
1490
1491
        cp_size_a2a = ctx.cp_size_a2a
        rank_a2a = ctx.rank_a2a

1492
1493
        cp_size = get_distributed_world_size(ctx.cp_group)
        rank = get_distributed_rank(ctx.cp_group)
1494
1495
        send_dst = ctx.cp_global_ranks[(rank - 1) % cp_size * cp_size_a2a + rank_a2a]
        recv_src = ctx.cp_global_ranks[(rank + 1) % cp_size * cp_size_a2a + rank_a2a]
1496
1497
        batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2)

1498
        q, kv, out, softmax_lse, cu_seqlens_q_padded, cu_seqlens_kv_padded, *other_tensors = (
1499
            restore_from_saved(ctx.tensor_objects, ctx.saved_tensors)
1500
1501
1502
1503
1504
        )
        cu_seqlens_q_per_step = other_tensors[:cp_size]
        cu_seqlens_kv_per_step = other_tensors[cp_size : cp_size * 2]
        rng_states = other_tensors[cp_size * 2 : cp_size * 3]
        attn_biases = other_tensors[cp_size * 3 : cp_size * 4]
1505

1506
1507
        causal = "causal" in ctx.attn_mask_type
        padding = "padding" in ctx.attn_mask_type
1508
1509

        seq_dim = None
1510
        if ctx.qkv_format in ["bshd", "sbhd"]:
1511
            seq_dim = ctx.qkv_format.index("s")
1512
1513
1514
            qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format[:-2] + "2" + ctx.qkv_format[-2:]
        else:
            qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format
1515

1516
        if attn_biases[0] is not None:
1517
1518
            # [b, np, sq, 2*cp, sk//(2*cp)]
            attn_dbias = torch.zeros(
1519
                *ctx.attn_bias_shape, dtype=attn_biases[0].dtype, device=attn_biases[0].device
1520
1521
1522
            )
            # [b, np, sq, 2*cp, sk//(2*cp)] -> [b, np, 2, sq//2, 2*cp, sk//(2*cp)]
            attn_dbias_ = attn_dbias.view(
1523
                *attn_dbias.shape[:-3], 2, attn_dbias.shape[-3] // 2, *attn_dbias.shape[-2:]
1524
1525
1526
            )
        else:
            attn_dbias = None
1527
            attn_dbias_ = None
1528

1529
1530
        softmax_lse_ = None
        if causal and ctx.second_half_lse_seqlen is not None:
1531
            if ctx.qkv_format == "thd":
1532
                softmax_lse_ = tex.thd_read_second_half_lse(
1533
1534
1535
1536
                    softmax_lse,
                    cu_seqlens_q_padded,
                    ctx.softmax_lse_in_packed_format,
                    ctx.second_half_lse_seqlen,
1537
                )
1538
1539
            else:
                # [b, np, sq] -> [b, np, 2, sq//2]
1540
                softmax_lse_ = softmax_lse.view(*softmax_lse.shape[:-1], 2, -1)
1541
                softmax_lse_ = softmax_lse_[..., 1, :].contiguous()
1542
1543
1544
1545
1546
1547
            if ctx.use_fused_attention:
                if ctx.softmax_lse_in_packed_format:
                    softmax_lse_ = softmax_lse_.transpose(0, 1).contiguous()
                # [b, np, sq//2] -> [b, np, sq//2, 1] or
                # [t//2, np] -> [t//2, np, 1]
                softmax_lse_.unsqueeze_(-1)
1548
        if ctx.use_fused_attention:
1549
1550
1551
1552
            if ctx.softmax_lse_in_packed_format:
                softmax_lse = softmax_lse.transpose(0, 1).contiguous()
            # [b, np, sq] -> [b, np, sq, 1] or
            # [t, np] -> [t, np, 1]
1553
            softmax_lse.unsqueeze_(-1)
1554
            dout = dout.contiguous()
1555

1556
        dq = None
1557
        dout_dtype = dout.dtype
1558
1559
        fused_attn_backend = None
        fused_attn_dqkv_dtype = None
1560
1561
1562
        amax_per_step = None
        dP_quantizer_per_step = [None for _ in range(cp_size)]
        dQKV_CP_quantizer_per_step = [None for _ in range(cp_size)]
1563
1564
1565
        if ctx.fp8:
            if ctx.use_fused_attention:
                fused_attn_backend = FusedAttnBackend["FP8"]
1566

1567
1568
1569
1570
1571
1572
1573
1574
1575
                dqkv_fp8_torch_dtype = get_fp8_torch_dtype(
                    ctx.fp8_meta["recipe"], fprop_tensor=False
                )
                dq_fp8 = torch.empty(
                    (cp_size, *q.shape), dtype=dqkv_fp8_torch_dtype, device=q.device
                )
                dkv_fp8 = torch.empty(
                    (cp_size, *kv.shape), dtype=dqkv_fp8_torch_dtype, device=kv.device
                )
1576
                dkv_fp8_ = torch.empty_like(dkv_fp8)
1577
                if ctx.is_output_fp8:
1578
                    assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!"
1579
                    ctx.dO_quantizer = dout._quantizer
1580
                else:
1581
                    dout = ctx.dO_quantizer(dout)
1582
1583
                fused_attn_dqkv_dtype = dout._fp8_dtype
                dout = dout._data
1584
1585
                p2p_comm_buffers = [[kv, dkv_fp8], [torch.empty_like(kv), dkv_fp8_]]
                fp8_meta_kwargs = {}
1586
                fp8_meta_kwargs["s_quantizer"] = ctx.S_quantizer
1587
1588
1589
1590
1591
1592
                amax_per_step = torch.zeros((2, cp_size), dtype=torch.float32, device=q.device)
                for i in range(cp_size):
                    dP_quantizer_per_step[i] = ctx.dP_quantizer.copy()
                    dP_quantizer_per_step[i].amax = amax_per_step[0][i]
                    dQKV_CP_quantizer_per_step[i] = ctx.dQKV_CP_quantizer.copy()
                    dQKV_CP_quantizer_per_step[i].amax = amax_per_step[1][i]
1593
1594
1595
            else:
                assert False, "FP8 is only supported with Fused Attention!"
        else:
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
            if ctx.fp8_meta is not None:
                if ctx.is_input_fp8:
                    q = ctx.QKV_quantizer.create_tensor_from_data(
                        q, fake_dtype=ctx.qkv_dtype, internal=True
                    )
                    kv = ctx.QKV_quantizer.create_tensor_from_data(
                        kv, fake_dtype=ctx.qkv_dtype, internal=True
                    )
                    q = q.dequantize(dtype=ctx.qkv_dtype)
                    kv = kv.dequantize(dtype=ctx.qkv_dtype)
                if ctx.is_output_fp8:
                    assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!"
                    if cp_size_a2a == 1:
                        dout = dout.dequantize(dtype=dout_dtype)
                    else:
                        ctx.dO_quantizer = dout._quantizer
                        dout = dout._data
1613
1614
1615
1616
1617
1618
1619
1620
            dq = torch.empty_like(q)
            p2p_comm_buffers = [
                torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device),
                torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device),
            ]
            p2p_comm_buffers[0][0].copy_(kv)
            if ctx.use_fused_attention:
                fp8_meta_kwargs = {}
1621
                fused_attn_dqkv_dtype = TE_DType[dout_dtype]
1622
1623
                fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"]

1624
1625
1626
1627
        if cp_size_a2a > 1:
            if not ctx.use_fused_attention:
                out = out.view(ctx.batch_size, -1, *out.shape[-2:])
                dout = dout.view(*out.shape)
1628
1629
1630
            chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_before_attn(
                cp_size_a2a, out.device
            )
1631
1632
1633
1634
1635
1636
1637
1638
1639
            out, dout = flash_attn_a2a_communicate(
                [out, dout],
                chunk_ids_for_a2a,
                seq_dim,
                cp_size_a2a,
                ctx.cp_group_a2a,
                ctx.cp_stream,
                True,
            )
1640
            if not ctx.fp8 and ctx.fp8_meta is not None and ctx.is_output_fp8:
1641
1642
1643
1644
                dout = ctx.dO_quantizer.create_tensor_from_data(
                    dout, fake_dtype=dout_dtype, internal=True
                )
                dout = dout.dequantize(dtype=dout_dtype)
1645

1646
1647
1648
1649
        out = out.view(*q.shape)
        dout = dout.view(*q.shape)
        send_recv_reqs = []

1650
        flash_attn_bwd = None
1651
1652
        if not ctx.use_fused_attention:
            fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale}
1653
            if fa_utils.use_v3:
1654
1655
1656
1657
                if ctx.qkv_format == "thd":
                    flash_attn_bwd = _flash_attn_varlen_bwd_v3
                else:
                    flash_attn_bwd = _flash_attn_bwd_v3
1658
1659
                fa_backward_kwargs["deterministic"] = ctx.deterministic
            else:
1660
1661
1662
1663
                if ctx.qkv_format == "thd":
                    flash_attn_bwd = _flash_attn_varlen_bwd
                else:
                    flash_attn_bwd = _flash_attn_bwd
1664
                fa_backward_kwargs["dropout_p"] = ctx.dropout_p
1665
                if fa_utils.v2_4_plus:
1666
                    fa_backward_kwargs["alibi_slopes"] = None
1667
                if fa_utils.v2_4_1_plus:
1668
                    fa_backward_kwargs["deterministic"] = ctx.deterministic
1669
                if fa_utils.v2_6_0_plus:
1670
                    fa_backward_kwargs["softcap"] = 0.0
1671

1672
1673
1674
1675
1676
        for i in range(cp_size):
            # wait until KV is received
            for req in send_recv_reqs:
                req.wait()

1677
1678
            send_tensor = p2p_comm_buffers[i % 2]
            recv_tensor = p2p_comm_buffers[(i + 1) % 2]
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
            if ctx.fp8:
                if i < cp_size - 1:
                    send_recv_reqs = flash_attn_p2p_communicate(
                        rank,
                        send_tensor[0],
                        send_dst,
                        recv_tensor[0],
                        recv_src,
                        ctx.cp_group,
                        batch_p2p_comm,
                    )
                else:
                    dkv_a2a_req = torch.distributed.all_to_all_single(
                        dkv_fp8,
                        dkv_fp8_,
                        group=ctx.cp_group,
                        async_op=True,
                    )
                    send_recv_reqs = [dkv_a2a_req]
            else:
                if i == 0:
                    send_tensor = send_tensor[0]
                    recv_tensor = recv_tensor[0]
                if i == (cp_size - 1):
                    send_tensor = send_tensor[1]
                    recv_tensor = recv_tensor[1]
                send_recv_reqs = flash_attn_p2p_communicate(
                    rank, send_tensor, send_dst, recv_tensor, recv_src, ctx.cp_group, batch_p2p_comm
                )
1708

1709
            kv = p2p_comm_buffers[i % 2][0]
1710
1711
            q_, kv_, out_, dout_ = None, None, None, None
            dq_, dk_, dv_ = None, None, None
1712
            # In reversed order of fwd
1713
            if causal:
1714
                if i == (cp_size - 1):
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
                    if ctx.qkv_format == "bshd":
                        # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
                        q_, out_, dout_ = [
                            x.view(x.shape[0], -1, *x.shape[-2:]) for x in [q, out, dout]
                        ]
                        # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn]
                        kv_ = kv.view(kv.shape[0], -1, *kv.shape[-3:])
                    elif ctx.qkv_format == "sbhd":
                        # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
                        q_, out_, dout_ = [x.view(-1, *x.shape[-3:]) for x in [q, out, dout]]
                        # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn]
                        kv_ = kv.view(-1, *kv.shape[-4:])
                    elif ctx.qkv_format == "thd":
                        q_, kv_, out_, dout_ = q, kv, out, dout
1729
                    if ctx.use_fused_attention:
1730
1731
1732
1733
1734
1735
1736
1737
                        if ctx.fp8:
                            aux_ctx_tensors = [
                                softmax_lse,
                                softmax_lse,
                                rng_states[cp_size - i - 1],
                            ]
                        else:
                            aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]]
1738
                        if attn_dbias is not None:
1739
                            aux_ctx_tensors += [attn_biases[cp_size - i - 1]]
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
                        q_part = q_
                        k_part = kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0]
                        v_part = kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1]
                        out_part = out_
                        dout_part = dout_

                        if ctx.fp8:
                            q_part = ctx.QKV_quantizer.create_tensor_from_data(
                                q_part, fake_dtype=ctx.qkv_dtype, internal=True
                            )
                            k_part = ctx.QKV_quantizer.create_tensor_from_data(
                                k_part, fake_dtype=ctx.qkv_dtype, internal=True
                            )
                            v_part = ctx.QKV_quantizer.create_tensor_from_data(
                                v_part, fake_dtype=ctx.qkv_dtype, internal=True
                            )
                            out_part = ctx.O_quantizer.create_tensor_from_data(
                                out_part, fake_dtype=ctx.qkv_dtype, internal=True
                            )
                            dout_part = ctx.dO_quantizer.create_tensor_from_data(
1760
                                dout_part, fake_dtype=dout_dtype, internal=True
1761
                            )
1762
1763
                            fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i]
                            fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i]
1764
                        dq_, dk_, dv_, dbias_ = fused_attn_bwd(
1765
                            ctx.max_seqlen_q,
1766
1767
1768
                            ctx.max_seqlen_kv,
                            cu_seqlens_q_per_step[cp_size - i - 1],
                            cu_seqlens_kv_per_step[cp_size - i - 1],
1769
1770
1771
1772
1773
1774
                            q_part,
                            k_part,
                            v_part,
                            out_part,
                            dout_part,
                            ctx.qkv_dtype,
1775
                            fused_attn_dqkv_dtype,
1776
                            aux_ctx_tensors,
1777
                            fused_attn_backend,
1778
1779
                            cu_seqlens_q_padded=cu_seqlens_q_padded,
                            cu_seqlens_kv_padded=cu_seqlens_kv_padded,
1780
1781
                            attn_scale=ctx.softmax_scale,
                            dropout=ctx.dropout_p,
1782
                            qkv_layout=qkv_layout,
1783
                            attn_mask_type=ctx.attn_mask_type,
1784
                            attn_bias_type=ctx.attn_bias_type,
1785
1786
                            deterministic=ctx.deterministic,
                            **fp8_meta_kwargs,
1787
                        )
1788
1789
1790
1791
                        if ctx.fp8:
                            dq_ = dq_._data
                            dk_ = dk_._data
                            dv_ = dv_._data
1792
                    else:
1793
                        dq_ = torch.empty_like(q_)
1794
                        dkv_ = torch.empty_like(kv_)
1795
1796
1797
1798
1799
1800
1801
1802
                        fa_backward_args_thd = []
                        if ctx.qkv_format == "thd":
                            fa_backward_args_thd = [
                                cu_seqlens_q_per_step[cp_size - i - 1],
                                cu_seqlens_kv_per_step[cp_size - i - 1],
                                ctx.max_seqlen_q,
                                ctx.max_seqlen_kv,
                            ]
1803
                        if fa_utils.use_v3 or (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus):
1804
                            fa_backward_kwargs["window_size"] = (-1, 0)
1805
                        elif fa_utils.v2_7_0_plus:
1806
1807
                            fa_backward_kwargs["window_size_left"] = -1
                            fa_backward_kwargs["window_size_right"] = 0
1808
                        if not fa_utils.use_v3:
1809
1810
                            fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1]
                        flash_attn_bwd(
1811
1812
                            dout_,
                            q_,
1813
1814
                            kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0],
                            kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1],
1815
1816
1817
                            out_,
                            softmax_lse,
                            dq_,
1818
1819
1820
                            dkv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[0],
                            dkv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[1],
                            *fa_backward_args_thd,
1821
1822
                            causal=True,
                            **fa_backward_kwargs,
1823
                        )
1824
                elif i >= (cp_size - rank - 1):
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
                    if ctx.qkv_format == "bshd":
                        # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
                        q_, out_, dout_ = [
                            x.view(x.shape[0], -1, *x.shape[-2:]) for x in [q, out, dout]
                        ]
                        # [b, 2, sk//2, 2, np, hn] -> [b, sk//2, 2, np, hn]
                        kv_ = kv[:, 0]
                    elif ctx.qkv_format == "sbhd":
                        # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
                        q_, out_, dout_ = [x.view(-1, *x.shape[-3:]) for x in [q, out, dout]]
                        # [2, sk//2, b, 2, np, hn] -> [sk//2, b, 2, np, hn]
                        kv_ = kv[0]
                    elif ctx.qkv_format == "thd":
                        q_, out_, dout_ = q, out, dout
                        # [2, t, np, hn] -> [2, t/2, np, hn]
                        kv_ = tex.thd_read_half_tensor(kv, cu_seqlens_kv_padded, 0)
1841
                    if ctx.use_fused_attention:
1842
                        kv_ = kv_.contiguous()
1843
1844
1845
1846
1847
1848
1849
1850
                        if ctx.fp8:
                            aux_ctx_tensors = [
                                softmax_lse,
                                softmax_lse,
                                rng_states[cp_size - i - 1],
                            ]
                        else:
                            aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]]
1851
                        if attn_dbias is not None:
1852
                            aux_ctx_tensors += [attn_biases[cp_size - i - 1]]
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
                        q_part = q_
                        k_part = kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0]
                        v_part = kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1]
                        out_part = out_
                        dout_part = dout_

                        if ctx.fp8:
                            q_part = ctx.QKV_quantizer.create_tensor_from_data(
                                q_part, fake_dtype=ctx.qkv_dtype, internal=True
                            )
                            k_part = ctx.QKV_quantizer.create_tensor_from_data(
                                k_part, fake_dtype=ctx.qkv_dtype, internal=True
                            )
                            v_part = ctx.QKV_quantizer.create_tensor_from_data(
                                v_part, fake_dtype=ctx.qkv_dtype, internal=True
                            )
                            out_part = ctx.O_quantizer.create_tensor_from_data(
                                out_part, fake_dtype=ctx.qkv_dtype, internal=True
                            )
                            dout_part = ctx.dO_quantizer.create_tensor_from_data(
1873
                                dout_part, fake_dtype=dout_dtype, internal=True
1874
                            )
1875
1876
                            fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i]
                            fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i]
1877
                        dq_, dk_, dv_, dbias_ = fused_attn_bwd(
1878
                            ctx.max_seqlen_q,
1879
1880
1881
                            ctx.max_seqlen_kv // 2,
                            cu_seqlens_q_per_step[cp_size - i - 1],
                            cu_seqlens_kv_per_step[cp_size - i - 1],
1882
1883
1884
1885
1886
1887
                            q_part,
                            k_part,
                            v_part,
                            out_part,
                            dout_part,
                            ctx.qkv_dtype,
1888
                            fused_attn_dqkv_dtype,
1889
                            aux_ctx_tensors,
1890
                            fused_attn_backend,
1891
1892
1893
1894
                            cu_seqlens_q_padded=cu_seqlens_q_padded,
                            cu_seqlens_kv_padded=(
                                None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded // 2
                            ),
1895
1896
                            attn_scale=ctx.softmax_scale,
                            dropout=ctx.dropout_p,
1897
                            qkv_layout=qkv_layout,
1898
                            attn_mask_type="padding" if padding else "no_mask",
1899
                            attn_bias_type=ctx.attn_bias_type,
1900
1901
                            deterministic=ctx.deterministic,
                            **fp8_meta_kwargs,
1902
                        )
1903
1904
1905
1906
                        if ctx.fp8:
                            dq_ = dq_._data
                            dk_ = dk_._data
                            dv_ = dv_._data
1907
                    else:
1908
                        dq_ = torch.empty_like(q_)
1909
                        dkv_ = torch.empty_like(kv_)
1910
1911
1912
1913
1914
1915
1916
1917
                        fa_backward_args_thd = []
                        if ctx.qkv_format == "thd":
                            fa_backward_args_thd = [
                                cu_seqlens_q_per_step[cp_size - i - 1],
                                cu_seqlens_kv_per_step[cp_size - i - 1],
                                ctx.max_seqlen_q,
                                ctx.max_seqlen_kv // 2,
                            ]
1918
                        if fa_utils.use_v3 or (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus):
1919
                            fa_backward_kwargs["window_size"] = (-1, -1)
1920
                        if fa_utils.v2_7_0_plus:
1921
1922
                            fa_backward_kwargs["window_size_left"] = -1
                            fa_backward_kwargs["window_size_right"] = -1
1923
                        if not fa_utils.use_v3:
1924
1925
                            fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1]
                        flash_attn_bwd(
1926
1927
                            dout_,
                            q_,
1928
1929
                            kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0],
                            kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1],
1930
1931
1932
                            out_,
                            softmax_lse,
                            dq_,
1933
1934
1935
                            dkv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[0],
                            dkv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[1],
                            *fa_backward_args_thd,
1936
1937
                            causal=False,
                            **fa_backward_kwargs,
1938
1939
                        )
                else:
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
                    if ctx.qkv_format == "bshd":
                        # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
                        q_, out_, dout_ = q[:, 1], out[:, 1], dout[:, 1]
                        # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn]
                        kv_ = kv.view(kv.shape[0], -1, *kv.shape[-3:])
                    elif ctx.qkv_format == "sbhd":
                        # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn]
                        q_, out_, dout_ = q[1], out[1], dout[1]
                        # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn]
                        kv_ = kv.view(-1, *kv.shape[-4:])
                    elif ctx.qkv_format == "thd":
                        # [t, np, hn] -> [t/2, np, hn]
                        q_, out_, dout_ = [
                            tex.thd_read_half_tensor(x, cu_seqlens_q_padded, 1)
                            for x in [q, out, dout]
                        ]
                        kv_ = kv
1957
                    if ctx.use_fused_attention:
1958
                        q_, out_, dout_ = [x.contiguous() for x in [q_, out_, dout_]]
1959
1960
1961
1962
1963
1964
1965
1966
                        if ctx.fp8:
                            aux_ctx_tensors = [
                                softmax_lse_,
                                softmax_lse_,
                                rng_states[cp_size - i - 1],
                            ]
                        else:
                            aux_ctx_tensors = [softmax_lse_, rng_states[cp_size - i - 1]]
1967
                        if attn_dbias is not None:
1968
                            aux_ctx_tensors += [attn_biases[cp_size - i - 1]]
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989

                        q_part = q_
                        k_part = kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0]
                        v_part = kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1]
                        out_part = out_
                        dout_part = dout_

                        if ctx.fp8:
                            q_part = ctx.QKV_quantizer.create_tensor_from_data(
                                q_part, fake_dtype=ctx.qkv_dtype, internal=True
                            )
                            k_part = ctx.QKV_quantizer.create_tensor_from_data(
                                k_part, fake_dtype=ctx.qkv_dtype, internal=True
                            )
                            v_part = ctx.QKV_quantizer.create_tensor_from_data(
                                v_part, fake_dtype=ctx.qkv_dtype, internal=True
                            )
                            out_part = ctx.O_quantizer.create_tensor_from_data(
                                out_part, fake_dtype=ctx.qkv_dtype, internal=True
                            )
                            dout_part = ctx.dO_quantizer.create_tensor_from_data(
1990
                                dout_part, fake_dtype=dout_dtype, internal=True
1991
                            )
1992
1993
                            fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i]
                            fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i]
1994
                        dq_, dk_, dv_, dbias_ = fused_attn_bwd(
1995
                            ctx.max_seqlen_q // 2,
1996
1997
1998
                            ctx.max_seqlen_kv,
                            cu_seqlens_q_per_step[cp_size - i - 1],
                            cu_seqlens_kv_per_step[cp_size - i - 1],
1999
2000
2001
2002
2003
2004
                            q_part,
                            k_part,
                            v_part,
                            out_part,
                            dout_part,
                            ctx.qkv_dtype,
2005
                            fused_attn_dqkv_dtype,
2006
                            aux_ctx_tensors,
2007
                            fused_attn_backend,
2008
2009
2010
2011
                            cu_seqlens_q_padded=(
                                None if cu_seqlens_q_padded is None else cu_seqlens_q_padded // 2
                            ),
                            cu_seqlens_kv_padded=cu_seqlens_kv_padded,
2012
2013
                            attn_scale=ctx.softmax_scale,
                            dropout=ctx.dropout_p,
2014
                            qkv_layout=qkv_layout,
2015
                            attn_mask_type="padding" if padding else "no_mask",
2016
                            attn_bias_type=ctx.attn_bias_type,
2017
2018
                            deterministic=ctx.deterministic,
                            **fp8_meta_kwargs,
2019
                        )
2020
2021
2022
2023
                        if ctx.fp8:
                            dq_ = dq_._data
                            dk_ = dk_._data
                            dv_ = dv_._data
2024
                    else:
2025
                        dq_ = torch.empty_like(q_)
2026
                        dkv_ = torch.empty_like(kv_)
2027
                        fa_backward_args_thd = []
2028
                        if ctx.qkv_format == "thd":
2029
2030
2031
2032
2033
2034
                            fa_backward_args_thd = [
                                cu_seqlens_q_per_step[cp_size - i - 1],
                                cu_seqlens_kv_per_step[cp_size - i - 1],
                                ctx.max_seqlen_q // 2,
                                ctx.max_seqlen_kv,
                            ]
2035
                        if fa_utils.use_v3 or (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus):
2036
                            fa_backward_kwargs["window_size"] = (-1, -1)
2037
                        elif fa_utils.v2_7_0_plus:
2038
2039
                            fa_backward_kwargs["window_size_left"] = -1
                            fa_backward_kwargs["window_size_right"] = -1
2040
                        if not fa_utils.use_v3:
2041
2042
                            fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1]
                        flash_attn_bwd(
2043
2044
                            dout_,
                            q_,
2045
2046
                            kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0],
                            kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1],
2047
2048
2049
                            out_,
                            softmax_lse_,
                            dq_,
2050
2051
2052
                            dkv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[0],
                            dkv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[1],
                            *fa_backward_args_thd,
2053
2054
                            causal=False,
                            **fa_backward_kwargs,
2055
2056
2057
                        )
            else:
                if ctx.use_fused_attention:
2058
2059
2060
2061
                    if ctx.fp8:
                        aux_ctx_tensors = [softmax_lse, softmax_lse, rng_states[cp_size - i - 1]]
                    else:
                        aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]]
2062
                    if attn_dbias is not None:
2063
                        aux_ctx_tensors += [attn_biases[cp_size - i - 1]]
2064
2065
2066
2067
2068
2069
2070
2071
                    q_part = q
                    k_part = kv[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[0]
                    v_part = kv[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[1]
                    out_part = out
                    dout_part = dout

                    if ctx.fp8:
                        q_part = ctx.QKV_quantizer.create_tensor_from_data(
2072
                            q_part, fake_dtype=ctx.qkv_dtype, internal=True
2073
2074
                        )
                        k_part = ctx.QKV_quantizer.create_tensor_from_data(
2075
                            k_part, fake_dtype=ctx.qkv_dtype, internal=True
2076
2077
                        )
                        v_part = ctx.QKV_quantizer.create_tensor_from_data(
2078
                            v_part, fake_dtype=ctx.qkv_dtype, internal=True
2079
2080
                        )
                        out_part = ctx.O_quantizer.create_tensor_from_data(
2081
                            out_part, fake_dtype=ctx.qkv_dtype, internal=True
2082
2083
                        )
                        dout_part = ctx.dO_quantizer.create_tensor_from_data(
2084
                            dout_part, fake_dtype=dout_dtype, internal=True
2085
                        )
2086
2087
                        fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i]
                        fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i]
2088
                    dq_, dk_, dv_, dbias_ = fused_attn_bwd(
2089
                        ctx.max_seqlen_q,
2090
2091
2092
                        ctx.max_seqlen_kv,
                        cu_seqlens_q_per_step[cp_size - i - 1],
                        cu_seqlens_kv_per_step[cp_size - i - 1],
2093
2094
2095
2096
2097
2098
                        q_part,
                        k_part,
                        v_part,
                        out_part,
                        dout_part,
                        ctx.qkv_dtype,
2099
                        fused_attn_dqkv_dtype,
2100
                        aux_ctx_tensors,
2101
                        fused_attn_backend,
2102
2103
                        cu_seqlens_q_padded=cu_seqlens_q_padded,
                        cu_seqlens_kv_padded=cu_seqlens_kv_padded,
2104
2105
                        attn_scale=ctx.softmax_scale,
                        dropout=ctx.dropout_p,
2106
                        qkv_layout=qkv_layout,
2107
                        attn_mask_type=ctx.attn_mask_type,
2108
                        attn_bias_type=ctx.attn_bias_type,
2109
2110
                        deterministic=ctx.deterministic,
                        **fp8_meta_kwargs,
2111
                    )
2112
2113
2114
2115
2116
2117

                    if ctx.fp8:
                        dq_ = dq_._data
                        dk_ = dk_._data
                        dv_ = dv_._data

2118
                else:
2119
2120
2121
2122
2123
2124
2125
2126
2127
2128
                    dq_ = torch.empty_like(q)
                    dkv_ = torch.empty_like(kv)
                    fa_backward_args_thd = []
                    if ctx.qkv_format == "thd":
                        fa_backward_args_thd = [
                            cu_seqlens_q_per_step[cp_size - i - 1],
                            cu_seqlens_kv_per_step[cp_size - i - 1],
                            ctx.max_seqlen_q,
                            ctx.max_seqlen_kv,
                        ]
2129
                    if fa_utils.use_v3 or (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus):
2130
                        fa_backward_kwargs["window_size"] = (-1, -1)
2131
                    elif fa_utils.v2_7_0_plus:
2132
2133
                        fa_backward_kwargs["window_size_left"] = -1
                        fa_backward_kwargs["window_size_right"] = -1
2134
                    if not fa_utils.use_v3:
2135
2136
                        fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1]
                    flash_attn_bwd(
2137
2138
2139
2140
2141
                        dout,
                        q,
                        kv[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[0],
                        kv[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[1],
                        out,
2142
2143
                        softmax_lse,
                        dq_,
2144
2145
2146
                        dkv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[0],
                        dkv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[1],
                        *fa_backward_args_thd,
2147
2148
                        causal=False,
                        **fa_backward_kwargs,
2149
2150
                    )

2151
2152
            if ctx.fp8:
                dq = dq_fp8[(rank + i + 1) % cp_size]
2153
2154
2155
            if causal and ctx.qkv_format in ["bshd", "sbhd"] and i >= (cp_size - rank - 1):
                # [b, sq, np, hn] -> [b, 2, sq//2, np, hn] or
                # [sq, b, np, hn] -> [2, sq//2, b, np, hn]
2156
                dq_ = dq_.view(*dq.shape)
2157

2158
2159
2160
2161
2162
2163
2164
2165
2166
2167
2168
            if ctx.fp8:
                if i >= (cp_size - rank - 1) or not causal:
                    dq.copy_(dq_)
                else:
                    if ctx.qkv_format == "bshd":
                        dq[:, 0, ...].fill_(0)
                        dq[:, 1, ...].copy_(dq_)
                    elif ctx.qkv_format == "sbhd":
                        dq[0].fill_(0)
                        dq[1].copy_(dq_)
            elif causal:
2169
                if i > (cp_size - rank - 1):
2170
                    dq.add_(dq_)
2171
2172
                elif i == (cp_size - rank - 1):
                    if rank == (cp_size - 1):
2173
2174
                        dq.copy_(dq_)
                    else:
2175
2176
2177
2178
2179
2180
                        if ctx.qkv_format == "bshd":
                            dq[:, 0, ...].copy_(dq_[:, 0, ...])
                            dq[:, 1, ...].add_(dq_[:, 1, ...])
                        elif ctx.qkv_format == "sbhd":
                            dq[0].copy_(dq_[0])
                            dq[1].add_(dq_[1])
2181
                        elif ctx.qkv_format == "thd":
2182
                            tex.thd_grad_correction(dq, dq_, cu_seqlens_q_padded, "copy", "add")
2183
                elif i > 0:
2184
2185
2186
2187
                    if ctx.qkv_format == "bshd":
                        dq[:, 1, ...].add_(dq_)
                    elif ctx.qkv_format == "sbhd":
                        dq[1].add_(dq_)
2188
                    elif ctx.qkv_format == "thd":
2189
                        tex.thd_grad_correction(dq, dq_, cu_seqlens_q_padded, "none", "add")
2190
                else:
2191
2192
2193
2194
                    if ctx.qkv_format == "bshd":
                        dq[:, 1, ...].copy_(dq_)
                    elif ctx.qkv_format == "sbhd":
                        dq[1].copy_(dq_)
2195
                    elif ctx.qkv_format == "thd":
2196
                        tex.thd_grad_correction(dq, dq_, cu_seqlens_q_padded, "none", "copy")
2197
2198
2199
2200
2201
            else:
                if i == 0:
                    dq.copy_(dq_)
                else:
                    dq.add_(dq_)
2202

2203
            if attn_dbias is not None:
2204
                idx = (rank + i + 1) % cp_size
2205
                if i == (cp_size - 1) or not causal:
2206
                    # [b, np, sq, sk//cp] -> [b, np, sq, 2, sk//(2*cp)]
2207
                    dbias_ = dbias_.view(*dbias_.shape[:-1], 2, dbias_.shape[-1] // 2)
2208
                    attn_dbias[..., idx, :].copy_(dbias_[..., 0, :])
2209
2210
                    attn_dbias[..., (2 * cp_size - idx - 1), :].copy_(dbias_[..., 1, :])
                elif i >= (cp_size - rank - 1):
2211
2212
2213
2214
                    # [b, np, sq, sk//(2*cp)]
                    attn_dbias[..., idx, :].copy_(dbias_)
                else:
                    # [b, np, sq//2, sk//cp] -> [b, np, sq//2, 2, sk//(2*cp)]
2215
                    dbias_ = dbias_.view(*dbias_.shape[:-1], 2, dbias_.shape[-1] // 2)
2216
                    attn_dbias_[..., 1, :, idx, :].copy_(dbias_[..., 0, :])
2217
                    attn_dbias_[..., 1, :, (2 * cp_size - idx - 1), :].copy_(dbias_[..., 1, :])
2218

2219
2220
2221
            # wait until dKV is received
            for req in send_recv_reqs:
                req.wait()
2222

2223
2224
2225
2226
2227
2228
2229
            if ctx.fp8:
                if i < cp_size - 1:
                    dkv = dkv_fp8_[(rank + i + 1) % cp_size]
                else:
                    dkv = dkv_fp8[(rank + i + 1) % cp_size]
            else:
                dkv = p2p_comm_buffers[(i + 1) % 2][1]
2230
            if ctx.use_fused_attention:
2231
                if ctx.qkv_format in ["bshd", "sbhd"]:
2232
2233
2234
2235
2236
2237
2238
2239
2240
2241
2242
2243
2244
2245
                    dkv_ = _combine_tensors([dk_, dv_], -2)
                elif ctx.qkv_format == "thd":
                    dkv_ = torch.cat(
                        (dk_.unsqueeze(0), dv_.unsqueeze(0)), dim=0
                    )  # pylint: disable=used-before-assignment
            if ctx.qkv_format in ["bshd", "sbhd"]:
                # [b, 2, sk//2, 2, np, hn] -> [2, b, 2, sk//2, np, hn] or
                # [2, sk//2, b, 2, np, hn] -> [2, 2, sk//2, b, np, hn]
                dkv = dkv.view(2, *dkv.shape[0:-3], *dkv.shape[-2:])
                dkv_ = dkv_.movedim(-3, 0)
                if causal and (i < (cp_size - rank - 1) or i == (cp_size - 1)):
                    # [2, b, sk, np, hn] -> [2, b, 2, sk//2, np, hn] or
                    # [2, sk, b, np, hn] -> [2, 2, sk//2, b, np, hn]
                    dkv_ = dkv_.view(*dkv.shape)
2246

2247
2248
2249
2250
2251
2252
2253
2254
2255
2256
2257
            if ctx.fp8:
                if causal and i >= (cp_size - rank - 1) and i != (cp_size - 1):
                    if ctx.qkv_format == "bshd":
                        dkv[:, :, 0, ...].copy_(dkv_)
                        dkv[:, :, 1, ...].fill_(0)
                    elif ctx.qkv_format == "sbhd":
                        dkv[:, 0, ...].copy_(dkv_)
                        dkv[:, 1, ...].fill_(0)
                else:
                    dkv.copy_(dkv_)
            elif causal:
2258
                if i == (cp_size - 1):
2259
                    if rank == 0:
2260
2261
2262
2263
2264
2265
                        if ctx.qkv_format == "bshd":
                            dkv[:, :, 0, ...].add_(dkv_[:, :, 0, ...])
                            dkv[:, :, 1, ...].copy_(dkv_[:, :, 1, ...])
                        elif ctx.qkv_format == "sbhd":
                            dkv[:, 0, ...].add_(dkv_[:, 0, ...])
                            dkv[:, 1, ...].copy_(dkv_[:, 1, ...])
2266
                        elif ctx.qkv_format == "thd":
2267
                            tex.thd_grad_correction(dkv, dkv_, cu_seqlens_kv_padded, "add", "copy")
2268
2269
                    else:
                        dkv.add_(dkv_)
2270
2271
                elif i >= (cp_size - rank - 1):
                    if i == 0 and rank == (cp_size - 1):
2272
2273
2274
2275
                        if ctx.qkv_format == "bshd":
                            dkv[:, :, 0, ...].copy_(dkv_)
                        elif ctx.qkv_format == "sbhd":
                            dkv[:, 0, ...].copy_(dkv_)
2276
                        elif ctx.qkv_format == "thd":
2277
                            tex.thd_grad_correction(dkv, dkv_, cu_seqlens_kv_padded, "copy", "none")
2278
                    else:
2279
2280
2281
2282
                        if ctx.qkv_format == "bshd":
                            dkv[:, :, 0, ...].add_(dkv_)
                        elif ctx.qkv_format == "sbhd":
                            dkv[:, 0, ...].add_(dkv_)
2283
                        elif ctx.qkv_format == "thd":
2284
                            tex.thd_grad_correction(dkv, dkv_, cu_seqlens_kv_padded, "add", "none")
2285
2286
2287
2288
2289
                elif i > 0:
                    dkv.add_(dkv_)
                else:
                    dkv.copy_(dkv_)
            else:
2290
2291
2292
2293
2294
                if i == 0:
                    dkv.copy_(dkv_)
                else:
                    dkv.add_(dkv_)

2295
        if ctx.fp8 and ctx.use_fused_attention:
2296
2297
2298
            amax_cp_bwd = amax_per_step.amax(dim=1)
            ctx.dP_quantizer.amax = amax_cp_bwd[0]
            ctx.dQKV_CP_quantizer.amax = amax_cp_bwd[1]
2299
2300
2301
2302
            if ctx.qkv_format in ["bshd", "sbhd"]:
                # [cp, b, 2, sk//2, 2, np, hn] -> [cp, 2, b, 2, sk//2, np, hn] or
                # [cp, 2, sk//2, b, 2, np, hn] -> [cp, 2, 2, sk//2, b, np, hn]
                dkv_fp8 = dkv_fp8.view(cp_size, 2, *dkv_fp8.shape[1:-3], *dkv_fp8.shape[-2:])
2303
2304
2305
2306
2307
2308
2309
            dq = ctx.dQKV_CP_quantizer.create_tensor_from_data(
                dq_fp8, fake_dtype=torch.float32, internal=True
            )
            dkv = ctx.dQKV_CP_quantizer.create_tensor_from_data(
                dkv_fp8, fake_dtype=torch.float32, internal=True
            )
            dq, dkv = [x.dequantize(dtype=torch.float32) for x in [dq, dkv]]
2310
2311
            dq, dkv = [x.sum(dim=0).to(dout_dtype) for x in [dq, dkv]]

2312
        if causal:
2313
2314
            if ctx.qkv_format == "bshd":
                # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
2315
                dq = dq.view(dq.shape[0], -1, *dq.shape[-2:])
2316
                # [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn]
2317
                dkv = dkv.view(*dkv.shape[0:2], -1, *dkv.shape[-2:])
2318
2319
            elif ctx.qkv_format == "sbhd":
                # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
2320
                dq = dq.view(-1, *dq.shape[-3:])
2321
                # [2, 2, sk//2, b, np, hn] -> [2, sk, b, np, hn]
2322
2323
                dkv = dkv.view(dkv.shape[0], -1, *dkv.shape[-3:])

2324
2325
2326
        if ctx.qkv_format == "thd" and not ctx.use_fused_attention:
            dq[cu_seqlens_q_padded[-1] :].fill_(0)
            dkv[:, cu_seqlens_kv_padded[-1] :].fill_(0)
2327

2328
        if ctx.fp8 and ctx.is_input_fp8:
2329
2330
            assert torch.uint8 not in [dq.dtype, dkv.dtype]
            dq, dkv = [ctx.dQKV_quantizer(x)._data for x in [dq, dkv]]
2331
2332
2333
        dk, dv = dkv[0], dkv[1]

        if cp_size_a2a > 1:
2334
            chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size_a2a, q.device)
2335
2336
2337
2338
2339
2340
2341
2342
2343
2344
2345
2346
2347
2348
            dq, dk, dv = flash_attn_a2a_communicate(
                [dq, dk, dv],
                chunk_ids_for_a2a,
                seq_dim,
                cp_size_a2a,
                ctx.cp_group_a2a,
                ctx.cp_stream,
                False,
            )
            if ctx.qkv_format == "bshd":
                dq, dk, dv = [x.view(ctx.batch_size, -1, *x.shape[-2:]) for x in [dq, dk, dv]]
            elif ctx.qkv_format == "sbhd":
                dq, dk, dv = [x.view(-1, ctx.batch_size, *x.shape[-2:]) for x in [dq, dk, dv]]

2349
2350
2351
        if attn_dbias is not None:
            # [b, np, sq, 2*cp, sk//(2*cp)] -> [b, np, sq, sk]
            attn_dbias = attn_dbias.view(*attn_dbias.shape[:-2], -1)
2352
2353
        # converting torch.uint8 to float8tensor
        if ctx.fp8 and ctx.is_input_fp8:
2354
2355
2356
            dq = ctx.dQKV_quantizer.create_tensor_from_data(dq, fake_dtype=dout_dtype)
            dk = ctx.dQKV_quantizer.create_tensor_from_data(dk, fake_dtype=dout_dtype)
            dv = ctx.dQKV_quantizer.create_tensor_from_data(dv, fake_dtype=dout_dtype)
2357
        nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVP2P.backward")
2358

2359
2360
2361
        return (
            None,
            dq,
2362
2363
            dk,
            dv,
2364
2365
2366
2367
2368
2369
2370
2371
2372
2373
2374
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
2375
            attn_dbias,
2376
2377
2378
2379
2380
            None,
            None,
            None,
            None,
            None,
2381
2382
            None,
            None,
2383
            None,
2384
            None,
2385
        )
2386
2387


2388
2389
def get_kv_seq_info_after_all_gather(
    local_chunk_id, cp_size, max_seqlen_q, max_seqlen_kv, window_size, causal
2390
):
2391
2392
2393
2394
2395
2396
2397
2398
2399
2400
2401
2402
2403
2404
2405
2406
2407
2408
2409
2410
2411
2412
    """Compute KV sequence index range and update window size after all-gather."""
    local_chunk_end_idx = (local_chunk_id + 1) * max_seqlen_kv
    full_seq_end_idx = max_seqlen_kv * cp_size * 2

    if window_size is None:
        window_size = (-1, 0) if causal else (-1, -1)

    if window_size[1] == -1:
        seq_end_idx = full_seq_end_idx
        window_size_right = -1
    else:
        seq_end_idx = min(full_seq_end_idx, local_chunk_end_idx + window_size[1])
        window_size_right = local_chunk_end_idx + window_size[1] - seq_end_idx

    if window_size[0] == -1:
        seq_start_idx = 0
        window_size_left = -1
    else:
        seq_start_idx = max(0, local_chunk_end_idx - max_seqlen_q - window_size[0])
        window_size_left = window_size[0] + seq_end_idx - local_chunk_end_idx

    return (seq_start_idx, seq_end_idx), (window_size_left, window_size_right)
2413
2414
2415
2416


class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
    """
2417
2418
    Attention implementation with context parallelism. KV all-gather between CP ranks is exposed.
    Refer section 3.3.2 of `The Llama 3 Herd of Models <https://arxiv.org/abs/2407.21783>`_.
2419
2420
2421
2422
2423
2424
2425
2426
2427
2428
2429
2430
2431
2432
2433
2434
2435
2436
2437
2438
2439
2440
    """

    @staticmethod
    def forward(
        ctx,
        is_training,
        q,
        k,
        v,
        cu_seqlens_q,
        max_seqlen_q,
        max_seqlen_kv,
        cu_seqlens_q_padded,
        dropout_p,
        softmax_scale,
        qkv_format,
        attn_mask_type,
        attn_bias_type,
        attn_bias,
        deterministic,
        use_fused_attention,
        window_size,
2441
2442
        cp_group,
        cp_stream,
2443
    ):
2444
        # pylint: disable=missing-function-docstring
2445
        nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVAllGather.forward")
2446
2447
2448
2449
2450
2451
        if softmax_scale is None:
            softmax_scale = q.shape[-1] ** (-0.5)

        cp_size = get_distributed_world_size(cp_group)
        rank = get_distributed_rank(cp_group)

2452
2453
        qkv_dtype = q.dtype

2454
2455
        causal = "causal" in attn_mask_type
        padding = "padding" in attn_mask_type
2456
        assert not padding, f"{attn_mask_type} mask type is not supported!"
2457
2458
2459
2460
2461
        if use_fused_attention and causal and "bottom_right" not in attn_mask_type:
            attn_mask_type = attn_mask_type + "_bottom_right"
        assert attn_bias_type == "no_bias", f"{attn_bias_type} bias type is not supported!"
        assert q.shape[-1] % 8 == 0, "Hidden size per attention head should be multiple of 8!"
        assert (
2462
            use_fused_attention or fa_utils.v2_3_plus
2463
        ), "Sliding window attention only can work with FusedAttention or FlashAttention >= 2.3!"
2464

2465
        flash_attn_fwd = None
2466
2467
        if not use_fused_attention:
            fa_forward_kwargs = {"softmax_scale": softmax_scale}
2468
            if fa_utils.use_v3:
2469
2470
2471
2472
                if qkv_format == "thd":
                    flash_attn_fwd = _flash_attn_varlen_fwd_v3
                else:
                    flash_attn_fwd = _flash_attn_fwd_v3
2473
            else:
2474
2475
2476
2477
                if qkv_format == "thd":
                    flash_attn_fwd = _flash_attn_varlen_fwd
                else:
                    flash_attn_fwd = _flash_attn_fwd
2478
2479
                fa_forward_kwargs["dropout_p"] = dropout_p
                fa_forward_kwargs["return_softmax"] = False
2480
                if fa_utils.v2_4_plus:
2481
                    fa_forward_kwargs["alibi_slopes"] = None
2482
                if fa_utils.v2_5_7_plus and qkv_format == "thd":
2483
                    fa_forward_kwargs["block_table"] = None
2484
                if fa_utils.v2_6_0_plus:
2485
                    fa_forward_kwargs["softcap"] = 0.0
2486
2487
2488
2489
2490
2491
2492
2493
2494
2495
2496

        assert qkv_format != "thd", f"{qkv_format} format is not supported!"
        qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format

        seq_dim = qkv_format.index("s")
        assert (
            q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0
        ), "Sequence length per GPU needs to be divisible by 2!"

        max_seqlen_q = max_seqlen_q // (2 * cp_size)
        max_seqlen_kv = max_seqlen_kv // (2 * cp_size)
2497
2498
        if use_fused_attention or qkv_format == "thd":
            cu_seqlens_q = cu_seqlens_q // (2 * cp_size)
2499
2500
2501
2502
        if cu_seqlens_q_padded is not None and qkv_format == "thd":
            cu_seqlens_q_padded = cu_seqlens_q_padded // (2 * cp_size)
        else:
            cu_seqlens_q_padded = None
2503

2504
2505
2506
2507
        # [b, s, np, hn] -> [b, 2, s//2, np, hn] or [s, b, np, hn] -> [2, s//2, b, np, hn]
        q = q.view(*q.shape[:seq_dim], 2, q.shape[seq_dim] // 2, *q.shape[(seq_dim + 1) :])
        # [b, s, np, hn] or [s, b, np, hn] -> [s, b, np, hn]
        k, v = [x.movedim(seq_dim, 0).contiguous() for x in [k, v]]
2508

2509
        # [s, b, np, hn] -> [cp, s, b, np, hn]
2510
2511
        k_ag, _ = gather_along_first_dim(k, cp_group)
        v_ag, _ = gather_along_first_dim(v, cp_group)
2512
2513

        # [cp, s, b, np, hn] -> [cp*2, s//2, b, np, hn]
2514
2515
        k_ag = k_ag.view(2 * cp_size, k.shape[0] // 2, *k.shape[1:])
        v_ag = v_ag.view(2 * cp_size, v.shape[0] // 2, *v.shape[1:])
2516
        chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_before_attn(cp_size, k.device)
2517
2518
2519
2520
2521
2522
2523
2524
2525
        k_ag = torch.index_select(k_ag, dim=0, index=chunk_ids_for_kv_ag)
        v_ag = torch.index_select(v_ag, dim=0, index=chunk_ids_for_kv_ag)
        # [cp*2, s//2, b, np, hn] -> [cp*s, b, np, hn]
        k_ag = k_ag.view(-1, *k.shape[1:])
        v_ag = v_ag.view(-1, *v.shape[1:])
        cp_stream.wait_stream(torch.cuda.current_stream())

        # create two streams to resolve wave quantization issue of Flash Attn in each step
        flash_attn_streams = [torch.cuda.current_stream(), cp_stream]
2526
2527

        local_seq_chunk_ids = [rank, 2 * cp_size - rank - 1]
2528
2529
2530
        kv_seq_range_per_step = [None, None]
        window_size_per_step = [None, None]
        cu_seqlens_kv_per_step = [None, None]
2531
2532
2533
2534
2535
2536
2537
2538
        out_per_step = [None, None]
        softmax_lse_per_step = [None, None]
        rng_states = [None, None]
        out = torch.empty_like(q)

        for i in range(len(local_seq_chunk_ids) + 1):
            if i < len(local_seq_chunk_ids):
                with torch.cuda.stream(flash_attn_streams[i]):
2539
2540
                    # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
                    # or [2, sq//2, b, np, hn] -> [sq//2, b, np, hn]
2541
2542
2543
2544
2545
2546
2547
2548
2549
                    q_ = q.select(seq_dim, i).contiguous()
                    kv_seq_range_per_step[i], window_size_per_step[i] = (
                        get_kv_seq_info_after_all_gather(
                            local_seq_chunk_ids[i],
                            cp_size,
                            max_seqlen_q,
                            max_seqlen_kv,
                            window_size,
                            causal,
2550
                        )
2551
2552
2553
2554
2555
2556
                    )
                    seq_start_idx, seq_end_idx = (
                        kv_seq_range_per_step[i][0],
                        kv_seq_range_per_step[i][1],
                    )
                    max_seqlen_kv_ = seq_end_idx - seq_start_idx
2557
                    if use_fused_attention or qkv_format == "thd":
2558
                        cu_seqlens_kv_per_step[i] = dpa_utils.get_full_cu_seqlens(
2559
2560
                            k.shape[1], max_seqlen_kv_, k.device
                        )
2561
2562
2563
                    k_, v_ = [x[seq_start_idx:seq_end_idx] for x in [k_ag, v_ag]]
                    # [s_range, b, np, hn] -> [b, s_range, np, hn] or [s_range, b, np, hn]
                    k_, v_ = [x.movedim(0, seq_dim).contiguous() for x in [k_, v_]]
2564
2565
2566
2567
                    if use_fused_attention:
                        out_per_step[i], [softmax_lse_per_step[i], rng_states[i]] = fused_attn_fwd(
                            is_training,
                            max_seqlen_q,
2568
                            max_seqlen_kv_,
2569
                            cu_seqlens_q,
2570
                            cu_seqlens_kv_per_step[i],
2571
2572
2573
                            q_,
                            k_,
                            v_,
2574
                            qkv_dtype,
2575
2576
2577
2578
2579
2580
2581
2582
                            tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
                            attn_scale=softmax_scale,
                            dropout=dropout_p,
                            qkv_layout=qkv_layout,
                            attn_mask_type=attn_mask_type,
                            attn_bias_type=attn_bias_type,
                            attn_bias=attn_bias,
                            cu_seqlens_q_padded=cu_seqlens_q_padded,
2583
2584
                            cu_seqlens_kv_padded=cu_seqlens_kv_per_step[i],
                            window_size=window_size_per_step[i],
2585
2586
                        )
                    else:
2587
2588
2589
2590
2591
2592
2593
2594
                        fa_forward_args_thd = []
                        if qkv_format == "thd":
                            fa_forward_args_thd = [
                                cu_seqlens_q,
                                cu_seqlens_kv_per_step[i],
                                max_seqlen_q,
                                max_seqlen_kv_,
                            ]
2595
                        if fa_utils.use_v3 or (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus):
2596
                            fa_forward_kwargs["window_size"] = window_size_per_step[i]
2597
                        elif fa_utils.v2_7_0_plus:
2598
2599
                            fa_forward_kwargs["window_size_left"] = window_size_per_step[i][0]
                            fa_forward_kwargs["window_size_right"] = window_size_per_step[i][1]
2600
2601
2602
2603
                        fa_outputs = flash_attn_fwd(
                            q_,
                            k_,
                            v_,
2604
                            *fa_forward_args_thd,
2605
2606
                            causal=causal,
                            **fa_forward_kwargs,
2607
                        )
2608
                        if not fa_utils.v2_7_0_plus:
2609
2610
                            out_per_step[i] = fa_outputs[4]
                            softmax_lse_per_step[i] = fa_outputs[5]
2611
                            if not fa_utils.use_v3:
2612
2613
2614
2615
                                rng_states[i] = fa_outputs[7]
                        else:
                            out_per_step[i] = fa_outputs[0]
                            softmax_lse_per_step[i] = fa_outputs[1]
2616
                            if not fa_utils.use_v3:
2617
                                rng_states[i] = fa_outputs[3]
2618
2619
2620
2621

            if i > 0:
                with torch.cuda.stream(flash_attn_streams[i - 1]):
                    if qkv_format == "bshd":
2622
                        out[:, i - 1].copy_(out_per_step[i - 1])
2623
                    elif qkv_format == "sbhd":
2624
                        out[i - 1].copy_(out_per_step[i - 1])
2625
2626
2627
2628
2629
2630
2631
2632
2633
2634
2635
2636
2637
2638
2639
2640
2641

        torch.cuda.current_stream().wait_stream(cp_stream)

        if use_fused_attention:
            if qkv_format == "bshd":
                out = out.view(out.shape[0], -1, *out.shape[-2:])
            elif qkv_format == "sbhd":
                out = out.view(-1, *out.shape[-3:])
        else:
            out = out.view(-1, *out.shape[-2:])

        ctx.save_for_backward(
            q,
            k,
            v,
            cu_seqlens_q,
            cu_seqlens_q_padded,
2642
            *cu_seqlens_kv_per_step,
2643
2644
2645
2646
            *out_per_step,
            *softmax_lse_per_step,
            *rng_states,
        )
2647
2648

        ctx.qkv_dtype = qkv_dtype
2649
2650
        ctx.kv_seq_range_per_step = kv_seq_range_per_step
        ctx.window_size_per_step = window_size_per_step
2651
2652
2653
2654
2655
2656
2657
        ctx.cp_group = cp_group
        ctx.cp_stream = cp_stream
        ctx.dropout_p = dropout_p
        ctx.max_seqlen_q = max_seqlen_q
        ctx.softmax_scale = softmax_scale
        ctx.qkv_format = qkv_format
        ctx.attn_bias_type = attn_bias_type
2658
        ctx.attn_mask_type = attn_mask_type
2659
2660
        ctx.deterministic = deterministic
        ctx.use_fused_attention = use_fused_attention
2661
        nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVAllGather.forward")
2662
2663
2664
2665
        return out

    @staticmethod
    def backward(ctx, dout):
2666
        # pylint: disable=missing-function-docstring
2667
        nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVAllGather.backward")
2668
2669
2670
        cp_size = get_distributed_world_size(ctx.cp_group)
        rank = get_distributed_rank(ctx.cp_group)

2671
2672
2673
2674
2675
2676
        (*saved_tensors,) = ctx.saved_tensors
        (q, k, v, cu_seqlens_q, cu_seqlens_q_padded) = saved_tensors[:5]
        cu_seqlens_kv_per_step = saved_tensors[5:7]
        out_per_step = saved_tensors[7:9]
        softmax_lse_per_step = saved_tensors[9:11]
        rng_states = saved_tensors[11:13]
2677
2678
        kv_seq_range_per_step = ctx.kv_seq_range_per_step
        window_size_per_step = ctx.window_size_per_step
2679

2680
        seq_dim = ctx.qkv_format.index("s")
2681
2682
        qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format

2683
        dout = dout.view(q.shape)
2684
        dq = torch.empty_like(q)
2685
        dk = torch.zeros((k.shape[0] * cp_size, *k.shape[1:]), dtype=k.dtype, device=k.device)
2686
2687
2688
2689
2690
2691
2692
2693
2694
2695
        dv = torch.zeros_like(dk)
        dq_per_step = [None, None]
        dk_per_step = [None, None]
        dv_per_step = [None, None]

        # create two streams to resolve wave quantization issue of Flash Attn in each step
        flash_attn_streams = [torch.cuda.current_stream(), ctx.cp_stream]
        # synchronize dkv update across steps
        dkv_update_done = torch.cuda.Event()

2696
        # [s, b, np, hn] -> [cp, s, b, np, hn]
2697
2698
        k_ag, _ = gather_along_first_dim(k, ctx.cp_group)
        v_ag, _ = gather_along_first_dim(v, ctx.cp_group)
2699
2700

        # [cp, s, b, np, hn] -> [cp*2, s//2, b, np, hn]
2701
2702
        k_ag = k_ag.view(2 * cp_size, k.shape[0] // 2, *k.shape[1:])
        v_ag = v_ag.view(2 * cp_size, v.shape[0] // 2, *v.shape[1:])
2703
        chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_before_attn(cp_size, k.device)
2704
2705
2706
2707
2708
2709
        k_ag = torch.index_select(k_ag, dim=0, index=chunk_ids_for_kv_ag)
        v_ag = torch.index_select(v_ag, dim=0, index=chunk_ids_for_kv_ag)
        # [cp*2, s//2, b, np, hn] -> [cp*s, b, np, hn]
        k_ag = k_ag.view(-1, *k.shape[1:])
        v_ag = v_ag.view(-1, *v.shape[1:])
        ctx.cp_stream.wait_stream(torch.cuda.current_stream())
2710
2711
2712

        local_seq_chunk_ids = [rank, 2 * cp_size - rank - 1]

2713
        flash_attn_bwd = None
2714
2715
        if not ctx.use_fused_attention:
            fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale}
2716
            if fa_utils.use_v3:
2717
2718
2719
2720
                if ctx.qkv_format == "thd":
                    flash_attn_bwd = _flash_attn_varlen_bwd_v3
                else:
                    flash_attn_bwd = _flash_attn_bwd_v3
2721
2722
                fa_backward_kwargs["deterministic"] = ctx.deterministic
            else:
2723
2724
2725
2726
                if ctx.qkv_format == "thd":
                    flash_attn_bwd = _flash_attn_varlen_bwd
                else:
                    flash_attn_bwd = _flash_attn_bwd
2727
                fa_backward_kwargs["dropout_p"] = ctx.dropout_p
2728
                if fa_utils.v2_4_plus:
2729
                    fa_backward_kwargs["alibi_slopes"] = None
2730
                if fa_utils.v2_4_1_plus:
2731
                    fa_backward_kwargs["deterministic"] = ctx.deterministic
2732
                if fa_utils.v2_6_0_plus:
2733
                    fa_backward_kwargs["softcap"] = 0.0
2734
2735
2736
2737

        for i in range(len(local_seq_chunk_ids) + 1):
            if i < len(local_seq_chunk_ids):
                with torch.cuda.stream(flash_attn_streams[i]):
2738
2739
                    # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
                    # or [2, sq//2, b, np, hn] -> [sq//2, b, np, hn]
2740
2741
2742
2743
2744
2745
2746
2747
2748
                    q_ = q.select(seq_dim, i).contiguous()
                    seq_start_idx, seq_end_idx = (
                        kv_seq_range_per_step[i][0],
                        kv_seq_range_per_step[i][1],
                    )
                    max_seqlen_kv = seq_end_idx - seq_start_idx
                    k_, v_ = [x[seq_start_idx:seq_end_idx] for x in [k_ag, v_ag]]
                    # [cp*s, b, np, hn] -> [b, s_range, np, hn] or [s_range, b, np, hn]
                    k_, v_ = [x.movedim(0, seq_dim).contiguous() for x in [k_, v_]]
2749
                    out_ = out_per_step[i]
2750
                    dout_ = dout.select(seq_dim, i).contiguous().view(out_.shape)
2751
2752
2753
2754
                    if ctx.use_fused_attention:
                        aux_ctx_tensors = [softmax_lse_per_step[i], rng_states[i]]
                        dq_per_step[i], dk_per_step[i], dv_per_step[i], _ = fused_attn_bwd(
                            ctx.max_seqlen_q,
2755
                            max_seqlen_kv,
2756
                            cu_seqlens_q,
2757
                            cu_seqlens_kv_per_step[i],
2758
2759
2760
2761
2762
                            q_,
                            k_,
                            v_,
                            out_,
                            dout_,
2763
                            ctx.qkv_dtype,
2764
                            TE_DType[dout.dtype],
2765
2766
2767
                            aux_ctx_tensors,
                            tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
                            cu_seqlens_q_padded=cu_seqlens_q_padded,
2768
                            cu_seqlens_kv_padded=cu_seqlens_kv_per_step[i],
2769
2770
2771
2772
2773
                            attn_scale=ctx.softmax_scale,
                            dropout=ctx.dropout_p,
                            qkv_layout=qkv_layout,
                            attn_mask_type=ctx.attn_mask_type,
                            attn_bias_type=ctx.attn_bias_type,
2774
2775
                            window_size=window_size_per_step[i],
                            deterministic=ctx.deterministic,
2776
2777
2778
2779
2780
                        )
                    else:
                        dq_per_step[i], dk_per_step[i], dv_per_step[i] = [
                            torch.empty_like(x) for x in [q_, k_, v_]
                        ]
2781
2782
2783
2784
2785
2786
2787
2788
                        fa_backward_args_thd = []
                        if ctx.qkv_format == "thd":
                            fa_backward_args_thd = [
                                cu_seqlens_q,
                                cu_seqlens_kv_per_step[i],
                                ctx.max_seqlen_q,
                                max_seqlen_kv,
                            ]
2789
                        if not fa_utils.use_v3:
2790
                            fa_backward_kwargs["rng_state"] = rng_states[i]
2791
                        if fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus:
2792
                            fa_backward_kwargs["window_size"] = window_size_per_step[i]
2793
                        if fa_utils.v2_7_0_plus:
2794
2795
                            fa_backward_kwargs["window_size_left"] = window_size_per_step[i][0]
                            fa_backward_kwargs["window_size_right"] = window_size_per_step[i][1]
2796
                        flash_attn_bwd(
2797
2798
2799
2800
2801
2802
2803
2804
2805
                            dout_,
                            q_,
                            k_,
                            v_,
                            out_,
                            softmax_lse_per_step[i],
                            dq_per_step[i],
                            dk_per_step[i],
                            dv_per_step[i],
2806
                            *fa_backward_args_thd,
2807
2808
                            causal="causal" in ctx.attn_mask_type,
                            **fa_backward_kwargs,
2809
2810
2811
2812
2813
                        )

            if i > 0:
                with torch.cuda.stream(flash_attn_streams[i - 1]):
                    if ctx.qkv_format == "bshd":
2814
                        dq[:, i - 1].copy_(dq_per_step[i - 1])
2815
                    elif ctx.qkv_format == "sbhd":
2816
2817
2818
2819
2820
2821
                        dq[i - 1].copy_(dq_per_step[i - 1])
                    # [b, s_range, np, hn] or [s_range, b, np, hn] -> [s_range, b, np, hn]
                    dk_per_step[i - 1], dv_per_step[i - 1] = [
                        x.movedim(seq_dim, 0).contiguous()
                        for x in [dk_per_step[i - 1], dv_per_step[i - 1]]
                    ]
2822
2823
2824
                    # wait until dkv update of last step is done
                    if i > 1:
                        flash_attn_streams[i - 1].wait_event(dkv_update_done)
2825
2826
2827
2828
2829
2830
                    seq_start_idx, seq_end_idx = (
                        kv_seq_range_per_step[i - 1][0],
                        kv_seq_range_per_step[i - 1][1],
                    )
                    dk[seq_start_idx:seq_end_idx].add_(dk_per_step[i - 1])
                    dv[seq_start_idx:seq_end_idx].add_(dv_per_step[i - 1])
2831
2832
2833
2834
2835
                    if i < len(local_seq_chunk_ids):
                        flash_attn_streams[i - 1].record_event(dkv_update_done)

        torch.cuda.current_stream().wait_stream(ctx.cp_stream)

2836
2837
2838
        # [cp*s, b, np, hn] -> [cp*2, s//2, b, np, hn]
        dk = dk.view(2 * cp_size, -1, *dk.shape[-3:])
        dv = dv.view(2 * cp_size, -1, *dv.shape[-3:])
2839
        chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_after_attn(cp_size, dk.device)
2840
2841
2842
        dk = torch.index_select(dk, dim=0, index=chunk_ids_for_kv_ag)
        dv = torch.index_select(dv, dim=0, index=chunk_ids_for_kv_ag)
        # [cp*2, s//2, b, np, hn] -> [cp*s, b, np, hn]
2843
2844
2845
2846
2847
        dk = dk.view(-1, *dk.shape[-3:])
        dv = dv.view(-1, *dv.shape[-3:])
        dk, _ = reduce_scatter_along_first_dim(dk, ctx.cp_group)
        dv, _ = reduce_scatter_along_first_dim(dv, ctx.cp_group)

2848
2849
2850
        dq = dq.view(*dq.shape[:seq_dim], -1, *dq.shape[(seq_dim + 2) :])
        dk = dk.movedim(0, seq_dim).contiguous()
        dv = dv.movedim(0, seq_dim).contiguous()
2851
        nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVAllGather.backward")
2852
2853
2854
2855
2856
2857
2858
2859
2860
2861
2862
2863
2864
2865
2866
2867
2868
2869
2870
2871
2872
2873
2874
2875
2876
2877
2878
2879
2880
2881
2882
2883
2884
2885
2886
2887
2888
2889
2890
2891
2892
2893
2894
2895
2896
2897
2898
2899
2900
2901
2902
2903
2904
2905
2906
2907

        return (
            None,
            dq,
            dk,
            dv,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
        )


class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
    """
    Attention implementation with context parallelism. Like Ulysses, applying A2A to QKVO.
    Refer the paper `DeepSpeed Ulysses <https://arxiv.org/abs/2309.14509>`_.
    """

    @staticmethod
    def forward(
        ctx,
        is_training,
        q,
        k,
        v,
        cu_seqlens_q,
        cu_seqlens_kv,
        max_seqlen_q,
        max_seqlen_kv,
        cu_seqlens_q_padded,
        cu_seqlens_kv_padded,
        dropout_p,
        softmax_scale,
        qkv_format,
        attn_mask_type,
        attn_bias_type,
        attn_bias,
        deterministic,
        use_fused_attention,
        window_size,
        fp8,
        fp8_meta,
        cp_group,
        cp_stream,
2908
        quantizers,
2909
    ):
2910
        # pylint: disable=missing-function-docstring
2911
        nvtx_range_push("transformer_engine.AttnFuncWithCPAndQKVOA2A.forward")
2912
2913
2914
2915
        if softmax_scale is None:
            softmax_scale = q.shape[-1] ** (-0.5)

        cp_size = get_distributed_world_size(cp_group)
2916
        qkv_dtype = q.dtype
2917
2918
2919
2920
2921
2922
2923
2924
2925
2926

        causal = "causal" in attn_mask_type
        padding = "padding" in attn_mask_type
        assert not padding, f"{attn_mask_type} mask type is not supported!"
        assert attn_bias_type == "no_bias", f"{attn_bias_type} bias type is not supported!"
        assert q.shape[-1] % 8 == 0, "Hidden size per attention head should be multiple of 8!"
        assert (
            window_size == (-1, 0)
            or window_size == (-1, -1)
            or use_fused_attention
2927
            or fa_utils.v2_3_plus
2928
        ), "Sliding window attention only can work with FusedAttention or FlashAttention >= 2.3!"
2929

2930
        flash_attn_fwd = None
2931
2932
        if not use_fused_attention:
            fa_forward_kwargs = {"softmax_scale": softmax_scale}
2933
            if fa_utils.use_v3:
2934
2935
2936
2937
                if qkv_format == "thd":
                    flash_attn_fwd = _flash_attn_varlen_fwd_v3
                else:
                    flash_attn_fwd = _flash_attn_fwd_v3
2938
2939
                fa_forward_kwargs["window_size"] = window_size
            else:
2940
2941
2942
2943
                if qkv_format == "thd":
                    flash_attn_fwd = _flash_attn_varlen_fwd
                else:
                    flash_attn_fwd = _flash_attn_fwd
2944
2945
                fa_forward_kwargs["dropout_p"] = dropout_p
                fa_forward_kwargs["return_softmax"] = False
2946
                if fa_utils.use_v3 or (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus):
2947
                    fa_forward_kwargs["window_size"] = window_size
2948
                elif fa_utils.v2_7_0_plus:
2949
2950
                    fa_forward_kwargs["window_size_left"] = window_size[0]
                    fa_forward_kwargs["window_size_right"] = window_size[1]
2951
                if fa_utils.v2_4_plus:
2952
                    fa_forward_kwargs["alibi_slopes"] = None
2953
                if fa_utils.v2_5_7_plus and qkv_format == "thd":
2954
                    fa_forward_kwargs["block_table"] = None
2955
                if fa_utils.v2_6_0_plus:
2956
                    fa_forward_kwargs["softcap"] = 0.0
2957
2958
2959
2960
2961
2962
2963
2964
2965
2966
2967
2968
2969
2970

        assert (
            q.shape[-2] % cp_size == 0 and k.shape[-2] % cp_size == 0
        ), "The number of attention heads needs to be divisible by CP size!"

        assert qkv_format != "thd", f"{qkv_format} format is not supported!"
        qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format

        batch_dim = qkv_format.index("b")
        seq_dim = qkv_format.index("s")
        assert (
            q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0
        ), "Sequence length per GPU needs to be divisible by 2!"

2971
        fused_attn_backend = None
2972
2973
        # "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype
        is_input_fp8 = False
2974
2975
2976
        is_output_fp8 = False

        QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer = (
2977
            dpa_utils.get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False)
2978
2979
2980
        )
        if fp8:
            if use_fused_attention:
2981
                fused_attn_backend = FusedAttnBackend["FP8"]
2982
2983
2984
2985
                assert isinstance(k, q.__class__) and isinstance(
                    v, q.__class__
                ), "q, k, and v must have the same type."
                is_input_fp8 = isinstance(q, Float8Tensor)
2986
                is_output_fp8 = fp8_meta is not None and fp8_meta["recipe"].fp8_mha
2987
                if is_input_fp8:
2988
                    QKV_quantizer = q._quantizer
2989
2990
2991
2992
                    q_fp8, k_fp8, v_fp8 = q, k, v
                    q, k, v = q_fp8._data, k_fp8._data, v_fp8._data
                elif int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
                    q_f16, k_f16, v_f16 = q, k, v
2993
                    q, k, v = [QKV_quantizer(x)._data for x in [q_f16, k_f16, v_f16]]
2994
                fp8_meta_kwargs = {}
2995
2996
                fp8_meta_kwargs["s_quantizer"] = S_quantizer
                fp8_meta_kwargs["o_quantizer"] = O_quantizer  # partial result quantizer
2997
2998
2999
3000
3001
3002
3003
            else:
                assert False, "FP8 is only supported with Fused Attention!"
        else:
            if use_fused_attention:
                fp8_meta_kwargs = {}
                fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"]

3004
        chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_before_attn(cp_size, q.device)
3005
3006
3007
3008
        q, k, v = flash_attn_a2a_communicate(
            [q, k, v], chunk_ids_for_a2a, seq_dim, cp_size, cp_group, cp_stream, True
        )

3009
        if fp8 and not is_input_fp8 and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
3010
            q_f16, k_f16, v_f16 = q, k, v
3011
            q, k, v = [QKV_quantizer(x)._data for x in [q_f16, k_f16, v_f16]]
3012
3013
3014

        batch_size = q.shape[batch_dim]
        if use_fused_attention:
3015
3016
3017
3018
3019
3020
3021
3022
3023
3024
3025
            q_part, k_part, v_part = q, k, v
            if fp8:
                q_part = QKV_quantizer.create_tensor_from_data(
                    q, fake_dtype=qkv_dtype, internal=True
                )
                k_part = QKV_quantizer.create_tensor_from_data(
                    k, fake_dtype=qkv_dtype, internal=True
                )
                v_part = QKV_quantizer.create_tensor_from_data(
                    v, fake_dtype=qkv_dtype, internal=True
                )
3026
3027
3028
3029
3030
3031
            out, aux_ctx_tensors = fused_attn_fwd(
                is_training,
                max_seqlen_q,
                max_seqlen_kv,
                cu_seqlens_q,
                cu_seqlens_kv,
3032
3033
3034
3035
                q_part,
                k_part,
                v_part,
                qkv_dtype,
3036
3037
3038
3039
3040
3041
3042
3043
3044
3045
3046
3047
                fused_attn_backend,
                attn_scale=softmax_scale,
                dropout=dropout_p,
                qkv_layout=qkv_layout,
                attn_mask_type=attn_mask_type,
                attn_bias_type=attn_bias_type,
                attn_bias=attn_bias,
                cu_seqlens_q_padded=cu_seqlens_q_padded,
                cu_seqlens_kv_padded=cu_seqlens_kv_padded,
                window_size=window_size,
                **fp8_meta_kwargs,
            )
3048
3049
            if fp8:
                out = out._data
3050
        else:
3051
3052
3053
3054
3055
3056
3057
3058
            fa_forward_args_thd = []
            if qkv_format == "thd":
                fa_forward_args_thd = [
                    cu_seqlens_q,
                    cu_seqlens_kv,
                    max_seqlen_q,
                    max_seqlen_kv,
                ]
3059
            fa_outputs = flash_attn_fwd(
3060
3061
3062
                q,
                k,
                v,
3063
                *fa_forward_args_thd,
3064
                causal=causal,
3065
                **fa_forward_kwargs,
3066
            )
3067
            if not fa_utils.v2_7_0_plus:
3068
                out, softmax_lse = fa_outputs[4], fa_outputs[5]
3069
                rng_state = fa_outputs[7] if not fa_utils.use_v3 else None
3070
3071
            else:
                out, softmax_lse = fa_outputs[0], fa_outputs[1]
3072
                rng_state = fa_outputs[3] if not fa_utils.use_v3 else None
3073
3074
            aux_ctx_tensors = [softmax_lse, rng_state]

3075
        chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size, out.device)
3076
3077
3078
3079
3080
3081
3082
3083
3084
3085
3086
3087
3088
        out = flash_attn_a2a_communicate(
            out, chunk_ids_for_a2a, seq_dim, cp_size, cp_group, cp_stream, False
        )

        if use_fused_attention:
            if qkv_format == "bshd":
                # [b*s, np, hn] -> [b, s, np, hn]
                out = out.view(batch_size, -1, *out.shape[-2:])
            elif qkv_format == "sbhd":
                # [s*b, np, hn] -> [s, b, np, hn]
                out = out.view(-1, batch_size, *out.shape[-2:])

        if fp8:
3089
            if is_output_fp8:
3090
3091
                out_fp8 = O_quantizer.create_tensor_from_data(
                    out, fake_dtype=qkv_dtype, internal=False
3092
3093
                )
                out_ret = out_fp8
3094
                out = out_fp8._data
3095
            else:
3096
                out_fp8 = O_quantizer.create_tensor_from_data(
3097
                    out, fake_dtype=qkv_dtype, internal=True
3098
                )
3099
                out_f16 = out_fp8.dequantize(dtype=qkv_dtype)
3100
3101
3102
3103
                out_ret = out_f16
        else:
            out_ret = out

3104
        if not fp8 or int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
3105
            q_save, k_save, v_save, out_save = q, k, v, out
3106
3107
3108
3109
3110
3111
3112
3113
3114
        else:
            if is_input_fp8:
                q_save, k_save, v_save = q, k, v
            else:
                q_save, k_save, v_save = q_f16, k_f16, v_f16
            if is_output_fp8:
                out_save = out
            else:
                out_save = out_f16
3115

3116
        tensors_to_save, tensor_objects = prepare_for_saving(
3117
3118
3119
3120
3121
3122
3123
3124
3125
3126
            q_save,
            k_save,
            v_save,
            out_save,
            cu_seqlens_q,
            cu_seqlens_kv,
            cu_seqlens_q_padded,
            cu_seqlens_kv_padded,
            *aux_ctx_tensors,
        )
3127
3128
3129
3130
3131
3132
3133
3134
3135
3136
3137
        ctx.save_for_backward(*tensors_to_save)
        ctx.tensor_objects = tensor_objects

        ctx.qkv_dtype = qkv_dtype
        ctx.QKV_quantizer = QKV_quantizer
        ctx.O_quantizer = O_quantizer
        ctx.S_quantizer = S_quantizer
        ctx.dQKV_quantizer = dQKV_quantizer
        ctx.dO_quantizer = dO_quantizer
        ctx.dP_quantizer = dP_quantizer

3138
3139
3140
3141
3142
3143
3144
3145
3146
3147
3148
3149
3150
3151
3152
        ctx.batch_size = batch_size
        ctx.cp_group = cp_group
        ctx.cp_stream = cp_stream
        ctx.dropout_p = dropout_p
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_kv = max_seqlen_kv
        ctx.softmax_scale = softmax_scale
        ctx.qkv_format = qkv_format
        ctx.attn_mask_type = attn_mask_type
        ctx.attn_bias_type = attn_bias_type
        ctx.deterministic = deterministic
        ctx.window_size = window_size
        ctx.use_fused_attention = use_fused_attention
        ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
        ctx.fp8_meta = fp8_meta
3153
3154
        ctx.is_input_fp8 = is_input_fp8
        ctx.is_output_fp8 = is_output_fp8
3155
        nvtx_range_pop("transformer_engine.AttnFuncWithCPAndQKVOA2A.forward")
3156
3157
3158
3159
        return out_ret

    @staticmethod
    def backward(ctx, dout):
3160
        # pylint: disable=missing-function-docstring
3161
        nvtx_range_push("transformer_engine.AttnFuncWithCPAndQKVOA2A.backward")
3162
3163
        cp_size = get_distributed_world_size(ctx.cp_group)

3164
3165
3166
3167
3168
3169
3170
3171
3172
3173
3174
        (
            q,
            k,
            v,
            out,
            cu_seqlens_q,
            cu_seqlens_kv,
            cu_seqlens_q_padded,
            cu_seqlens_kv_padded,
            *aux_ctx_tensors,
        ) = restore_from_saved(ctx.tensor_objects, ctx.saved_tensors)
3175
3176
3177
3178
3179

        qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format
        causal = "causal" in ctx.attn_mask_type
        seq_dim = ctx.qkv_format.index("s")

3180
        dout_dtype = dout.dtype
3181
3182
        fused_attn_backend = None
        fused_attn_dqkv_dtype = None
3183
3184
3185
        if ctx.fp8:
            if ctx.use_fused_attention:
                fused_attn_backend = FusedAttnBackend["FP8"]
3186
                if ctx.is_output_fp8:
3187
                    assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!"
3188
                    ctx.dO_quantizer = dout._quantizer
3189
                else:
3190
3191
3192
                    dout = ctx.dO_quantizer(dout)
                fused_attn_dqkv_dtype = dout._fp8_dtype
                dout = dout._data
3193
                fp8_meta_kwargs = {}
3194
3195
3196
3197
                fp8_meta_kwargs["s_quantizer"] = ctx.S_quantizer
                fp8_meta_kwargs["dp_quantizer"] = ctx.dP_quantizer
                fp8_meta_kwargs["dqkv_quantizer"] = ctx.dQKV_quantizer

3198
3199
3200
            else:
                assert False, "FP8 is only supported with Fused Attention!"
        else:
3201
3202
3203
3204
3205
3206
3207
3208
3209
3210
3211
3212
3213
3214
3215
3216
            if ctx.fp8_meta is not None:
                if ctx.is_output_fp8:
                    assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!"
                    ctx.dO_quantizer = dout._quantizer
                    dout = dout._data
                if ctx.is_input_fp8:
                    q = ctx.QKV_quantizer.create_tensor_from_data(
                        q, fake_dtype=ctx.qkv_dtype, internal=True
                    )
                    k = ctx.QKV_quantizer.create_tensor_from_data(
                        k, fake_dtype=ctx.qkv_dtype, internal=True
                    )
                    v = ctx.QKV_quantizer.create_tensor_from_data(
                        v, fake_dtype=ctx.qkv_dtype, internal=True
                    )
                    q, k, v = [x.dequantize(dtype=ctx.qkv_dtype) for x in [q, k, v]]
3217
3218
            if ctx.use_fused_attention:
                fp8_meta_kwargs = {}
3219
                fused_attn_dqkv_dtype = TE_DType[dout_dtype]
3220
3221
3222
3223
3224
3225
                fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"]

        if not ctx.use_fused_attention:
            out = out.view(ctx.batch_size, -1, *out.shape[-2:])
        dout = dout.view(*out.shape)

3226
        chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_before_attn(cp_size, out.device)
3227
3228
3229
        out, dout = flash_attn_a2a_communicate(
            [out, dout], chunk_ids_for_a2a, seq_dim, cp_size, ctx.cp_group, ctx.cp_stream, True
        )
3230
3231
3232
3233
3234
3235
3236
3237
3238
        if not ctx.fp8 and ctx.fp8_meta is not None and ctx.is_output_fp8:
            out = ctx.O_quantizer.create_tensor_from_data(
                out, fake_dtype=ctx.qkv_dtype, internal=True
            )
            dout = ctx.dO_quantizer.create_tensor_from_data(
                dout, fake_dtype=dout_dtype, internal=True
            )
            out = out.dequantize(dtype=ctx.qkv_dtype)
            dout = dout.dequantize(dtype=dout_dtype)
3239

3240
        flash_attn_bwd = None
3241
3242
        if not ctx.use_fused_attention:
            fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale}
3243
            if fa_utils.use_v3:
3244
3245
3246
3247
                if ctx.qkv_format == "thd":
                    flash_attn_bwd = _flash_attn_varlen_bwd_v3
                else:
                    flash_attn_bwd = _flash_attn_bwd_v3
3248
3249
3250
                fa_backward_kwargs["window_size"] = ctx.window_size
                fa_backward_kwargs["deterministic"] = ctx.deterministic
            else:
3251
3252
3253
3254
                if ctx.qkv_format == "thd":
                    flash_attn_bwd = _flash_attn_varlen_bwd
                else:
                    flash_attn_bwd = _flash_attn_bwd
3255
                fa_backward_kwargs["dropout_p"] = ctx.dropout_p
3256
                if fa_utils.use_v3 or (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus):
3257
                    fa_backward_kwargs["window_size"] = ctx.window_size
3258
                elif fa_utils.v2_7_0_plus:
3259
3260
                    fa_backward_kwargs["window_size_left"] = ctx.window_size[0]
                    fa_backward_kwargs["window_size_right"] = ctx.window_size[1]
3261
                if fa_utils.v2_4_plus:
3262
                    fa_backward_kwargs["alibi_slopes"] = None
3263
                if fa_utils.v2_4_1_plus:
3264
                    fa_backward_kwargs["deterministic"] = ctx.deterministic
3265
                if fa_utils.v2_6_0_plus:
3266
                    fa_backward_kwargs["softcap"] = 0.0
3267
3268

        if ctx.use_fused_attention:
3269
3270
3271
3272
3273
3274
3275
3276
3277
3278
3279
3280
3281
3282
3283
3284
3285
3286
3287
3288
            q_part = q
            k_part = k
            v_part = v
            out_part = out
            dout_part = dout

            if ctx.fp8:
                q_part = ctx.QKV_quantizer.create_tensor_from_data(
                    q_part, fake_dtype=ctx.qkv_dtype, internal=True
                )
                k_part = ctx.QKV_quantizer.create_tensor_from_data(
                    k_part, fake_dtype=ctx.qkv_dtype, internal=True
                )
                v_part = ctx.QKV_quantizer.create_tensor_from_data(
                    v_part, fake_dtype=ctx.qkv_dtype, internal=True
                )
                out_part = ctx.O_quantizer.create_tensor_from_data(
                    out_part, fake_dtype=ctx.qkv_dtype, internal=True
                )
                dout_part = ctx.dO_quantizer.create_tensor_from_data(
3289
                    dout_part, fake_dtype=dout_dtype, internal=True
3290
3291
                )

3292
3293
3294
3295
3296
            dq, dk, dv, _ = fused_attn_bwd(
                ctx.max_seqlen_q,
                ctx.max_seqlen_kv,
                cu_seqlens_q,
                cu_seqlens_kv,
3297
3298
3299
3300
3301
3302
                q_part,
                k_part,
                v_part,
                out_part,
                dout_part,
                ctx.qkv_dtype,
3303
3304
3305
3306
3307
3308
3309
3310
3311
3312
3313
3314
3315
3316
                fused_attn_dqkv_dtype,
                aux_ctx_tensors,
                fused_attn_backend,
                cu_seqlens_q_padded=cu_seqlens_q_padded,
                cu_seqlens_kv_padded=cu_seqlens_kv_padded,
                attn_scale=ctx.softmax_scale,
                dropout=ctx.dropout_p,
                qkv_layout=qkv_layout,
                attn_mask_type=ctx.attn_mask_type,
                attn_bias_type=ctx.attn_bias_type,
                window_size=ctx.window_size,
                deterministic=ctx.deterministic,
                **fp8_meta_kwargs,
            )
3317
3318
3319
3320
            if ctx.fp8:
                dq = dq._data
                dk = dk._data
                dv = dv._data
3321
3322
3323
        else:
            softmax_lse, rng_state = aux_ctx_tensors
            dq, dk, dv = [torch.empty_like(x) for x in [q, k, v]]
3324
3325
3326
3327
3328
3329
3330
3331
            fa_backward_args_thd = []
            if ctx.qkv_format == "thd":
                fa_backward_args_thd = [
                    cu_seqlens_q,
                    cu_seqlens_kv,
                    ctx.max_seqlen_q,
                    ctx.max_seqlen_kv,
                ]
3332
            if not fa_utils.use_v3:
3333
3334
                fa_backward_kwargs["rng_state"] = rng_state
            flash_attn_bwd(
3335
3336
3337
3338
3339
3340
3341
3342
3343
                dout,
                q,
                k,
                v,
                out,
                softmax_lse,
                dq,
                dk,
                dv,
3344
                *fa_backward_args_thd,
3345
3346
                causal=causal,
                **fa_backward_kwargs,
3347
3348
            )

3349
        chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size, q.device)
3350
3351
3352
3353
        dq, dk, dv = flash_attn_a2a_communicate(
            [dq, dk, dv], chunk_ids_for_a2a, seq_dim, cp_size, ctx.cp_group, ctx.cp_stream, False
        )

3354
        if ctx.qkv_format == "bshd":
3355
            dq, dk, dv = [x.view(ctx.batch_size, -1, *x.shape[-2:]) for x in [dq, dk, dv]]
3356
        elif ctx.qkv_format == "sbhd":
3357
3358
3359
            dq, dk, dv = [x.view(-1, ctx.batch_size, *x.shape[-2:]) for x in [dq, dk, dv]]

        if ctx.fp8:
3360
3361
3362
3363
3364
3365
3366
3367
3368
            dq = ctx.dQKV_quantizer.create_tensor_from_data(
                dq, fake_dtype=dout_dtype, internal=not ctx.is_input_fp8
            )
            dk = ctx.dQKV_quantizer.create_tensor_from_data(
                dk, fake_dtype=dout_dtype, internal=not ctx.is_input_fp8
            )
            dv = ctx.dQKV_quantizer.create_tensor_from_data(
                dv, fake_dtype=dout_dtype, internal=not ctx.is_input_fp8
            )
3369
            if not ctx.is_input_fp8:
3370
                dq, dk, dv = [x.dequantize(dtype=dout_dtype) for x in [dq, dk, dv]]
3371
        nvtx_range_pop("transformer_engine.AttnFuncWithCPAndQKVOA2A.backward")
3372
3373
3374
3375
3376
3377
3378
3379
3380
3381
3382
3383
3384
3385
3386
3387
3388
3389
3390
3391
3392
3393
3394

        return (
            None,
            dq,
            dk,
            dv,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
3395
3396
3397
            None,
            None,
            None,
3398
            None,
3399
3400
3401
        )


3402
def attn_forward_func_with_cp(
3403
3404
3405
3406
3407
    is_training,
    q,
    k,
    v,
    cu_seqlens_q,
3408
    cu_seqlens_kv,
3409
    max_seqlen_q,
3410
    max_seqlen_kv,
3411
3412
    cu_seqlens_q_padded,
    cu_seqlens_kv_padded,
3413
3414
3415
3416
    dropout_p,
    cp_group,
    cp_global_ranks,
    cp_stream,
3417
    cp_comm_type,
3418
3419
3420
3421
3422
3423
3424
    softmax_scale=None,
    qkv_format="bshd",
    attn_mask_type="causal",
    attn_bias_type="no_bias",
    attn_bias=None,
    deterministic=False,
    use_fused_attention=False,
3425
    window_size=None,
3426
3427
    fp8=False,
    fp8_meta=None,
3428
    quantizers=None,
3429
    pad_between_seqs=False,
3430
) -> torch.Tensor:
3431
3432
3433
3434
    """
    Attention implementation with context parallelism.
    """

3435
3436
3437
3438
3439
3440
3441
3442
3443
3444
3445
3446
3447
3448
3449
3450
    if cp_comm_type == "a2a+p2p":
        assert isinstance(
            cp_group, list
        ), "Hierarchical CP implementation needs multi-level CP groups!"
        assert len(cp_group) == 2, "Current implementation only supports two-level CP groups!"
        if get_distributed_world_size(cp_group[0]) == 1:
            cp_group = cp_group[1]
            cp_comm_type = "p2p"
        elif get_distributed_world_size(cp_group[1]) == 1:
            cp_group = cp_group[0]
            cp_comm_type = "a2a"
    else:
        assert isinstance(
            cp_group, dist_group_type
        ), f"Unsupported process group for CP communication type {cp_comm_type}!"

3451
3452
3453
3454
3455
3456
3457
3458
3459
3460
3461
3462
    assert qkv_format in [
        "bshd",
        "sbhd",
        "thd",
    ], f"QKV format of {qkv_format} is not supported with context parallelism!"
    assert (
        qkv_format != "sbhd" or use_fused_attention
    ), "FlashAttention does not support sbhd format!"
    assert attn_bias is None or (use_fused_attention and "padding" not in attn_mask_type), (
        """Attention bias is only supported with FusedAttention and "causal" """
        """or "no_mask" mask types!"""
    )
3463
    assert qkv_format != "thd" or (
3464
        cu_seqlens_q_padded is not None and cu_seqlens_kv_padded is not None
3465
    ), "cu_seqlens_padded cannot be None with context parallelism + THD format!"
3466
3467
3468

    sliding_window_attn = (
        window_size is not None and window_size != (-1, 0) and window_size != (-1, -1)
3469
    )
3470
3471
3472
3473
    assert not sliding_window_attn or cp_comm_type in [
        "a2a",
        "all_gather",
    ], "The context parallel running configs cannot support sliding window attetnion!"
3474

3475
3476
3477
3478
3479
3480
3481
3482
3483
3484
3485
3486
3487
3488
3489
3490
3491
3492
3493
3494
3495
    args = [
        is_training,
        q,
        k,
        v,
        cu_seqlens_q,
        cu_seqlens_kv,
        max_seqlen_q,
        max_seqlen_kv,
        cu_seqlens_q_padded,
        cu_seqlens_kv_padded,
        dropout_p,
        softmax_scale,
        qkv_format,
        attn_mask_type,
        attn_bias_type,
        attn_bias,
        deterministic,
        use_fused_attention,
    ]

3496
    if cp_comm_type in ["p2p", "a2a+p2p"]:
3497
        args += [fp8, fp8_meta, cp_group, cp_global_ranks, cp_stream, quantizers, pad_between_seqs]
3498
3499
3500
3501
3502
3503
3504
        out = AttnFuncWithCPAndKVP2P.apply(*args)
    elif cp_comm_type == "all_gather":
        args.pop(5)
        args.pop(8)
        args += [window_size, cp_group, cp_stream]
        out = AttnFuncWithCPAndKVAllGather.apply(*args)
    elif cp_comm_type == "a2a":
3505
        args += [window_size, fp8, fp8_meta, cp_group, cp_stream, quantizers]
3506
        out = AttnFuncWithCPAndQKVOA2A.apply(*args)
3507
3508
3509
    else:
        raise ValueError(f"Unsupported communication type: {cp_comm_type}!")

3510
3511
3512
    return out


cyanguwa's avatar
cyanguwa committed
3513
class _SplitAlongDim(torch.autograd.Function):
3514
3515
3516
    """"""

    @staticmethod
3517
3518
3519
3520
3521
    def forward(
        ctx,
        mixed_x_layer: torch.Tensor,
        split_dim: int,
        split_size_or_sections: Union[int, List[int], Tuple[int]],
3522
        squeeze=False,
3523
    ) -> Tuple[torch.Tensor, ...]:
3524
        # pylint: disable=missing-function-docstring
cyanguwa's avatar
cyanguwa committed
3525
3526
        ctx.split_dim = split_dim
        ctx.split_size_or_sections = split_size_or_sections
3527
3528
3529
3530
3531
3532
3533
3534
3535
3536
3537
3538
3539
3540
3541
3542
3543
        if isinstance(mixed_x_layer, Float8TensorBase) and not isinstance(
            mixed_x_layer, Float8Tensor
        ):
            return tuple(
                Float8TensorBase(
                    fp8_scale_inv=mixed_x_layer._scale_inv,
                    fp8_dtype=mixed_x_layer._fp8_dtype,
                    data=x.squeeze(split_dim) if squeeze else x,
                    shape=x.squeeze(split_dim).shape if squeeze else x.shape,
                    quantizer=mixed_x_layer._quantizer,
                )
                for x in torch.split(
                    mixed_x_layer._data,
                    split_size_or_sections=split_size_or_sections,
                    dim=split_dim,
                )
            )
3544
        if isinstance(mixed_x_layer, Float8Tensor):
3545
3546
3547
            return tuple(
                Float8Tensor.make_like(
                    mixed_x_layer,
3548
3549
                    data=x.squeeze(split_dim) if squeeze else x,
                    shape=x.squeeze(split_dim).shape if squeeze else x.shape,
3550
3551
                )
                for x in torch.split(
3552
3553
                    mixed_x_layer._data,
                    split_size_or_sections=split_size_or_sections,
3554
3555
3556
                    dim=split_dim,
                )
            )
3557
3558
3559
3560
        out_list = torch.split(mixed_x_layer, split_size_or_sections, dim=split_dim)
        if squeeze:
            out_list = [x.squeeze(split_dim) for x in out_list]
        return out_list
3561
3562

    @staticmethod
3563
    def backward(ctx, *grad_outputs):
3564
        # pylint: disable=missing-function-docstring
3565
3566
        assert len(grad_outputs) > 0, "No gradients received for backprop!"

cyanguwa's avatar
cyanguwa committed
3567
3568
        if isinstance(ctx.split_size_or_sections, (list, tuple)):
            split_sizes = ctx.split_size_or_sections
3569
3570
3571
            assert len(grad_outputs) == len(
                split_sizes
            ), "Unequal number of gradients vs split sections for backprop!"
cyanguwa's avatar
cyanguwa committed
3572
3573
3574
3575
3576
        if isinstance(ctx.split_size_or_sections, int):
            split_sizes = [ctx.split_size_or_sections] * len(grad_outputs)
        dims = len(grad_outputs[0].shape)
        split_dim = (ctx.split_dim + dims) % dims

3577
3578
3579
3580
3581
3582
3583
3584
        if isinstance(grad_outputs[0], Float8Tensor):
            noop_ok = True
            strides = grad_outputs[0].stride()
            data_ptr = grad_outputs[0]._data.untyped_storage().data_ptr()
            shape = list(grad_outputs[0].shape)
            for i, tensor in enumerate(grad_outputs):
                shape_i = shape
                shape_i[split_dim] = split_sizes[i]
3585
3586
3587
3588
3589
3590
3591
                offset_size = sum(split_sizes[:i]) * np.prod(shape[split_dim + 1 :])
                if (
                    tensor.stride() != strides
                    or list(tensor.shape) != shape_i
                    or tensor._data.untyped_storage().data_ptr() != data_ptr
                    or tensor.storage_offset() != offset_size
                ):
3592
3593
3594
                    noop_ok = False
                    break
            if noop_ok:
3595
3596
3597
                ret = torch.Tensor().to(
                    device=grad_outputs[0].device, dtype=grad_outputs[0]._data.dtype
                )
3598
3599
                new_shape = list(shape)
                new_shape[split_dim] = sum(split_sizes)
3600
3601
3602
3603
3604
                ret.set_(
                    grad_outputs[0]._data.untyped_storage(),
                    grad_outputs[0]._data.storage_offset(),
                    new_shape,
                    strides,
3605
                )
3606
3607
3608
3609
3610
                return (
                    Float8Tensor.make_like(grad_outputs[0], data=ret, shape=ret.shape),
                    None,
                    None,
                )
3611
3612

            grad_outputs_data = [x._data for x in grad_outputs]
3613
            data = torch.cat(grad_outputs_data, dim=split_dim)
3614
            return (
3615
3616
                Float8Tensor.make_like(grad_outputs[0], data=data, shape=data.shape),
                None,
3617
3618
3619
                None,
                None,
            )
3620
3621
        noop_ok = True
        strides = grad_outputs[0].stride()
3622
        data_ptr = grad_outputs[0].untyped_storage().data_ptr()
cyanguwa's avatar
cyanguwa committed
3623
        shape = list(grad_outputs[0].shape)
3624
        for i, tensor in enumerate(grad_outputs):
cyanguwa's avatar
cyanguwa committed
3625
3626
            shape_i = shape
            shape_i[split_dim] = split_sizes[i]
3627
3628
3629
3630
3631
3632
3633
            offset_size = sum(split_sizes[:i]) * np.prod(shape[split_dim + 1 :])
            if (
                tensor.stride() != strides
                or list(tensor.shape) != shape_i
                or tensor.untyped_storage().data_ptr() != data_ptr
                or tensor.storage_offset() != offset_size
            ):
3634
3635
3636
                noop_ok = False
                break
        if noop_ok:
3637
            ret = torch.Tensor().to(device=grad_outputs[0].device, dtype=grad_outputs[0].dtype)
3638
            new_shape = list(shape)
cyanguwa's avatar
cyanguwa committed
3639
            new_shape[split_dim] = sum(split_sizes)
3640
3641
3642
3643
3644
            ret.set_(
                grad_outputs[0].untyped_storage(),
                grad_outputs[0].storage_offset(),
                new_shape,
                strides,
3645
            )
cyanguwa's avatar
cyanguwa committed
3646
            return ret, None, None
3647

3648
        return torch.cat(grad_outputs, dim=split_dim), None, None
3649
3650
3651
3652
3653
3654
3655
3656
3657


class UnfusedDotProductAttention(torch.nn.Module):
    """Parallel attention w/o QKV and Proj Gemms
    BMM1 -> softmax + dropout -> BMM2
    """

    def __init__(
        self,
3658
        softmax_scale: float,
3659
        attention_type: str = "self",
3660
3661
3662
3663
3664
3665
        attention_dropout: float = 0.0,
        attention_dropout_ctx: Optional[Callable] = nullcontext,
        layer_number: Optional[int] = None,
    ) -> None:
        super().__init__()

3666
        self.softmax_scale = softmax_scale
3667
        self.attention_type = attention_type
3668
3669
3670
        self.attention_dropout_ctx = attention_dropout_ctx
        self.layer_number = layer_number

3671
        self.scale_mask_softmax = FusedScaleMaskSoftmax(attention_mask_func)
3672
3673
3674
3675
3676
3677

        # Dropout. Note that for a single iteration, this layer will generate
        # different outputs on different number of parallel partitions but
        # on average it should not be partition dependent.
        self.attention_dropout = torch.nn.Dropout(attention_dropout)

3678
3679
        # An FP16 training trick required for certain GPT-like models.
        self.apply_qk_layer_scaling = (
3680
3681
            bool(int(os.getenv("NVTE_APPLY_QK_LAYER_SCALING", "0"))) and layer_number is not None
        )
3682

3683
3684
3685
3686
3687
    def forward(
        self,
        query_layer: torch.Tensor,
        key_layer: torch.Tensor,
        value_layer: torch.Tensor,
3688
        qkv_layout: str = "sbh3d",
3689
3690
        cu_seqlens_q: Optional[torch.Tensor] = None,  # pylint: disable=unused-argument
        cu_seqlens_kv: Optional[torch.Tensor] = None,  # pylint: disable=unused-argument
3691
        attn_mask_type: str = "causal",
3692
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
3693
        window_size: Optional[Tuple[int, int]] = None,
3694
3695
        core_attention_bias_type: str = "no_bias",
        core_attention_bias: Optional[torch.Tensor] = None,
3696
        alibi_slopes: Optional[torch.Tensor] = None,
3697
    ) -> torch.Tensor:
3698
        """Unfused attention fprop"""
3699
3700
3701
3702
3703
        assert (
            qkv_layout in QKVLayouts
        ), f"UnfusedDotProductAttention does not support qkv_layout = {qkv_layout}!"
        qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
        if qkv_format == "bshd":
3704
            # convert to sbhd and use sbhd implementation for now
3705
3706
3707
            query_layer, key_layer, value_layer = [
                x.transpose(0, 1) for x in [query_layer, key_layer, value_layer]
            ]
3708
3709
3710
3711
3712
        batch_size, max_seqlen_q, max_seqlen_kv = (
            query_layer.shape[1],
            query_layer.shape[0],
            key_layer.shape[0],
        )
3713

3714
3715
3716
3717
3718
3719
3720
3721
3722
        attn_mask_type, attention_mask, actual_seqlens_q, actual_seqlens_kv = (
            dpa_utils.get_full_mask(
                max_seqlen_q,
                max_seqlen_kv,
                attn_mask_type=attn_mask_type,
                attention_mask=attention_mask,
                window_size=window_size,
                attention_type=self.attention_type,
            )
3723
        )
3724

3725
        batch_size, seqlen = query_layer.shape[1], query_layer.shape[0]
3726
        apply_qk_layer_scaling = self.apply_qk_layer_scaling and key_layer.dtype == torch.float16
3727
3728
3729
3730
3731
3732
3733
3734
3735

        # [b, np, sq, sk]
        output_size = (
            query_layer.size(1),
            query_layer.size(2),
            query_layer.size(0),
            key_layer.size(0),
        )

3736
        if key_layer.shape[2] != query_layer.shape[2]:
3737
3738
3739
            assert (
                query_layer.shape[2] % key_layer.shape[2] == 0
            ), "The number of attention heads must be divisible by the number of GQA groups!"
3740
            key_layer = key_layer.repeat_interleave(
3741
3742
                int(query_layer.shape[2] / key_layer.shape[2]), dim=2
            )
3743
            value_layer = value_layer.repeat_interleave(
3744
3745
                int(query_layer.shape[2] / value_layer.shape[2]), dim=2
            )
3746

3747
        # [sq, b, np, hn] -> [sq, b * np, hn]
3748
        query_layer = query_layer.reshape(output_size[2], output_size[0] * output_size[1], -1)
3749
3750
3751
3752
3753
3754
3755
3756
        # [sk, b, np, hn] -> [sk, b * np, hn]
        key_layer = key_layer.reshape(output_size[3], output_size[0] * output_size[1], -1)

        # preallocting result tensor: [b * np, sq, sk]
        matmul_result = torch.empty(
            output_size[0] * output_size[1],
            output_size[2],
            output_size[3],
3757
            dtype=query_layer.dtype,
3758
3759
3760
            device=torch.cuda.current_device(),
        )

3761
        scale = self.softmax_scale
3762
        if apply_qk_layer_scaling:
3763
            scale /= self.layer_number
3764
3765

        # Raw attention scores. [b * np, sq, sk]
3766
3767
3768
3769
3770
3771
        if core_attention_bias_type == "no_bias":
            matmul_result = torch.baddbmm(
                matmul_result,
                query_layer.transpose(0, 1),  # [b * np, sq, hn]
                key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
                beta=0.0,
3772
                alpha=scale,
3773
            ).view(*output_size)
3774
3775
3776
3777
3778
3779
3780

        elif core_attention_bias_type == "pre_scale_bias":
            assert core_attention_bias is not None, "core_attention_bias should not be None!"
            matmul_result = torch.bmm(
                query_layer.transpose(0, 1),  # [b * np, sq, hn]
                key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
            )
3781
            matmul_result = matmul_result.view(*output_size) + core_attention_bias
3782
            matmul_result *= scale
3783

3784
3785
3786
3787
        elif core_attention_bias_type in ["post_scale_bias", "alibi"]:
            if core_attention_bias_type == "post_scale_bias":
                assert core_attention_bias is not None, "core_attention_bias should not be None!"
            if core_attention_bias_type == "alibi":
3788
3789
                _, core_attention_bias = dpa_utils.get_alibi(
                    _alibi_cache,
3790
3791
3792
                    output_size[1],
                    output_size[2],
                    output_size[3],
3793
3794
                    actual_seqlens_q=actual_seqlens_q if "padding" in attn_mask_type else None,
                    actual_seqlens_kv=actual_seqlens_kv if "padding" in attn_mask_type else None,
3795
3796
                    alibi_slopes=alibi_slopes,
                    bottom_right_alignment=attn_mask_type not in ["causal", "padding_causal"],
3797
                )
3798
3799
3800
3801
3802
            matmul_result = torch.baddbmm(
                matmul_result,
                query_layer.transpose(0, 1),  # [b * np, sq, hn]
                key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
                beta=0.0,
3803
                alpha=scale,
3804
            )
3805
3806
            matmul_result = (matmul_result.view(*output_size) + core_attention_bias).to(
                dtype=query_layer.dtype
3807
            )
3808
3809
3810

        # attention scores and attention mask [b, np, sq, sk]
        softmax_scale = self.layer_number if apply_qk_layer_scaling else None
3811
        attention_probs = self.scale_mask_softmax(
3812
            matmul_result, attention_mask, attn_mask_type, softmax_scale
3813
        )
3814

3815
3816
3817
3818
3819
        # mask out the pad positions in softmax results, mostly for the rows (pad tokens from q)
        # the columns (pad tokens from k) are already zeroed out during softmax
        if "padding" in attn_mask_type:
            attention_probs = attention_probs.masked_fill(attention_mask, 0)

3820
3821
3822
3823
3824
3825
3826
3827
3828
3829
3830
3831
3832
3833
3834
        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        with self.attention_dropout_ctx():
            attention_probs = self.attention_dropout(attention_probs)

        # value_layer -> context layer.
        # [sk, b, np, hn] --> [b, np, sq, hn]
        output_size = (
            value_layer.size(1),
            value_layer.size(2),
            query_layer.size(0),
            value_layer.size(3),
        )

        # change view [sk, b * np, hn]
3835
        value_layer = value_layer.reshape(value_layer.size(0), output_size[0] * output_size[1], -1)
3836
3837

        # change view [b * np, sq, sk]
3838
        attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
3839
3840
3841
3842
3843
3844
3845

        # matmul: [b * np, sq, hn]
        context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))

        # change view [b, np, sq, hn]
        context_layer = context_layer.view(*output_size)

3846
        if qkv_format == "sbhd":
3847
3848
            # [b, np, sq, hn] --> [sq, b, np, hn]
            context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
3849

3850
3851
3852
            # [sq, b, np, hn] --> [sq, b, hp]
            context_layer = context_layer.view(seqlen, batch_size, -1)

3853
        if qkv_format == "bshd":
3854
3855
3856
3857
3858
            # [b, np, sq, hn] --> [b, sq, np, hn]
            context_layer = context_layer.permute(0, 2, 1, 3).contiguous()

            # [b, sq, np, hn] --> [b, sq, hp]
            context_layer = context_layer.view(batch_size, seqlen, -1)
3859
3860
3861
3862
3863
3864

        return context_layer


class _PrepareQKVForFA(torch.autograd.Function):
    """This class converts QKV from interleaved (s, b, ...) layout
3865
    to separate contiguous q, k, v tensors in (b, s, ...) layout."""
3866
3867

    @staticmethod
3868
3869
3870
3871
    def forward(
        _ctx: torch.autograd.function.FunctionCtx,  # unused
        query_layer: torch.Tensor,
        key_layer: torch.Tensor,
3872
        value_layer: torch.Tensor,
3873
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
3874
        # pylint: disable=missing-function-docstring
3875
3876
3877
3878
3879
3880
3881
3882
3883
3884
3885
        # All inputs received are non-contiguous tensors.
        # The `query_layer` tensor is used to access the
        # full memory region of the QKV tensor.
        qkv = tex.fa_prepare_fwd(query_layer)
        q, k, v = split_tensor_along_dim(qkv, 0, 3)
        query_layer = torch.squeeze(q, 0)
        key_layer = torch.squeeze(k, 0)
        value_layer = torch.squeeze(v, 0)
        return query_layer, key_layer, value_layer

    @staticmethod
3886
3887
3888
3889
    def backward(
        _ctx: torch.autograd.function.FunctionCtx,  # unused
        dq: torch.Tensor,
        dk: torch.Tensor,
3890
        dv: torch.Tensor,
3891
    ) -> Tuple[Union[torch.Tensor, None], ...]:
3892
        # pylint: disable=missing-function-docstring
3893
3894
3895
3896
        dqkv = tex.fa_prepare_bwd(dq, dk, dv)
        dq, dk, dv = split_tensor_along_dim(dqkv, -1, 3)
        return dq, dk, dv

3897

3898
class FlashAttention(torch.nn.Module):
3899
    """Dot product attention, using HazyResearch flash-attn package:
3900
    https://github.com/Dao-AILab/flash-attention
3901
3902
3903
3904
    """

    def __init__(
        self,
3905
        softmax_scale: float,
3906
3907
        attention_dropout: float = 0.0,
        attention_dropout_ctx: Optional[Callable] = nullcontext,
3908
3909
        attention_type: str = "self",
        layer_number: Optional[int] = None,
3910
        deterministic: bool = False,
3911
3912
3913
    ) -> None:
        super().__init__()

3914
        if fa_utils.is_installed:
3915
            assert (
3916
3917
                fa_utils.version >= fa_utils.version_required
            ), f"FlashAttention minimum version {fa_utils.version_required} is required."
3918
            assert (
3919
3920
                fa_utils.version <= fa_utils.max_version
            ), f"FlashAttention maximum version {fa_utils.max_version} is supported."
3921

3922
        self.softmax_scale = softmax_scale
3923
3924
        self.attention_dropout_ctx = attention_dropout_ctx
        self.attention_dropout = attention_dropout
3925
3926
        self.attention_type = attention_type
        self.layer_number = 1 if layer_number is None else layer_number
3927
        self.deterministic = deterministic
3928
        self.logger = logging.getLogger("FlashAttention")
3929
        self.logger.setLevel(attn_log._log_level)
3930
        if not self.logger.hasHandlers():
3931
            self.logger.addHandler(attn_log._stream_handler)
3932
3933
3934
3935
3936
3937

    def forward(
        self,
        query_layer: torch.Tensor,
        key_layer: torch.Tensor,
        value_layer: torch.Tensor,
3938
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
3939
3940
3941
        qkv_layout: str = "sbh3d",
        cu_seqlens_q: Optional[torch.Tensor] = None,
        cu_seqlens_kv: Optional[torch.Tensor] = None,
3942
3943
        max_seqlen_q: Optional[int] = None,
        max_seqlen_kv: Optional[int] = None,
3944
        attn_mask_type: str = "causal",
3945
        window_size: Optional[Tuple[int, int]] = None,
3946
        alibi_slopes: Optional[torch.Tensor] = None,
3947
        cp_group: Optional[Union[dist_group_type, List[dist_group_type]]] = None,
3948
        cp_global_ranks: List[int] = None,
3949
        cp_stream: torch.cuda.Stream = None,
3950
        cp_comm_type: str = "p2p",
3951
3952
        fp8: bool = False,
        fp8_meta: Optional[Dict[str, Any]] = None,
3953
        quantizers=None,
3954
3955
3956
    ) -> torch.Tensor:
        """flash-attn fprop"""

3957
3958
3959
3960
        assert all(
            x.dtype in [torch.float16, torch.bfloat16] or isinstance(x, Float8Tensor)
            for x in [query_layer, key_layer, value_layer]
        ), "FlashAttention only supports FP16 and BF16 data types, or Float8Tensors."
3961
3962
        assert (
            query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda
3963
        ), "FlashAttention currently only supports CUDA tensors."
3964
3965
        assert (
            qkv_layout in QKVLayouts
3966
        ), f"FlashAttention does not support qkv_layout = {qkv_layout}!"
3967

3968
3969
3970
3971
3972
3973
        cp_size = 1
        if isinstance(cp_group, dist_group_type):
            cp_size = get_distributed_world_size(cp_group)
        elif isinstance(cp_group, list):
            for group in cp_group:
                cp_size *= get_distributed_world_size(group)
3974
        context_parallel = cp_size > 1
3975

3976
        qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
3977

3978
3979
3980
3981
3982
3983
3984
3985
3986
3987
3988
3989
3990
        if all(not isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer]):
            if qkv_format == "sbhd":
                # For now just 128, will make it more general in the future
                if (
                    query_layer.shape[-1] == 128
                    and query_layer.shape[0] * query_layer.shape[1] >= 512
                    and qkv_layout == "sbh3d"
                ):
                    query_layer, key_layer, value_layer = _PrepareQKVForFA.apply(
                        query_layer, key_layer, value_layer
                    )
                else:
                    query_layer, key_layer, value_layer = [
3991
                        x.transpose(0, 1) for x in (query_layer, key_layer, value_layer)
3992
                    ]
3993
            if context_parallel:
3994
                query_layer, key_layer, value_layer = [
3995
3996
3997
3998
3999
                    x.contiguous() for x in (query_layer, key_layer, value_layer)
                ]
        else:
            if qkv_format == "sbhd":
                query_layer._data, key_layer._data, value_layer._data = [
4000
                    x.transpose(0, 1)
4001
4002
                    for x in (query_layer._data, key_layer._data, value_layer._data)
                ]
4003
                query_layer, key_layer, value_layer = [
4004
                    Float8Tensor.make_like(x, data=x._data, shape=x._data.shape)
4005
4006
                    for x in (query_layer, key_layer, value_layer)
                ]
4007
            if context_parallel:
4008
4009
                query_layer._data, key_layer._data, value_layer._data = [
                    x.contiguous() for x in (query_layer._data, key_layer._data, value_layer._data)
4010
                ]
4011

4012
        batch_size = query_layer.shape[0]
4013

4014
        if qkv_format in ["sbhd", "bshd"]:
4015
            max_seqlen_q, max_seqlen_kv = query_layer.shape[1], key_layer.shape[1]
4016
4017
            max_seqlen_q *= cp_size
            max_seqlen_kv *= cp_size
4018
4019
4020

            if "padding" in attn_mask_type:
                assert not context_parallel, "Padding mask not supported with context parallelism!"
4021
4022
                # [b * s, h, d]
                query_layer, key_layer, value_layer = [
4023
                    x.reshape(x.shape[0] * x.shape[1], *x.shape[2:])
4024
4025
4026
4027
4028
4029
4030
                    for x in [query_layer, key_layer, value_layer]
                ]

                if self.attention_type == "self":
                    assert (
                        max_seqlen_q == max_seqlen_kv
                    ), "Maximum sequence length for Q and KV should be the same."
4031
                    if cu_seqlens_q is None:
4032
4033
4034
                        assert (
                            attention_mask is not None
                        ), "Please provide attention_mask for padding!"
4035
4036
4037
                        cu_seqlens_q, indices_q = dpa_utils.get_cu_seqlens_and_indices(
                            attention_mask
                        )
4038
                    else:
4039
                        indices_q = dpa_utils.get_indices(max_seqlen_q, cu_seqlens_q)
4040
                    cu_seqlens_kv = cu_seqlens_q
4041
                    query_layer, key_layer, value_layer = dpa_utils.PackTensors.apply(
4042
                        indices_q, query_layer, key_layer, value_layer
4043
4044
                    )
                else:
4045
                    if cu_seqlens_q is None or cu_seqlens_kv is None:
4046
4047
4048
                        assert (
                            attention_mask is not None
                        ), "Please provide attention_mask for padding!"
4049
4050
4051
4052
4053
4054
                        cu_seqlens_q, indices_q = dpa_utils.get_cu_seqlens_and_indices(
                            attention_mask[0]
                        )
                        cu_seqlens_kv, indices_kv = dpa_utils.get_cu_seqlens_and_indices(
                            attention_mask[1]
                        )
4055
                    else:
4056
4057
4058
4059
4060
4061
                        indices_q = dpa_utils.get_indices(max_seqlen_q, cu_seqlens_q)
                        indices_kv = dpa_utils.get_indices(max_seqlen_kv, cu_seqlens_kv)
                    query_layer = dpa_utils.PackTensors.apply(indices_q, query_layer)
                    key_layer, value_layer = dpa_utils.PackTensors.apply(
                        indices_kv, key_layer, value_layer
                    )
4062
            else:
4063
4064
                # Cumulative sequence lengths for unpadded data
                if cu_seqlens_q is None:
4065
                    cu_seqlens_q = dpa_utils.get_full_cu_seqlens(
4066
4067
4068
4069
4070
                        batch_size,
                        max_seqlen_q,
                        query_layer.device,
                    )
                if cu_seqlens_kv is None:
4071
                    cu_seqlens_kv = dpa_utils.get_full_cu_seqlens(
4072
4073
4074
4075
                        batch_size,
                        max_seqlen_kv,
                        key_layer.device,
                    )
4076
4077
4078
4079
        elif qkv_format == "thd":
            assert (
                cu_seqlens_q is not None and cu_seqlens_kv is not None
            ), "cu_seqlens_q and cu_seqlens_kv can not be None when qkv_format = thd!"
4080
4081
4082
4083
4084
4085
            if max_seqlen_q is None:
                seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
                max_seqlen_q = seqlens_q.max().item()
            if max_seqlen_kv is None:
                seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
                max_seqlen_kv = seqlens_kv.max().item()
4086

4087
4088
4089
        if context_parallel and all(
            not isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer]
        ):
4090
4091
4092
            assert (
                alibi_slopes is None
            ), "Alibi slope bias addition is not supported with context parallelism."
4093
            with self.attention_dropout_ctx():
4094
                output = attn_forward_func_with_cp(
4095
4096
4097
4098
4099
4100
4101
4102
                    self.training,
                    query_layer,
                    key_layer,
                    value_layer,
                    cu_seqlens_q,
                    cu_seqlens_kv,
                    max_seqlen_q,
                    max_seqlen_kv,
4103
4104
                    cu_seqlens_q if qkv_format == "thd" else None,
                    cu_seqlens_kv if qkv_format == "thd" else None,
4105
                    self.attention_dropout if self.training else 0.0,
4106
4107
4108
                    cp_group,
                    cp_global_ranks,
                    cp_stream,
4109
                    cp_comm_type,
4110
                    softmax_scale=self.softmax_scale,
4111
                    qkv_format="bshd" if qkv_format == "sbhd" else qkv_format,
4112
                    attn_mask_type=attn_mask_type,
4113
                    deterministic=self.deterministic,
4114
                    window_size=window_size,
4115
                    quantizers=quantizers,
4116
                    pad_between_seqs=False,
4117
4118
                )
        else:
4119
4120

            from .cpu_offload import CPUOffloadEnabled
4121

4122
4123
4124
4125
4126
4127
            if CPUOffloadEnabled:
                tensor_list = [query_layer, key_layer, value_layer, cu_seqlens_q, cu_seqlens_kv]
                for tensor in tensor_list:
                    if tensor is not None:
                        tensor.activation_offloading = True

4128
            with self.attention_dropout_ctx():
4129
                fa_optional_forward_kwargs = {}
4130
                if fa_utils.v2_3_plus:
4131
                    fa_optional_forward_kwargs["window_size"] = window_size
4132
                if fa_utils.v2_4_plus:
4133
                    fa_optional_forward_kwargs["alibi_slopes"] = alibi_slopes
4134
                if fa_utils.v2_4_1_plus:
4135
                    fa_optional_forward_kwargs["deterministic"] = self.deterministic
4136
4137
                fa_optional_forward_args_thd = []
                if qkv_format in ["bshd", "sbhd"] and "padding" not in attn_mask_type:
4138
                    func = flash_attn_func if not fa_utils.use_v3 else flash_attn_func_v3
4139
                else:
4140
                    if fa_utils.v2_5_7_plus:
4141
                        fa_optional_forward_kwargs["block_table"] = None
4142
                    func = (
4143
                        flash_attn_varlen_func if not fa_utils.use_v3 else flash_attn_varlen_func_v3
4144
4145
4146
4147
4148
                    )
                    fa_optional_forward_args_thd.append(cu_seqlens_q)
                    fa_optional_forward_args_thd.append(cu_seqlens_kv)
                    fa_optional_forward_args_thd.append(max_seqlen_q)
                    fa_optional_forward_args_thd.append(max_seqlen_kv)
4149
                if fa_utils.use_v3:
4150
4151
4152
                    fa_3_optional_forward_kwargs = {}
                    fa_3_optional_forward_kwargs["window_size"] = window_size
                    fa_3_optional_forward_kwargs["deterministic"] = self.deterministic
4153
                    if fp8:
4154
                        QKV_quantizer = quantizers["scaling_fwd"][META_QKV]
4155
                        torch_dtype = get_fp8_torch_dtype(fp8_meta["recipe"], fprop_tensor=True)
4156
                        torch_orig_dtype = query_layer.dtype
4157
4158
4159
4160
4161
4162
4163
4164
4165
4166
4167

                        def convert_to_torch_float8(tensor, dtype):
                            out = torch.Tensor().to(device=tensor.device, dtype=dtype)
                            out.set_(
                                tensor._data.untyped_storage(),
                                tensor._data.storage_offset(),
                                tensor._data.shape,
                                tensor._data.stride(),
                            )
                            return out

4168
4169
4170
4171
4172
                        # "fp8_mha" decides outputs in fp8, while inputs are inferred from
                        # the real dtype
                        assert isinstance(key_layer, query_layer.__class__) and isinstance(
                            value_layer, query_layer.__class__
                        ), "q, k, and v must have the same type."
4173
                        if not isinstance(query_layer, Float8Tensor):
4174
                            query_layer, key_layer, value_layer = (
4175
                                QKV_quantizer(x) for x in [query_layer, key_layer, value_layer]
4176
                            )
4177
4178
                        fa_3_optional_forward_kwargs["descale_q"] = (
                            query_layer._scale_inv.unsqueeze(0)
4179
                        )
4180
4181
                        fa_3_optional_forward_kwargs["descale_k"] = key_layer._scale_inv.unsqueeze(
                            0
4182
                        )
4183
4184
                        fa_3_optional_forward_kwargs["descale_v"] = (
                            value_layer._scale_inv.unsqueeze(0)
4185
                        )
4186
4187
4188
                        query_layer, key_layer, value_layer = (
                            convert_to_torch_float8(x, torch_dtype)
                            for x in [query_layer, key_layer, value_layer]
4189
                        )
4190
4191
4192
4193
4194
4195
4196
4197
4198
4199
4200
                    try:
                        output, _ = func(
                            query_layer,
                            key_layer,
                            value_layer,
                            *fa_optional_forward_args_thd,
                            softmax_scale=self.softmax_scale,
                            causal="causal" in attn_mask_type,
                            **fa_3_optional_forward_kwargs,
                        )
                    except TypeError as e:
4201
                        if fa_utils.v3_0_0_beta:
4202
4203
4204
4205
                            e.args = (
                                e.args[0]
                                + ". Please update your flash-attn v3 (beta) installation as it "
                                + "may have added more supported arguments to its API. \n"
4206
                                + fa_utils.v3_installation_steps,
4207
4208
4209
4210
4211
4212
4213
4214
                            ) + e.args[1:]
                        raise

                    if fp8:
                        output = output.to(dtype=torch_orig_dtype)
                    if fp8 and fp8_meta["recipe"].fp8_mha:
                        O_quantizer = quantizers["scaling_fwd"][META_O]
                        output = O_quantizer(output)
4215
                else:
4216
4217
4218
4219
4220
4221
4222
4223
4224
                    output = func(
                        query_layer,
                        key_layer,
                        value_layer,
                        *fa_optional_forward_args_thd,
                        self.attention_dropout if self.training else 0.0,
                        softmax_scale=self.softmax_scale,
                        causal="causal" in attn_mask_type,
                        **fa_optional_forward_kwargs,
4225
                    )
4226

4227
        if qkv_format in ["sbhd", "bshd"] and "padding" in attn_mask_type:
4228
            output = dpa_utils.UnpackTensor.apply(indices_q, batch_size * max_seqlen_q, output)
4229
4230
4231
4232
4233
4234
4235
4236
4237
4238
4239
4240
4241
4242
4243
4244
4245
4246
4247
4248
4249
4250
4251
4252
4253
4254
4255
4256
4257
4258
4259
4260
4261
4262
4263
4264
4265
4266
4267
4268
4269
4270
4271
4272
4273
4274
4275
4276
4277
4278
4279
4280

        if qkv_format == "sbhd":
            # (bs)hd -> bs(hd) -> sb(hd)
            if fp8 and fp8_meta["recipe"].fp8_mha:
                output_data = (
                    output._data.reshape(batch_size, max_seqlen_q // cp_size, -1)
                    .transpose(0, 1)
                    .contiguous()
                )
                output = Float8Tensor.make_like(
                    output,
                    data=output_data,
                    shape=output_data.shape,
                )
            else:
                output = output.view(batch_size, max_seqlen_q // cp_size, -1).transpose(0, 1)
        elif qkv_format == "bshd":
            # (bs)hd -> bs(hd)
            output = output.reshape(batch_size, max_seqlen_q // cp_size, -1)
        elif qkv_format == "thd":
            # thd -> t(hd)
            output = output.reshape(output.shape[0], -1)

        return output.contiguous()


def _combine_tensors(
    tensors: List[torch.Tensor],
    dim: int,
) -> torch.Tensor:
    """Combine tensors along a particular dimension"""

    num_tensors = len(tensors)
    new_shape = list(tensors[0].shape)
    new_shape.insert(dim, num_tensors)
    if isinstance(tensors[0], Float8Tensor):
        new_stride = list(tensors[0]._data.stride())
        new_stride.insert(dim, int(new_stride[dim - 1] / num_tensors))
        combined_tensor = torch.Tensor().to(device=tensors[0].device, dtype=tensors[0]._data.dtype)
        combined_tensor.set_(
            tensors[0]._data.untyped_storage(),
            tensors[0]._data.storage_offset(),
            new_shape,
            new_stride,
        )
        combined_tensor = Float8Tensor.make_like(tensors[0], data=combined_tensor, shape=new_shape)
    else:
        new_stride = list(tensors[0].stride())
        new_stride.insert(dim, int(new_stride[dim - 1] / num_tensors))
        combined_tensor = torch.Tensor().to(device=tensors[0].device, dtype=tensors[0].dtype)
        combined_tensor.set_(
            tensors[0].untyped_storage(), tensors[0].storage_offset(), new_shape, new_stride
4281
4282
        )

4283
4284
    return combined_tensor

4285

4286
4287
4288
4289
class FusedAttnFunc(torch.autograd.Function):
    """Function for FusedAttention with separate Q, K, V tensors"""

    @staticmethod
4290
4291
4292
4293
4294
4295
4296
    def forward(
        ctx,
        is_training,
        max_seqlen_q,
        max_seqlen_kv,
        cu_seqlens_q,
        cu_seqlens_kv,
4297
4298
        cu_seqlens_q_padded,
        cu_seqlens_kv_padded,
4299
4300
4301
4302
4303
4304
4305
4306
4307
4308
        q,
        k,
        v,
        attn_bias,
        attn_scale,
        dropout_p,
        fast_zero_fill,
        qkv_layout,
        attn_bias_type,
        attn_mask_type,
4309
        window_size,
4310
4311
4312
4313
4314
        rng_gen,
        fused_attention_backend,
        use_FAv2_bwd,
        fp8,
        fp8_meta,
4315
        quantizers,
4316
        deterministic,
4317
    ):
4318
        # pylint: disable=missing-function-docstring
4319
        # "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype
4320
        is_input_fp8 = False
4321
        is_output_fp8 = fp8_meta["recipe"].fp8_mha if "recipe" in fp8_meta else False
4322
4323
4324
4325

        # FP16/BF16 attn:                  fake_dtype = torch.float16 or torch.bfloat16
        # FP8 attn, is_output_fp8 = False: fake_dtype = torch.float16 or torch.bfloat16
        # FP8 attn, is_output_fp8 = True:  fake_dtype = torch.float8_e4m3fn
4326
4327
4328
        fake_dtype = q.dtype

        QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer = (
4329
            dpa_utils.get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False)
4330
        )
4331
4332
        if fp8:
            fused_attention_backend = FusedAttnBackend["FP8"]
4333
4334
4335
            assert isinstance(k, q.__class__) and isinstance(
                v, q.__class__
            ), "q, k, and v must have the same type."
4336

4337
            is_input_fp8 = isinstance(q, Float8Tensor)
4338
            q_fp8, k_fp8, v_fp8 = None, None, None
4339
            if is_input_fp8:
4340
                q_fp8, k_fp8, v_fp8 = q, k, v
4341
4342
            else:
                # 1: qkv packed, 2: kv packed, 3: qkv separate
4343
                qkv_group = len(qkv_layout.split("_"))
4344
4345
4346
4347
4348
4349
4350
4351
4352
4353
4354
4355
4356
4357
4358
4359
4360
4361
4362
4363
                match qkv_group:
                    case 1:
                        dim = qkv_layout.find("3")
                        qkv = _combine_tensors([q, k, v], dim)
                        qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1])
                        qkv_fp8 = QKV_quantizer(qkv)
                        q_fp8, k_fp8, v_fp8 = _SplitAlongDim.apply(qkv_fp8, dim, [1, 1, 1], True)
                    case 2:
                        q_fp8 = QKV_quantizer(q)
                        dim = qkv_layout.split("_")[1].find("2")
                        kv = _combine_tensors([k, v], dim)
                        kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1])
                        kv_fp8 = QKV_quantizer(kv_c)
                        k_fp8, v_fp8 = _SplitAlongDim.apply(kv_fp8, dim, [1, 1], True)
                    case 3:
                        q_fp8 = QKV_quantizer(q)
                        k_fp8 = QKV_quantizer(k)
                        v_fp8 = QKV_quantizer(v)
                    case _:
                        raise "Invalid qkv_layout " + qkv_layout
4364
            # q_fp8, k_fp8, v_fp8, out_fp8: torch.float8_e4m3fn
4365
            out_fp8, aux_ctx_tensors = fused_attn_fwd(
4366
4367
4368
4369
4370
4371
4372
4373
                is_training,
                max_seqlen_q,
                max_seqlen_kv,
                cu_seqlens_q,
                cu_seqlens_kv,
                q_fp8,
                k_fp8,
                v_fp8,
4374
                fake_dtype,
4375
4376
                fused_attention_backend,
                attn_bias,
4377
4378
                cu_seqlens_q_padded,
                cu_seqlens_kv_padded,
4379
4380
                S_quantizer,
                O_quantizer,
4381
4382
4383
4384
4385
4386
                attn_scale,
                dropout_p,
                fast_zero_fill,
                qkv_layout,
                attn_bias_type,
                attn_mask_type,
4387
                window_size,
4388
4389
                rng_gen,
            )
4390
            if is_output_fp8:
4391
                out_ret = out_fp8
4392
            else:
4393
                out_ret = out_fp8.dequantize().view(out_fp8.shape)
4394
4395
            # is_output_fp8 = False: out_save.dtype = torch.float16 or torch.bfloat16
            # is_output_fp8 = True:  out_save.dtype = torch.float8_e4m3fn
4396
4397
            out_save = out_ret

4398
            if not int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
4399
                # 1: qkv packed, 2: kv packed, 3: qkv separate
4400
4401
4402
4403
4404
4405
                if is_input_fp8:
                    qkv_group = len(qkv_layout.split("_"))
                    if qkv_group == 1:
                        dim = qkv_layout.find("3")
                        qkv = _combine_tensors([q, k, v], dim)
                        qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1])
4406
4407
                        qkv_no_fp8 = qkv_c.dequantize().view(qkv.shape)
                        q, k, v = _SplitAlongDim.apply(qkv_no_fp8, dim, [1, 1, 1], True)
4408
                    if qkv_group == 2:
4409
                        q = q.dequantize()
4410
4411
4412
                        dim = qkv_layout.split("_")[1].find("2")
                        kv = _combine_tensors([k, v], dim)
                        kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1])
4413
4414
                        kv_no_fp8 = kv.dequantize()
                        k, v = _SplitAlongDim.apply(kv_no_fp8, dim, [1, 1], True)
4415
                    if qkv_group == 3:
4416
4417
4418
                        q = q.dequantize()
                        k = k.dequantize()
                        v = v.dequantize()
4419
                if is_output_fp8:
4420
4421
4422
                    out_save = out_fp8.dequantize()

            fp8_tensors = (q_fp8, k_fp8, v_fp8, out_fp8)
4423
        else:
4424
            # q, k, v, out_ret: torch.float16 or torch.bfloat16
4425
            out_ret, aux_ctx_tensors = fused_attn_fwd(
4426
4427
4428
4429
4430
4431
4432
4433
                is_training,
                max_seqlen_q,
                max_seqlen_kv,
                cu_seqlens_q,
                cu_seqlens_kv,
                q,
                k,
                v,
4434
                fake_dtype,
4435
4436
                fused_attention_backend,
                attn_bias,
4437
4438
                cu_seqlens_q_padded,
                cu_seqlens_kv_padded,
4439
4440
                None,  # s_quantizer
                None,  # o_quantizer
4441
4442
4443
4444
4445
4446
                attn_scale,
                dropout_p,
                fast_zero_fill,
                qkv_layout,
                attn_bias_type,
                attn_mask_type,
4447
                window_size,
4448
4449
                rng_gen,
            )
4450
            out_save = out_ret
4451
            fp8_tensors = (None, None, None, None)
4452

4453
4454
        ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1"))

4455
        from .cpu_offload import CPUOffloadEnabled
4456

4457
        if CPUOffloadEnabled:
4458
4459
4460
4461
4462
4463
4464
            if ctx.fp8:
                tensor_list = fp8_tensors
            else:
                tensor_list = [q, k, v, out_save]

            tensor_list.extend(aux_ctx_tensors)

4465
            qkv_layout = "sbhd_sbhd_sbhd"
4466
4467
4468
4469
            for tensor in tensor_list:
                if tensor is not None:
                    tensor.activation_offloading = True

4470
4471
        ctx.is_input_fp8 = is_input_fp8
        ctx.is_output_fp8 = is_output_fp8
4472
        qkvo_tensors = (q, k, v, out_save) if not ctx.fp8 else (None, None, None, None)
4473
4474
        tensors_to_save, tensor_objects = prepare_for_saving(
            *fp8_tensors,
4475
4476
4477
            *qkvo_tensors,
            cu_seqlens_q,
            cu_seqlens_kv,
4478
4479
            cu_seqlens_q_padded,
            cu_seqlens_kv_padded,
4480
4481
            *aux_ctx_tensors,
        )
4482
4483
        ctx.save_for_backward(*tensors_to_save)
        ctx.tensor_objects = tensor_objects
4484
        ctx.fp8_meta = fp8_meta
4485
4486
4487
4488
4489
4490

        ctx.dQKV_quantizer = dQKV_quantizer
        ctx.dO_quantizer = dO_quantizer
        ctx.dP_quantizer = dP_quantizer
        ctx.S_quantizer = S_quantizer

4491
4492
4493
4494
4495
4496
4497
4498
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_kv = max_seqlen_kv
        ctx.attn_scale = attn_scale
        ctx.dropout_p = dropout_p
        ctx.fast_zero_fill = fast_zero_fill
        ctx.qkv_layout = qkv_layout
        ctx.attn_bias_type = attn_bias_type
        ctx.attn_mask_type = attn_mask_type
4499
        ctx.window_size = window_size
4500
        ctx.fused_attention_backend = (
4501
            fused_attention_backend if ctx.fp8 else FusedAttnBackend["F16_arbitrary_seqlen"]
4502
        )
4503
        ctx.use_FAv2_bwd = use_FAv2_bwd
4504
        ctx.deterministic = deterministic
4505

4506
        return out_ret
4507
4508
4509

    @staticmethod
    def backward(ctx, d_out):
4510
        # pylint: disable=missing-function-docstring
4511
        if ctx.is_output_fp8:
4512
4513
4514
            assert isinstance(
                d_out, Float8Tensor
            ), "Gradient of the DPA output must be in Float8Tensor type for FP8 MHA."
4515

4516
4517
4518
4519
4520
        # FP16/BF16 attn:                  fake_dtype = torch.float16 or torch.bfloat16
        # FP8 attn, is_output_fp8 = False: fake_dtype = torch.float16 or torch.bfloat16
        # FP8 attn, is_output_fp8 = True:  fake_dtype = torch.float8_e5m2
        fake_dtype = d_out.dtype

4521
        d_out = d_out.contiguous()
4522
        (
4523
4524
4525
4526
            q_fp8,
            k_fp8,
            v_fp8,
            out_fp8,
4527
4528
4529
4530
4531
4532
            q,
            k,
            v,
            out,
            cu_seqlens_q,
            cu_seqlens_kv,
4533
4534
            cu_seqlens_q_padded,
            cu_seqlens_kv_padded,
4535
4536
4537
4538
4539
            *other_tensors,
        ) = restore_from_saved(ctx.tensor_objects, ctx.saved_tensors)

        aux_ctx_tensors = other_tensors

4540
4541
        if not aux_ctx_tensors[0].is_contiguous():
            aux_ctx_tensors[0] = aux_ctx_tensors[0].contiguous()
4542
        rest = [None]
4543
        if ctx.use_FAv2_bwd:
4544
            softmax_lse, rng_state = aux_ctx_tensors
4545
4546
4547
            dq = torch.empty_like(q)
            dk = torch.empty_like(k)
            dv = torch.empty_like(v)
4548
            d_out, q, k, v, out = [maybe_contiguous(x) for x in (d_out, q, k, v, out)]
4549
            flash_attn_cuda_bwd(
4550
4551
4552
4553
4554
4555
4556
4557
4558
4559
4560
4561
4562
4563
4564
4565
4566
4567
4568
                d_out,
                q,
                k,
                v,
                out,
                softmax_lse,
                dq,
                dk,
                dv,
                cu_seqlens_q,
                cu_seqlens_kv,
                ctx.max_seqlen_q,
                ctx.max_seqlen_kv,
                ctx.dropout_p,
                ctx.attn_scale,
                False,
                "causal" in ctx.attn_mask_type,
                None,
                rng_state,
4569
            )
4570
4571
4572
            dq = dq[..., : d_out.shape[-1]]
            dk = dk[..., : d_out.shape[-1]]
            dv = dv[..., : d_out.shape[-1]]
4573
        else:
4574
4575
            with torch.cuda.nvtx.range("_FusedAttn"):
                if ctx.fp8:
4576
                    if ctx.is_output_fp8:
4577
4578
                        d_out_fp8 = d_out
                    else:
4579
                        d_out_fp8 = ctx.dO_quantizer(d_out)
4580
4581
4582
                    dqkv_dtype = TE_DType[d_out_fp8._data.dtype]
                    # q_fp8, k_fp8, v_fp8, out_fp8:      torch.float8_e4m3fn
                    # d_out_fp8, dq_fp8, dk_fp8, dv_fp8: torch.float8_e5m2
4583
                    dq_fp8, dk_fp8, dv_fp8, *rest = fused_attn_bwd(
4584
4585
4586
4587
4588
4589
4590
4591
4592
                        ctx.max_seqlen_q,
                        ctx.max_seqlen_kv,
                        cu_seqlens_q,
                        cu_seqlens_kv,
                        q_fp8,
                        k_fp8,
                        v_fp8,
                        out_fp8,
                        d_out_fp8,
4593
4594
                        fake_dtype,
                        dqkv_dtype,
4595
                        aux_ctx_tensors,
4596
                        ctx.fused_attention_backend,
4597
4598
                        cu_seqlens_q_padded,
                        cu_seqlens_kv_padded,
4599
4600
4601
                        ctx.S_quantizer,
                        ctx.dP_quantizer,
                        ctx.dQKV_quantizer,
4602
4603
4604
4605
4606
4607
                        ctx.attn_scale,
                        ctx.dropout_p,
                        ctx.fast_zero_fill,
                        ctx.qkv_layout,
                        ctx.attn_bias_type,
                        ctx.attn_mask_type,
4608
4609
                        ctx.window_size,
                        ctx.deterministic,
4610
                    )
4611

4612
4613
                    # is_input_fp8 = False: dq, dk, dv: torch.float16 or torch.bfloat16
                    # is_input_fp8 = True:  dq, dk, dv: torch.float8_e5m2
4614
                    if not ctx.is_input_fp8:
4615
                        qkv_group = len(ctx.qkv_layout.split("_"))
4616
                        if qkv_group == 1:
4617
                            dim = ctx.qkv_layout.find("3")
4618
4619
                            dqkv_fp8_data = _combine_tensors(
                                [dq_fp8._data, dk_fp8._data, dv_fp8._data], dim
4620
                            )
4621
4622
4623
4624
4625
                            dqkv_fp8 = dq_fp8.make_like(
                                tensor=dq_fp8, data=dqkv_fp8_data, shape=dqkv_fp8_data.shape
                            )
                            dqkv = dqkv_fp8.dequantize()
                            dq, dk, dv = _SplitAlongDim.apply(dqkv, dim, [1, 1, 1], True)
4626
                        if qkv_group == 2:
4627
                            dq = dq_fp8.dequantize()
4628
4629
4630
4631
4632
                            dim = ctx.qkv_layout.split("_")[1].find("2")
                            dkv_fp8 = _combine_tensors([dk_fp8, dv_fp8], dim)
                            dkv_c_fp8 = dkv_fp8.view(
                                -1, dkv_fp8.shape[-3] * dkv_fp8.shape[-2] * dkv_fp8.shape[-1]
                            )
4633
4634
                            dkv = dkv_c_fp8.dequantize()
                            dk, dv = _SplitAlongDim.apply(dkv, dim, [1, 1], True)
4635
                        if qkv_group == 3:
4636
4637
4638
4639
4640
                            dq = dq_fp8.dequantize()
                            dk = dk_fp8.dequantize()
                            dv = dv_fp8.dequantize()
                    else:
                        dq, dk, dv = dq_fp8, dk_fp8, dv_fp8
4641
                else:
4642
4643
                    if isinstance(d_out, QuantizedTensor):
                        d_out = d_out.dequantize()
4644
4645
                    dqkv_dtype = TE_DType[d_out.dtype]
                    # q, k, v, out, d_out, dq, dk, dv: torch.float16 or torch.bfloat16
4646
                    dq, dk, dv, *rest = fused_attn_bwd(
4647
4648
4649
4650
4651
4652
4653
4654
4655
                        ctx.max_seqlen_q,
                        ctx.max_seqlen_kv,
                        cu_seqlens_q,
                        cu_seqlens_kv,
                        q,
                        k,
                        v,
                        out,
                        d_out,
4656
4657
                        fake_dtype,
                        dqkv_dtype,
4658
                        aux_ctx_tensors,
4659
                        ctx.fused_attention_backend,
4660
4661
                        cu_seqlens_q_padded,
                        cu_seqlens_kv_padded,
4662
4663
4664
4665
4666
4667
4668
4669
4670
                        None,
                        None,
                        None,
                        ctx.attn_scale,
                        ctx.dropout_p,
                        ctx.fast_zero_fill,
                        ctx.qkv_layout,
                        ctx.attn_bias_type,
                        ctx.attn_mask_type,
4671
4672
                        ctx.window_size,
                        ctx.deterministic,
4673
                    )
4674

4675
4676
        # if no_bias or alibi, return dqkv
        if ctx.attn_bias_type in ["no_bias", "alibi"]:
4677
4678
4679
4680
4681
4682
4683
4684
4685
4686
4687
4688
4689
4690
4691
4692
4693
4694
4695
4696
4697
4698
4699
4700
4701
4702
            return (
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                dq,
                dk,
                dv,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
4703
4704
                None,
                None,
4705
            )
4706
        # else, return (dqkv, dbias)
4707
4708
4709
4710
4711
4712
4713
4714
4715
4716
4717
4718
4719
4720
4721
4722
4723
4724
4725
4726
4727
4728
4729
4730
4731
        return (
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            dq,
            dk,
            dv,
            rest[0],
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
4732
4733
            None,
            None,
4734
            None,
4735
        )
4736

4737

4738
class FusedAttention(torch.nn.Module):
4739
4740
4741
4742
4743
4744
4745
4746
4747
    """Dot product attention, with multiple backends:

    1. FusedAttnBackend["F16_max512_seqlen"]
       cuDNN based fused attention for FP16/BF16 and <=512 sequence length.
    2. FusedAttnBackend["F16_arbitrary_seqlen"]
       cuDNN based fused attention for FP16/BF16 and any sequence length.

    Support matrix:

4748
4749
4750
4751
    | backend       | 1                       | 2                              |
    | flash based   | no                      | yes                            |
    | cuDNN based   | yes                     | yes                            |
    | qkv dtype     | fp16/bf16               | fp16/bf16                      |
4752
    | attn_type     | self/cross              | self/cross                     |
4753
    | qkv_layout    |                         |                                |
4754
    |  - (q,k,v)    | sb3hd, bs3hd            | sb3hd, bs3hd, sbh3d, bsh3d     |
4755
    |               | sbhd_sb2hd, bshd_bs2hd  | sbhd_sb2hd, bshd_bs2hd         |
4756
4757
    |               | bshd_bshd_bshd          | sbhd_sbh2d, bshd_bsh2d         |
    |               |                         | sbhd_sbhd_sbhd, bshd_bshd_bshd |
4758
4759
    | mask_type     | causal/padding/no_mask  | causal/padding/no_mask         |
    | bias_type     | post_scale_bias/no_bias | post_scale_bias/alibi/no_bias  |
4760
    | dropout       | yes                     | yes                            |
4761
4762
    | max_seqlen    | <=512, multiple of 64   | any, multiple of 64            |
    | head_dim      | 64                      | <=128, multiple of 8           |
4763
    | output dtype  | fp16/bf16               | fp16/bf16                      |
4764
4765
4766
4767
    """

    def __init__(
        self,
4768
        softmax_scale: float,
4769
4770
4771
        attention_dropout: float = 0.0,
        attention_dropout_ctx: Optional[Callable] = nullcontext,
        attention_type: str = "self",
4772
4773
        layer_number: Optional[int] = None,
        deterministic: bool = False,
4774
4775
4776
    ) -> None:
        super().__init__()

4777
        self.softmax_scale = softmax_scale
4778
4779
4780
        self.attention_dropout = attention_dropout
        self.attention_dropout_ctx = attention_dropout_ctx
        self.attention_type = attention_type
4781
4782
4783
        self.use_FAv2_bwd = os.getenv(
            "NVTE_FUSED_ATTN_USE_FAv2_BWD", "0"
        ) == "1" and get_device_compute_capability() == (9, 0)
4784
        self.layer_number = 1 if layer_number is None else layer_number
4785
        self.deterministic = deterministic
4786

4787
        def remove_extra_states_check(self, incompatible_keys):  # pylint: disable=unused-argument
4788
4789
            """
            Temporarily remove fused_attention._extra_state as a missing key
4790
            or an unexpected key when loading Transformer Engine checkpoints.
4791
4792
            Please store FP8 metadata as DotProductAttention's _extra_state,
            rather than FusedAttention's _extra_state. This hook will be
4793
            phased out in Transformer Engine 2.0.
4794
4795
            """
            for key in incompatible_keys.missing_keys:
4796
                if "fused_attention._extra_state" in key:
4797
                    incompatible_keys.missing_keys.remove(key)
4798
4799
4800
4801
4802
4803
4804
            for key in incompatible_keys.unexpected_keys:
                if "fused_attention._extra_state" in key:
                    incompatible_keys.unexpected_keys.remove(key)
                    warnings.warn(
                        "fused_attention._extra_state is not loaded from checkpoint. Please map "
                        "FusedAttention's _extra_state to DotProductAttention's _extra_state."
                    )
4805

4806
4807
        self.register_load_state_dict_post_hook(remove_extra_states_check)

4808
    @no_torch_dynamo()
4809
4810
4811
4812
4813
    def forward(
        self,
        query_layer: torch.Tensor,
        key_layer: torch.Tensor,
        value_layer: torch.Tensor,
4814
4815
4816
        qkv_layout: str = "sbh3d",
        cu_seqlens_q: Optional[torch.Tensor] = None,
        cu_seqlens_kv: Optional[torch.Tensor] = None,
4817
4818
        cu_seqlens_q_padded: Optional[torch.Tensor] = None,
        cu_seqlens_kv_padded: Optional[torch.Tensor] = None,
4819
4820
        max_seqlen_q: Optional[int] = None,
        max_seqlen_kv: Optional[int] = None,
4821
        attn_mask_type: str = "causal",
4822
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
4823
        window_size: Optional[Tuple[int, int]] = None,
4824
        fused_attention_backend: tex.NVTE_Fused_Attn_Backend = tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend,
4825
4826
4827
        core_attention_bias_type: str = "no_bias",
        core_attention_bias: Optional[torch.Tensor] = None,
        fast_zero_fill: bool = True,
4828
        cp_group: Optional[Union[dist_group_type, List[dist_group_type]]] = None,
4829
4830
        cp_global_ranks: List[int] = None,
        cp_stream: torch.cuda.Stream = None,
4831
        cp_comm_type: str = "p2p",
4832
4833
        fp8: bool = False,
        fp8_meta: Optional[Dict[str, Any]] = None,
4834
        quantizers=None,
4835
        pad_between_seqs: bool = False,
4836
4837
    ) -> torch.Tensor:
        """fused attention fprop"""
4838
4839
4840
        assert (
            fused_attention_backend != tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend
        ), "No fused attention backend supports this input combination!"
4841
4842
4843
4844
        assert all(
            x.dtype in [torch.float16, torch.bfloat16] or isinstance(x, Float8Tensor)
            for x in [query_layer, key_layer, value_layer]
        ), "FusedAttention only supports FP16 and BF16 data types, or Float8Tensors."
4845
4846
        assert (
            query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda
4847
        ), "FusedAttention only supports CUDA tensors."
4848
4849
        assert (
            qkv_layout in QKVLayouts
4850
        ), f"FusedAttention does not support qkv_layout = {qkv_layout}!"
4851

4852
4853
4854
4855
4856
4857
        cp_size = 1
        if isinstance(cp_group, dist_group_type):
            cp_size = get_distributed_world_size(cp_group)
        elif isinstance(cp_group, list):
            for group in cp_group:
                cp_size *= get_distributed_world_size(group)
4858
        context_parallel = cp_size > 1
4859

4860
        qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
4861

4862
4863
        if qkv_format in ["sbhd", "bshd"]:
            if qkv_format == "sbhd":
4864
                batch_size, max_seqlen_q, max_seqlen_kv = (
4865
4866
4867
4868
4869
                    query_layer.shape[1],
                    query_layer.shape[0],
                    key_layer.shape[0],
                )
            if qkv_format == "bshd":
4870
                batch_size, max_seqlen_q, max_seqlen_kv = (
4871
4872
4873
4874
                    query_layer.shape[0],
                    query_layer.shape[1],
                    key_layer.shape[1],
                )
4875
4876
            max_seqlen_q *= cp_size
            max_seqlen_kv *= cp_size
4877
            if "padding" in attn_mask_type:
4878
4879
                assert not context_parallel, "Padding mask not supported with context parallelism!"

4880
4881
4882
4883
4884
                if cu_seqlens_q is None or cu_seqlens_kv is None:
                    if attention_mask is None:
                        raise RuntimeError(
                            "Please provide attention_mask or cu_seqlens for padding!"
                        )
4885
                    if self.attention_type == "self":
4886
                        cu_seqlens_q = dpa_utils.get_cu_seqlens(attention_mask)
4887
                        cu_seqlens_kv = cu_seqlens_q
4888
                    else:
4889
4890
                        cu_seqlens_q = dpa_utils.get_cu_seqlens(attention_mask[0])
                        cu_seqlens_kv = dpa_utils.get_cu_seqlens(attention_mask[1])
4891
            else:
4892
                if cu_seqlens_q is None:
4893
                    cu_seqlens_q = dpa_utils.get_full_cu_seqlens(
4894
4895
4896
4897
4898
                        batch_size,
                        max_seqlen_q,
                        query_layer.device,
                    )
                if cu_seqlens_kv is None:
4899
                    cu_seqlens_kv = dpa_utils.get_full_cu_seqlens(
4900
4901
4902
4903
                        batch_size,
                        max_seqlen_kv,
                        key_layer.device,
                    )
4904
4905
4906
        if qkv_format == "thd":
            assert (
                max_seqlen_q is not None
4907
4908
4909
                and max_seqlen_kv is not None
                and cu_seqlens_q is not None
                and cu_seqlens_kv is not None
4910
            ), "max_seqlen_q/kv and cu_seqlens_q/kv can not be None when qkv_format is thd!"
4911

4912
        if qkv_format == "thd" and (cu_seqlens_q_padded is None or cu_seqlens_kv_padded is None):
4913
4914
            cu_seqlens_q_padded = cu_seqlens_q
            cu_seqlens_kv_padded = cu_seqlens_kv
4915

4916
4917
4918
4919
4920
        use_FAv2_bwd = (
            self.use_FAv2_bwd
            and (core_attention_bias_type == "no_bias")
            and (fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen)
        )
4921

4922
4923
4924
4925
4926
4927
4928
4929
4930
4931
4932
        if fp8:
            assert fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_FP8, (
                f"cuDNN attention sub-backend {int(tex.NVTE_Fused_Attn_Backend.NVTE_FP8)}"
                " is required for FP8 attention!"
            )
            assert fp8_meta is not None, "FP8 metadata fp8_meta is required for FP8 attention!"
            assert not context_parallel or fp8_meta["recipe"].reduce_amax, (
                "Amax reduction across TP+CP group is necessary when using context parallelism with"
                " FP8!"
            )

4933
        if context_parallel:
4934
            assert (
4935
4936
                fp8
                or fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen
4937
4938
4939
4940
4941
4942
4943
            ), f"{fused_attention_backend} does not work with context parallelism!"
            assert core_attention_bias_type not in [
                "alibi"
            ], f"{core_attention_bias_type} is not supported with context parallelism!"
            query_layer, key_layer, value_layer = [
                x.contiguous() for x in (query_layer, key_layer, value_layer)
            ]
4944
4945
4946
            with self.attention_dropout_ctx():
                output = attn_forward_func_with_cp(
                    self.training,
4947
4948
4949
4950
4951
4952
4953
                    query_layer,
                    key_layer,
                    value_layer,
                    cu_seqlens_q,
                    cu_seqlens_kv,
                    max_seqlen_q,
                    max_seqlen_kv,
4954
4955
                    cu_seqlens_q_padded,
                    cu_seqlens_kv_padded,
4956
                    self.attention_dropout if self.training else 0.0,
4957
4958
4959
                    cp_group,
                    cp_global_ranks,
                    cp_stream,
4960
                    cp_comm_type,
4961
                    softmax_scale=self.softmax_scale,
4962
                    qkv_format=qkv_format,
4963
                    attn_mask_type=attn_mask_type,
4964
4965
                    attn_bias_type=core_attention_bias_type,
                    attn_bias=core_attention_bias,
4966
                    deterministic=self.deterministic,
4967
                    use_fused_attention=True,
4968
                    window_size=window_size,
4969
4970
                    fp8=fp8,
                    fp8_meta=fp8_meta,
4971
                    quantizers=quantizers,
4972
                    pad_between_seqs=pad_between_seqs,
4973
4974
                )
        else:
4975
4976
4977
4978
4979
4980
4981
            with self.attention_dropout_ctx():
                output = FusedAttnFunc.apply(
                    self.training,
                    max_seqlen_q,
                    max_seqlen_kv,
                    cu_seqlens_q,
                    cu_seqlens_kv,
4982
4983
                    cu_seqlens_q_padded,
                    cu_seqlens_kv_padded,
4984
4985
4986
4987
4988
4989
4990
4991
4992
4993
                    query_layer,
                    key_layer,
                    value_layer,
                    core_attention_bias,
                    self.softmax_scale,
                    self.attention_dropout if self.training else 0.0,
                    fast_zero_fill,
                    qkv_layout,
                    core_attention_bias_type,
                    attn_mask_type,
4994
                    window_size,
4995
4996
4997
4998
4999
                    None,  # rng_gen
                    fused_attention_backend,
                    use_FAv2_bwd,
                    fp8,
                    fp8_meta,
5000
                    quantizers,
5001
                    self.deterministic,
5002
                )
5003

5004
5005
        # ...hd -> ...(hd)
        return output.view(*output.shape[:-2], -1)
5006
5007


5008
class DotProductAttention(TransformerEngineBaseModule):
5009
5010
5011
5012
5013
5014
    """Allows the model to jointly attend to information from different
    representation subspaces as described in the paper:
    `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.

    .. note::

5015
        Argument :attr:`attention_mask` in the `forward` call is only used when
5016
        :attr:`attn_mask_type` includes '"padding"' or `"arbitrary"`.
5017
5018
5019

    .. warning::

5020
        FlashAttention uses a non-deterministic algorithm for optimal performance. To observe
5021
        deterministic behavior at the cost of performance, use FlashAttention version >= `2.4.1`
5022
5023
        and set the environment variable :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO=0`. In order
        to disable`flash-attn` entirely, set :attr:`NVTE_FLASH_ATTN=0`.
5024

5025
5026
5027
5028
5029
5030
5031
    .. note::

        Transformer Engine stores the FP8 metadata under a `._extra_state` key when checkpointing.
        As the FP8 attention support expands from one backend to multiple backends, the location
        of that key has also shifted (see `FP8 checkpoint compatibility <https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/faq.html#fp8-checkpoint-compatibility>`_).


5032
5033
5034
5035
    Parameters
    ----------
    num_attention_heads : int
                         number of attention heads in the transformer layer.
5036
5037
5038
    kv_channels : Union[int, Tuple[int, int]]
                the head size in key and value tensors. If the same, :attr:`kv_channels` can be
                an integer; if not, :attr:`kv_channels` should be a tuple of two integers.
5039
5040
5041
5042
5043
5044
5045
5046
    num_gqa_groups : Optional[int] = None
                    number of GQA groups in the transformer layer.
                    Grouped Query Attention is described in
                    `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
                    This only affects the keys and values, not the queries.
                    GQA-1 is equivalent to Multi-Query Attention
                    (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
                    is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
5047
5048
    attention_dropout: float, default = 0.0
                      dropout probability for the dropout op during multi-head attention.
5049
    attn_mask_type: str, default = `causal`
5050
                   type of attention mask passed into softmax operation, options are "`no_mask`",
5051
5052
5053
5054
5055
5056
5057
5058
5059
                   "`padding`", "`causal`", "`padding,causal`", "`causal,padding`",
                   "`padding_causal`", "`causal_bottom_right`", "`padding_causal_bottom_right`", and
                   "`arbitrary`", where "`padding,causal`", "`causal,padding`" and "`padding_causal`"
                   are equivalent. This arg can be overridden by :attr:`attn_mask_type` in the
                   `forward` method. It is useful for cases involving compilation/tracing, e.g.
                   ONNX export, and the forward arg is useful for dynamically changing mask types,
                   e.g. a different mask for training and inference.
                   1. For "`no_mask`", no attention mask is applied.
                   2. For "`causal`", "`causal_bottom_right`", or the causal mask in
5060
                   "`padding_causal`" and "`padding_causal_bottom_right`", Transformer Engine
5061
5062
5063
5064
5065
5066
5067
5068
5069
5070
5071
5072
5073
5074
                   calculates and applies an upper triangular mask to the softmax input.
                   No user input is needed. Causal masks without the "`bottom_right`" appendix align
                   the diagonal line to the top left corner of the softmax matrix. With
                   "`bottom_right`", the causal mask is aligned to the bottom right corner, which is
                   often used in inference/KV caching.
                   3. For "`padding`", or the padding mask in "`padding_causal`" and
                   "`padding_causal_bottom_right`", users need to provide the locations of padded
                   tokens, either via :attr:`cu_seqlens_q` and :attr:`cu_seqlens_kv` (both in shape
                   [batch_size + 1]), or via :attr:`attention_mask` (one tensor for self-attention
                   in shape [batch_size, 1, 1, max_seqlen_q], or two tensors in a tuple for
                   cross-attention in shapes [batch_size, 1, 1, max_seqlen_q] and
                   [batch_size, 1, 1, max_seqlen_kv]).
                   4. For "`arbitrary`", users need to provide a mask that is broadcastable to
                   the shape of softmax input [batch_size, num_heads, max_seqlen_q, max_seqlen_kv].
5075
5076
5077
5078
    window_size: Optional[Tuple[int, int]], default = `None`
                sliding window size for local attention, where query at position i attends to keys
                in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
                + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
5079
5080
5081
                window and causal mask specifically. Both `causal` and `causal_bottom_right` masks
                map to `window_size = (-1, 0)` and Transformer Engine distinguishes them based on
                `attn_mask_type`. Similar to :attr:`attn_mask_type`, `window_size` can
5082
                be overridden by :attr:`window_size` in `forward` as well.
5083
5084
    attention_type: str, default = `self`
                   type of attention, either "`self`" and "`cross`".
5085
5086
5087
    layer_number: int, default = `None`
                 layer number of the current `DotProductAttention` when multiple such modules
                 are concatenated, for instance in consecutive transformer blocks.
5088
5089
5090
    qkv_format: str, default = `sbhd`
               dimension format for `query_layer`, `key_layer` and `value_layer`,
               {`sbhd`, `bshd`, `thd`}. `s` stands for the sequence length, `b` batch size,
5091
               `h` the number of heads, `d` head size, and `t` the total number of tokens
5092
5093
5094
5095
5096
               in a batch, with `t = sum(s_i), for i = 0...b-1`. `sbhd` and `bshd` formats
               are used for when sequences in a batch are of equal length or padded to
               equal length, and the `thd` format is used for when sequences in a batch
               have different lengths. Please note that these formats do not reflect how
               tensors `query_layer`, `key_layer`, `value_layer` are laid out in memory.
5097
               For that, please use `get_qkv_layout` to gain the layout information.
5098
5099
    softmax_scale: Optional[float], default = `None`
                softmax scale for the attention scores. If `None`, defaults to
5100
                `1.0/math.sqrt(kv_channels if isinstance(kv_channels, int) else kv_channels[0])`.
5101
5102
5103
5104
5105
5106
5107
5108
5109

    Parallelism parameters
    ----------------------
    sequence_parallel : bool, default = `False`
                       if set to `True`, uses sequence parallelism.
    tp_size : int, default = 1
             tensor parallel world size.
    tp_group : ProcessGroup, default = `None`
              tensor parallel process group.
5110
    cp_group : Union[ProcessGroup, List[ProcessGroup]], default = `None`
5111
              context parallel process group.
5112
5113
5114
              ProcessGroup is for cp_comm_type of "p2p", "all_gather", and "a2a".
              List[ProcessGroup] is for cp_comm_type of "a2a+p2p", where cp_group[0]
              and cp_group[1] are for a2a and p2p communications respectively.
5115
5116
5117
5118
5119
5120
5121
    cp_global_ranks : list of global rank IDs, default = `None`
                     global rank IDs of GPUs that are in cp_group.
    cp_stream : CUDA stream, default = `None`
               context parallelism splits flash attention into multiple steps for
               compute and communication overlapping. To address the wave quantization
               issue of each split step, we add an additional CUDA stream so that we
               can overlap two flash attention kernels.
5122
    cp_comm_type : str, default = `p2p`
5123
                  inter-gpu communication type for context parallelism.
5124
                  Can be "p2p" or "all_gather" or "a2a" or "a2a+p2p".
5125
5126
5127
5128
5129
5130
                  "p2p": Exchange KV chunks with P2P communications in ring topology.
                         P2P is async and can be overlapped with attention compute.
                  "all_gather": All-gather to get full sequence of KV before attention.
                                The all-gather is not async, and cannot be overlapped.
                  "a2a": Like DeepSpeed Ulysses, scatter attention heads across the CP
                         group, and gather to get full sequence of QKV.
5131
5132
5133
                  "a2a+p2p": hierarchical CP implementation. First applying a2a to QKV
                  across each CP sub-group (e.g., via NVLink), then exchanging KV with
                  p2p between sub-groups (e.g., via IBLink).
5134
5135
5136
5137
5138
    """

    def __init__(
        self,
        num_attention_heads: int,
5139
        kv_channels: Union[int, Tuple[int, int]],
5140
        num_gqa_groups: Optional[int] = None,
5141
        attention_dropout: float = 0.0,
5142
        qkv_format: str = "sbhd",
5143
        attn_mask_type: str = "causal",
5144
        window_size: Optional[Tuple[int, int]] = None,
5145
5146
5147
5148
5149
        sequence_parallel: bool = False,
        tp_size: int = 1,
        get_rng_state_tracker: Optional[Callable] = None,
        tp_group: Optional[dist_group_type] = None,
        layer_number: Optional[int] = None,
5150
        attention_type: str = "self",
5151
        cp_group: Optional[Union[dist_group_type, List[dist_group_type]]] = None,
5152
        cp_global_ranks: List[int] = None,
5153
        cp_stream: torch.cuda.Stream = None,
5154
        cp_comm_type: str = "p2p",
5155
        softmax_scale: Optional[float] = None,
5156
5157
5158
    ) -> None:
        super().__init__()

5159
        self.logger = logging.getLogger("DotProductAttention")
5160
        self.logger.setLevel(attn_log._log_level)
5161
        if not self.logger.hasHandlers():
5162
            self.logger.addHandler(attn_log._stream_handler)
5163
        self.qkv_format = qkv_format
5164
        attn_mask_type = attn_mask_type.replace(",", "_")
5165
5166
        if attn_mask_type == "causal_padding":
            attn_mask_type = "padding_causal"
5167
        self.attn_mask_type = attn_mask_type
5168
        self.window_size = dpa_utils.check_set_window_size(attn_mask_type, window_size)
5169
5170
5171
5172
5173
5174
5175
        if tp_group is None:
            self.tp_size = tp_size
            if tp_size == 1:
                self.set_tensor_parallel_group(tp_group)
        else:
            self.tp_size = get_distributed_world_size(tp_group)
            self.set_tensor_parallel_group(tp_group)
5176
        self.get_rng_state_tracker = get_rng_state_tracker
5177
        self.num_attention_heads = num_attention_heads
5178
        self.layer_number = 1 if layer_number is None else layer_number
5179
5180
5181
        self.cp_group = cp_group
        self.cp_global_ranks = cp_global_ranks
        self.cp_stream = cp_stream
5182
        self.cp_comm_type = cp_comm_type
5183

5184
5185
5186
5187
5188
5189
        self.hidden_size_per_attention_head_k = (
            kv_channels if isinstance(kv_channels, int) else kv_channels[0]
        )
        self.hidden_size_per_attention_head_v = (
            kv_channels if isinstance(kv_channels, int) else kv_channels[1]
        )
5190

5191
        self.num_gqa_groups = num_attention_heads if num_gqa_groups is None else num_gqa_groups
5192
        self.num_gqa_groups_per_partition = int(self.num_gqa_groups // self.tp_size)
5193

5194
5195
5196
        assert (
            num_attention_heads % self.num_gqa_groups == 0
        ), "The number of attention heads must be divisible by the number of GQA groups!"
5197

5198
        self.rng_states_tracker = None
5199
5200
5201
        if sequence_parallel or get_rng_state_tracker is None:
            attention_dropout_ctx = nullcontext
        else:
5202
5203
5204
            self.rng_states_tracker = get_rng_state_tracker()
            set_all_rng_states(self.rng_states_tracker.get_states())
            attention_dropout_ctx = self.rng_states_tracker.fork
5205

5206
        if softmax_scale is None:
5207
5208
5209
            softmax_scale = 1.0 / math.sqrt(
                kv_channels if isinstance(kv_channels, int) else kv_channels[0]
            )
5210

5211
5212
5213
        self.deterministic = (
            not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1")))
            or torch.are_deterministic_algorithms_enabled()
5214
        )
5215
5216
5217
5218
5219
5220
5221
5222
5223
5224
5225
5226
5227
5228
5229
5230
5231
5232
5233
        # To use the workspace optimization path for determinism, please
        # set NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT=1 for cuDNN >=8.9.5 and <9.0.0,
        # and set NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 for cuDNN >=9.0.0.
        cudnn_version = get_cudnn_version()
        if (8, 9, 5) <= cudnn_version < (9, 0, 0):
            if self.deterministic:
                os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] = "1"

            # CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT
            # - unset:       enables workspace optimization when required workspace is <= 256MB
            #                or when bias gradient needs to be computed
            # - n:           enables workspace optimization when required workspace is <= n bytes
            # - -1:          enables workspace optimization always
            # - 0:           disables workspace optimization always
            if "NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT" in os.environ:
                if os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] == "0":
                    os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "0"
                if os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] == "1":
                    os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "-1"
5234

5235
        assert attention_type in AttnTypes, f"attention_type {attention_type} not supported"
5236
5237
5238
5239

        self.attention_type = attention_type
        self.attention_dropout = attention_dropout

5240
5241
5242
5243
5244
        attn_kwargs = {
            "attention_dropout": attention_dropout,
            "attention_dropout_ctx": attention_dropout_ctx,
        }

5245
5246
5247
5248
5249
5250
5251
        self.flash_attention = FlashAttention(
            softmax_scale,
            attention_type=attention_type,
            layer_number=layer_number,
            deterministic=self.deterministic,
            **attn_kwargs,
        )
5252

5253
        # Instantiating three types since use of flash-attn and FusedAttention
5254
        # might be ruled out due to forward inputs.
5255
5256
5257
5258
5259
5260
5261
        self.fused_attention = FusedAttention(
            softmax_scale,
            attention_type=attention_type,
            layer_number=layer_number,
            deterministic=self.deterministic,
            **attn_kwargs,
        )
5262

5263
        self.unfused_attention = UnfusedDotProductAttention(
5264
5265
5266
5267
            softmax_scale,
            attention_type=attention_type,
            **attn_kwargs,
            layer_number=layer_number,
5268
        )
5269

5270
5271
5272
        def remove_extra_states_check(self, incompatible_keys):  # pylint: disable=unused-argument
            """
            Temporarily remove core_attention._extra_state as a missing key
5273
5274
            when loading older Transformer Engine checkpoints. Will phase out
            this hook in Transformer Engine 2.0.
5275
5276
5277
5278
5279
5280
5281
            """
            for key in incompatible_keys.missing_keys:
                if "core_attention._extra_state" in key:
                    incompatible_keys.missing_keys.remove(key)

        self.register_load_state_dict_post_hook(remove_extra_states_check)

5282
5283
5284
5285
5286
5287
5288
5289
5290
5291
5292
5293
5294
5295
5296
5297
5298
5299
5300
5301
5302
5303
    def _load_from_state_dict(
        self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
    ):
        """
        This function helps to load Transformer Engine 1.6 and 1.7 checkpoints, where FP8 attention
        metadata is stored under the `core_attention.fused_attention._extra_state` key and not the
        `core_attention._extra_state` key. Please see `FP8 checkpoint compatibility
        <https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/faq.html#fp8-checkpoint-compatibility>`_ for more details.
        """
        fused_attn_key = False
        dot_product_attn_key = False
        for k in state_dict.keys():
            if "core_attention.fused_attention._extra_state" in k:
                fused_attn_key = True
            if "core_attention._extra_state" in k:
                dot_product_attn_key = True
        if fused_attn_key and not dot_product_attn_key:
            prefix = prefix + "fused_attention."
        super()._load_from_state_dict(
            state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
        )

5304
5305
5306
5307
    def _checkpointed_attention_forward(
        self,
        attention_func: Callable,
        *forward_args: Tuple[torch.Tensor, ...],
5308
        **forward_kwargs: Dict[str, Any],
5309
5310
5311
    ) -> torch.Tensor:
        """Forward method with activation checkpointing."""

5312
5313
        def custom_forward(*input_args, **input_kwargs):
            return attention_func(*input_args, **input_kwargs)
5314
5315
5316

        hidden_states = checkpoint(
            custom_forward,
5317
5318
5319
            distribute_saved_activations=False,
            get_rng_state_tracker=self.get_rng_state_tracker,
            tp_group=self.tp_group,
5320
            *forward_args,
5321
            **forward_kwargs,
5322
5323
5324
5325
        )

        return hidden_states

5326
5327
    def set_context_parallel_group(
        self,
5328
        cp_group: Union[dist_group_type, List[dist_group_type], None],
5329
5330
        cp_global_ranks: List[int],
        cp_stream: torch.cuda.Stream,
5331
        cp_comm_type: str = "p2p",
5332
    ) -> None:
5333
5334
5335
5336
5337
5338
        """
        Set the context parallel attributes for the given
        module before executing the forward pass.

        Parameters
        ----------
5339
        cp_group : Union[ProcessGroup, List[ProcessGroup]]
5340
                  context parallel process group.
5341
5342
5343
                  ProcessGroup is for cp_comm_type of "p2p", "all_gather", and "a2a".
                  List[ProcessGroup] is for cp_comm_type of "a2a+p2p", where cp_group[0]
                  and cp_group[1] are for a2a and p2p communications respectively.
5344
5345
5346
5347
        cp_global_ranks : List[int]
                         list of global ranks in the context group.
        cp_stream : torch.cuda.Stream
                   cuda stream for context parallel execution.
5348
        cp_comm_type : str, default = `p2p`
5349
                      inter-gpu communication type for context parallelism.
5350
                      Can be "p2p" or "all_gather" or "a2a" or "a2a+p2p".
5351
5352
5353
5354
5355
5356
                      "p2p": Exchange KV chunks with P2P communications in ring topology.
                             P2P is async and can be overlapped with attention compute.
                      "all_gather": All-gather to get full sequence of KV before attention.
                                    The all-gather is not async, and cannot be overlapped.
                      "a2a": Like DeepSpeed Ulysses, scatter attention heads across the CP
                             group, and gather to get full sequence of QKV.
5357
5358
5359
                      "a2a+p2p": hierarchical CP implementation. First applying a2a to QKV
                      across each CP sub-group (e.g., via NVLink), then exchanging KV with
                      p2p between sub-groups (e.g., via IBLink).
5360
        """
5361
5362
5363
        self.cp_group = cp_group
        self.cp_global_ranks = cp_global_ranks
        self.cp_stream = cp_stream
5364
        self.cp_comm_type = cp_comm_type
5365

5366
    @no_torch_dynamo(recursive=False)
5367
5368
5369
5370
5371
    def forward(
        self,
        query_layer: torch.Tensor,
        key_layer: torch.Tensor,
        value_layer: torch.Tensor,
5372
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
5373
5374
5375
        qkv_format: Optional[str] = None,
        cu_seqlens_q: Optional[torch.Tensor] = None,
        cu_seqlens_kv: Optional[torch.Tensor] = None,
5376
5377
        cu_seqlens_q_padded: Optional[torch.Tensor] = None,
        cu_seqlens_kv_padded: Optional[torch.Tensor] = None,
5378
5379
        max_seqlen_q: Optional[int] = None,
        max_seqlen_kv: Optional[int] = None,
5380
        attn_mask_type: Optional[str] = None,
5381
        window_size: Optional[Tuple[int, int]] = None,
5382
        checkpoint_core_attention: bool = False,
5383
5384
        core_attention_bias_type: str = "no_bias",
        core_attention_bias: Optional[torch.Tensor] = None,
5385
        alibi_slopes: Optional[torch.Tensor] = None,
5386
        fast_zero_fill: bool = True,
5387
        inference_params: Optional[InferenceParams] = None,
5388
        pad_between_seqs: Optional[bool] = None,
5389
5390
5391
5392
5393
5394
    ) -> torch.Tensor:
        """
        Dot Product Attention Layer.

        .. note::

5395
5396
            Argument :attr:`attention_mask` is only used when :attr:`attn_mask_type`
            includes '"padding"' or `"arbitrary"`.
5397

5398
5399
        .. note::

5400
5401
5402
5403
5404
5405
5406
5407
5408
5409
5410
5411
5412
            DotProductAttention supports three backends: 1) FlashAttention which calls
            HazyResearch/Dao-AILab's `flash-attn <https://arxiv.org/pdf/2305.13245.pdf>`_
            PyTorch API, 2) FusedAttention which has multiple fused attention implementations
            based on `cuDNN Graph API
            <https://docs.nvidia.com/deeplearning/cudnn/developer-guide/index.html#op-fusion>`_
            (see :attr:`FusedAttention` for more details on FusedAttention backends), and 3)
            UnfusedDotProductAttention which is the native PyTorch implementation
            with fused scaled masked softmax.

        .. note::

            Users can use environment variables :attr:`NVTE_FLASH_ATTN`, :attr:`NVTE_FUSED_ATTN`,
            and :attr:`NVTE_FUSED_ATTN_BACKEND` to control which DotProductAttention backend,
5413
            and FusedAttention backend if applicable, to use. Transformer Engine prioritizes
5414
5415
5416
5417
            FlashAttention over FusedAttention and over UnfusedDotProductAttention.
            If FusedAttention is being used, users can also choose to switch to flash-attn's
            implementation for backward by setting :attr:`NVTE_FUSED_ATTN_USE_FAv2_BWD=1`
            (default: 0), because of the performance differences between various versions of
5418
5419
            flash-attn and FusedAttention. Further, :attr:`NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT`
            can be used to enable (:attr:`1`) or disable (:attr:`0`) the workspace related
5420
            optimizations in FusedAttention. When unset, Transformer Engine determines the code path
5421
5422
            based on its internal logic. These optimizations trade memory for performance
            and should be used with care.
5423

5424
5425
5426
5427
5428
5429
5430
5431
5432
5433
5434
5435
5436
5437
5438
5439
5440
5441
5442
5443
5444
5445
5446
5447
5448
5449
5450
5451
5452
5453
5454
5455
5456
5457
5458
5459
5460
5461
5462
5463
5464
5465
5466
5467
5468
5469
5470
5471
5472
5473
5474
5475
5476
5477
        .. note::
            .. _cu_seqlens note:

            When training data has variable sequence lengths, users have two options.

            1. Manipulate the data and pad all sequences to the same length. Use
               :attr:`qkv_format` = {"bshd", "sbhd"} and
               :attr:`attn_mask_type` = {"padding", "padding_causal", "padding_causal_bottom_right"}.
               Pass in :attr:`cu_seqlens_q` and :attr:`cu_seqlens_kv`, or :attr:`attention_mask`
               (which will be converted to :attr:`cu_seqlens_q` and :attr:`cu_seqlens_kv`), to provide
               the real sequence length information. For example, a batch of 3 sequences
               [a a a b b c c c c] can be padded to [a a a PAD b b PAD PAD c c c c], and the cumulative
               sequence length tensors would be
               :attr:`cu_seqlens_q` = :attr:`cu_seqlens_kv` = [0, 3, 5, 9] for self-attention.

            2. Do not perform padding on training data. Use :attr:`qkv_format` = "thd" and
               :attr:`attn_mask_type` = {"padding", "padding_causal", "padding_causal_bottom_right"}.
               Pass in :attr:`cu_seqlens_q` and :attr:`cu_seqlens_kv`, or :attr:`attention_mask`,
               as in option 1. For example, a batch of 3 sequences [a a a b b c c c c] can be processed
               without any padding, and the sequence length tensors would be
               :attr:`cu_seqlens_q` = :attr:`cu_seqlens_kv` = [0, 3, 5, 9] for self-attention.

               In certain use cases, a varying number of identifier tokens are inserted between
               sequences. These tokens do not participate in the attention calculation.
               :attr:`cu_seqlens_q_padded` and :attr:`cu_seqlens_kv_padded` must be specified
               in such cases to correctly identify the start and end of each sequence in a batch.
               For example, a batch of 3 sequences [a a a 1 b b 2 2 c c c c 3] would have
               :attr:`cu_seqlens_q` = :attr:`cu_seqlens_kv` = [0, 3, 5, 9], and
               :attr:`cu_seqlens_q_padded` = :attr:`cu_seqlens_kv_padded` = [0, 4, 8, 13]
               for self-attention.

        .. note::
            .. _max_seqlen note:

            When :attr:`qkv_format` = {"bshd", "sbhd"}, sequences are of equal length in a batch.
            :attr:`max_seqlen_q` and :attr:`max_seqlen_kv` should be the same as the "s" dimension of
            :attr:`query_layer` and :attr:`key_layer` tensors. When unset, Transformer Engine will
            infer them as such.

            When :attr:`qkv_format` = "thd", sequences have varying lengths. :attr:`max_seqlen_q` and
            :attr:`max_seqlen_kv` should be the maximum query and key/value sequence length in a batch.
            When unset, Transformer Engine deduces them from :attr:`cu_seqlens_q` and :attr:`cu_seqlens_kv`.
            This deduction costs a small kernel and some CPU-GPU synchronization, and to avoid this
            overhead, users are recommended to obtain the maximum sequence lengths from the data loaders
            and pass them in.

            - As the maximum sequence lengths, batch size, and number of tokens change from batch to batch,
              dynamic shapes need to be supported for tensor construction. FlashAttention and
              UnfusedDotProductAttention naturally do so, while FusedAttention requires parameters to be static
              to create graphs before performance heuristics analysis. To reduce the number of graphs created
              per run, Transformer Engine 1.13+ quantizes relevant parameters: for cuDNN < 9.6, {batch size,
              :attr:`max_seqlen_q`, :attr:`max_seqlen_kv`}, and for cuDNN >= 9.6, {"t" dimension of
              :attr:`query_layer`, "t" dimension of :attr:`key_layer`}.

5478
5479
5480
5481
5482
5483
5484
5485
        Parameters
        ----------
        query_layer : torch.Tensor
                     Query tensor.
        key_layer : torch.Tensor
                   Key tensor.
        value_layer : torch.Tensor
                     Value tensor.
5486
5487
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]],
             default = `None`. Boolean tensor(s) used to mask out attention softmax input.
5488
             It should be `None` for causal masks and "`no_mask`". For padding masks, it should be
5489
5490
             a single tensor of [batch_size, 1, 1, seqlen_q] for self-attention, and a tuple of
             two tensors in shapes [batch_size, 1, 1, seqlen_q] and [batch_size, 1, 1, seqlen_kv]
5491
5492
5493
5494
             for cross-attention. For "`arbitrary`" mask, it should be in a shape broadcastable
             to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv]. A `True` value means
             the corresponding position is masked out and a `False` means that position
             is allowed to participate in attention.
5495
5496
5497
        qkv_format: str, default = `None`
                   If provided, overrides :attr:`qkv_format` from initialization.
        cu_seqlens_q: Optional[torch.Tensor], default = `None`
5498
                   Cumulative sum of sequence lengths (without offset) in a batch for `query_layer`,
5499
                   with shape [batch_size + 1] and dtype torch.int32.
5500
                   See :ref:`note<cu_seqlens note>` for more details.
5501
        cu_seqlens_kv: Optional[torch.Tensor], default = `None`
5502
5503
                   Cumulative sum of sequence lengths (without offset) in a batch for `key_layer`
                   and `value_layer`, with shape [batch_size + 1] and dtype torch.int32.
5504
                   See :ref:`note<cu_seqlens note>` for more details.
5505
5506
5507
5508
5509
        cu_seqlens_q_padded: Optional[torch.Tensor], default = `None`
                   Cumulative sum of sequence lengths (with offset) in a batch for
                   `query_layer`, with shape [batch_size + 1] and dtype torch.int32.
                   When there is no padding between sequences in a batch,
                   `cu_seqlens_q_padded = cu_seqlens_q`.
5510
                   See :ref:`note<cu_seqlens note>` for more details.
5511
5512
5513
5514
5515
        cu_seqlens_kv_padded: Optional[torch.Tensor], default = `None`
                   Cumulative sum of sequence lengths (with offset) in a batch for `key_layer`
                   and `value_layer`, with shape [batch_size + 1] and dtype torch.int32.
                   When there is no padding between sequences in a batch,
                   `cu_seqlens_kv_padded = cu_seqlens_kv`.
5516
                   See :ref:`note<cu_seqlens note>` for more details.
5517
5518
        max_seqlen_q: Optional[int], default = `None`
                      Maximum sequence length in `query_layer`.
5519
                      See :ref:`note<max_seqlen note>` for more details.
5520
5521
        max_seqlen_kv: Optional[int], default = `None`
                       Maximum sequence length in `key_layer` and `value_layer`.
5522
                       See :ref:`note<max_seqlen note>` for more details.
5523
5524
5525
5526
5527
5528
5529
        attn_mask_type: {'no_mask', 'padding', 'causal', 'padding,causal', 'causal,padding',
                       'padding_causal', 'causal_bottom_right', 'padding_causal_bottom_right',
                       'arbitrary'}, default = `None`. Type of attention mask passed into
                       softmax operation. 'padding,causal', 'causal,padding' and 'padding_causal'
                       are equivalent. By default, causal masks are aligned to the top left corner
                       of the softmax matrix. When "`bottom_right`" is specified in the mask type,
                       causal masks are aligned to the bottom right corner.
5530
        window_size: Optional[Tuple[int, int]], default = `None`
5531
                    Sliding window size for local attention.
5532
5533
5534
5535
5536
        checkpoint_core_attention : bool, default = `False`
                                   If true, forward activations for attention are recomputed
                                   during the backward pass in order to save memory that would
                                   otherwise be occupied to store the forward activations until
                                   backprop.
5537
        core_attention_bias_type: str, default = `no_bias`
5538
                    Bias type, {`no_bias`, `pre_scale_bias`, `post_scale_bias`, `alibi`}
5539
        core_attention_bias: Optional[torch.Tensor], default = `None`
5540
5541
                    Bias tensor for Q * K.T, shape [1, num_head, max_seqlen_q, max_seqlen_kv].
                    It should be 'None' for 'no_bias' and 'alibi' bias types.
5542
5543
5544
5545
        alibi_slopes: Optional[torch.Tensor], default = `None`
                     ALiBi slopes in FP32 and shape [nheads] or [batch_size, nheads].
                     It adds a bias of (-alibi_slope * (i + seqlen_k - seqlen_q - j))
                     to the attention score of query i and key j.
5546
        fast_zero_fill: bool, default = `True`
5547
                    Whether to use the fast path to set output tensors to 0 or not.
5548
5549
5550
5551
5552
5553
5554
5555
5556
5557
        inference_params: Optional[InferenceParams], default = `None`
            Optimizes execution performance during inference by caching Keys and Values of the
            current decoding iteration. These cached values are appended to the K and V values
            computed in previous iterations, eliminating the need to recalculate them for the
            entire sequence.
            Initialization of `inference_params` is required prior to use to ensure sufficient
            memory allocation.
            Adjustments of the sequence_len_offset should be done after a complete forward pass.
            If rotary positional embeddings (RoPE) are utilized, they must be prepared beforehand.
            Supports "sbhd" and "bshd" layouts, with the "sbhd" layout being more efficient.
5558
5559
5560
        pad_between_seqs: Optional[bool], default = `None`
            If None, inferred from qkv_format, cu_seqlens and cu_seqlens_padded.
            If true, there are padding tokens between individual sequences in a packed batch.
5561
        """
5562

5563
5564
5565
5566
5567
5568
5569
5570
5571
        with self.prepare_forward(
            query_layer,
            num_gemms=3,
            allow_non_contiguous=True,
        ) as query_layer:
            if self.fp8:
                if self.fp8_meta["recipe"].fp8_mha:
                    if not self.fp8_meta["recipe"].fp8_dpa:
                        self.fp8_meta["recipe"].fp8_dpa = True
5572
                        self.logger.warning(
5573
5574
5575
                            """Forcing fp8_meta["recipe"].fp8_dpa=True due to """
                            """fp8_meta["recipe"].fp8_mha=True"""
                        )
5576
5577
5578
5579
5580
5581
5582
5583
5584
5585
5586

            if self.fp8 and self.fp8_meta["recipe"].fp8_dpa:
                forward_dtype = get_fp8_te_dtype(self.fp8_meta["recipe"], fprop_tensor=True)
                backward_dtype = get_fp8_te_dtype(self.fp8_meta["recipe"], fprop_tensor=False)
                assert forward_dtype in [
                    tex.DType.kFloat8E4M3,
                    tex.DType.kFloat8E5M2,
                ] and backward_dtype in [
                    tex.DType.kFloat8E4M3,
                    tex.DType.kFloat8E5M2,
                ], """DotProductAttention only supports "E4M3" and "E5M2" FP8 data types."""
5587

5588
5589
5590
            assert (
                query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda
            ), "DotProductAttention only supports CUDA tensors."
5591
5592
5593
            assert (
                query_layer.dtype == key_layer.dtype and query_layer.dtype == value_layer.dtype
            ), "Queries, keys and values must have the same data type!"
5594
5595
5596
            assert (
                key_layer.shape[:-1] == value_layer.shape[:-1]
            ), "Keys and values must have the same batch size, sequence length and number of heads!"
5597
5598
5599
5600
5601
5602
5603
5604
            assert (
                key_layer.shape[-1] == self.hidden_size_per_attention_head_k
            ), f"Keys have head_dim = {key_layer.shape[-1]}, "
            "but expected head_dim = {self.hidden_size_per_attention_head_k}!"
            assert (
                value_layer.shape[-1] == self.hidden_size_per_attention_head_v
            ), f"Values have head_dim = {value_layer.shape[-1]}, "
            "but expected head_dim = {self.hidden_size_per_attention_head_v}!"
5605

5606
5607
5608
            if qkv_format is None:
                qkv_format = self.qkv_format

5609
5610
5611
5612
5613
5614
            if attn_mask_type is None:
                attn_mask_type = self.attn_mask_type
            else:
                attn_mask_type = attn_mask_type.replace(",", "_")
                if attn_mask_type == "causal_padding":
                    attn_mask_type = "padding_causal"
5615
            assert (
5616
5617
5618
5619
5620
5621
                attn_mask_type in AttnMaskTypes
            ), f"Attention mask type {attn_mask_type} is not supported!"
            if qkv_format == "thd":
                assert (
                    "padding" in attn_mask_type
                ), "Attention mask type must be padding or padding_causal for qkv_format=thd!"
5622

5623
5624
            if window_size is None:
                window_size = self.window_size
5625
            window_size = dpa_utils.check_set_window_size(attn_mask_type, window_size)
5626

5627
5628
5629
5630
5631
5632
5633
            if self.rng_states_tracker is not None and is_graph_capturing():
                assert isinstance(
                    self.rng_states_tracker, CudaRNGStatesTracker
                ), "Unsupported RNG states tracker."
                assert (
                    graph_safe_rng_available()
                ), "Upgrade PyTorch version to get RNG manipulation support for cuda graph capture."
5634

5635
5636
            if inference_params is not None:
                assert self.layer_number is not None, "Layer number must be set!"
5637

5638
5639
5640
5641
5642
                # convert causal to causal_bottom_right in inference when KV-caching is in use
                # so users can run with the same attn_mask_type for training and inference
                if attn_mask_type in ["causal", "padding_causal"]:
                    attn_mask_type = attn_mask_type + "_bottom_right"

5643
5644
5645
                if qkv_format == "bshd":
                    key_layer = key_layer.transpose(0, 1)
                    value_layer = value_layer.transpose(0, 1)
5646

5647
5648
5649
5650
                (
                    inference_key_memory,
                    inference_value_memory,
                ) = inference_params.key_value_memory_dict[self.layer_number]
5651

5652
5653
5654
                batch_start = inference_params.batch_size_offset
                batch_end = batch_start + key_layer.size(1)
                assert batch_end <= inference_key_memory.size(1)
5655

5656
5657
5658
                sequence_start = inference_params.sequence_len_offset
                sequence_end = sequence_start + key_layer.size(0)
                assert sequence_end <= inference_key_memory.size(0)
5659

5660
5661
5662
5663
5664
5665
5666
5667
5668
                # Copy keys and values into KV-cache
                inference_key_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = (
                    key_layer
                )
                inference_value_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = (
                    value_layer
                )
                key_layer = inference_key_memory[:sequence_end, batch_start:batch_end, ...]
                value_layer = inference_value_memory[:sequence_end, batch_start:batch_end, ...]
5669

5670
5671
5672
                if qkv_format == "bshd":
                    key_layer = key_layer.transpose(0, 1)
                    value_layer = value_layer.transpose(0, 1)
5673

5674
5675
                key_layer = key_layer.contiguous()
                value_layer = value_layer.contiguous()
5676
5677

            assert (
5678
5679
                key_layer.shape[-2] == self.num_gqa_groups_per_partition
                and value_layer.shape[-2] == self.num_gqa_groups_per_partition
5680
5681
5682
5683
            ), (
                "Keys and values must have num_gqa_group ="
                f" {self.num_gqa_groups_per_partition} heads!"
            )
5684
5685
5686
5687
5688
5689
5690
            assert qkv_format in [
                "sbhd",
                "bshd",
                "thd",
            ], "DotProductAttention only supports qkv_format = {'sbhd', 'bshd', 'thd'}!"

            if qkv_format == "thd":
5691
                assert all(
5692
5693
5694
5695
5696
5697
5698
5699
5700
5701
5702
5703
5704
                    len(x.shape) == 3 for x in (query_layer, key_layer, value_layer)
                ), "Queries, keys and values must be 3D tensors when qkv_format = thd!"
                assert (
                    cu_seqlens_q is not None and cu_seqlens_kv is not None
                ), "cu_seqlens_q and cu_seqlens_kv can not be None when qkv_format = thd!"
                assert (
                    cu_seqlens_q.shape == cu_seqlens_kv.shape
                    and len(cu_seqlens_q.shape) == 1
                    and len(cu_seqlens_kv.shape) == 1
                ), "cu_seqlens_q and cu_seqlens_q must both have shape [batch_size + 1]!"
                assert (
                    cu_seqlens_q.dtype == torch.int32 and cu_seqlens_kv.dtype == torch.int32
                ), "cu_seqlens_q and cu_seqlens_q must both be in dtype torch.int32!"
5705
                batch_size = len(cu_seqlens_q) - 1
5706
                if max_seqlen_q is None:
5707
5708
5709
5710
                    if cu_seqlens_q_padded is not None:
                        seqlens_q = cu_seqlens_q_padded[1:] - cu_seqlens_q_padded[:-1]
                    else:
                        seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
5711
                    max_seqlen_q = int((seqlens_q.max().item() + 63) // 64 * 64)
5712
                if max_seqlen_kv is None:
5713
5714
5715
5716
                    if cu_seqlens_kv_padded is not None:
                        seqlens_kv = cu_seqlens_kv_padded[1:] - cu_seqlens_kv_padded[:-1]
                    else:
                        seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
5717
                    max_seqlen_kv = int((seqlens_kv.max().item() + 63) // 64 * 64)
5718

5719
5720
5721
5722
5723
5724
            cp_size = 1
            if isinstance(self.cp_group, dist_group_type):
                cp_size = get_distributed_world_size(self.cp_group)
            elif isinstance(self.cp_group, list):
                for group in self.cp_group:
                    cp_size *= get_distributed_world_size(group)
5725
5726
            context_parallel = cp_size > 1

5727
            if qkv_format in ["sbhd", "bshd"]:
5728
                assert all(
5729
5730
5731
                    len(x.shape) == 4 for x in (query_layer, key_layer, value_layer)
                ), f"Queries, keys and values must be 4D tensors when qkv_format = {qkv_format}!"
                if qkv_format == "sbhd":
5732
5733
                    max_seqlen_q = query_layer.shape[0] if max_seqlen_q is None else max_seqlen_q
                    max_seqlen_kv = key_layer.shape[0] if max_seqlen_kv is None else max_seqlen_kv
5734
                    batch_size = query_layer.shape[1]
5735
                else:
5736
5737
                    max_seqlen_q = query_layer.shape[1] if max_seqlen_q is None else max_seqlen_q
                    max_seqlen_kv = key_layer.shape[1] if max_seqlen_kv is None else max_seqlen_kv
5738
                    batch_size = query_layer.shape[0]
5739
5740
                max_seqlen_q *= cp_size
                max_seqlen_kv *= cp_size
5741
5742
5743
5744
5745
                if cu_seqlens_q is not None:
                    seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
                    assert all(
                        seqlens_q <= max_seqlen_q
                    ), """Sequence lengths indicated by cu_seqlens_q must be no greater than
5746
                        the sequence dimension in 'query_layer'!"""
5747
5748
5749
5750
5751
                if cu_seqlens_kv is not None:
                    seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
                    assert all(
                        seqlens_kv <= max_seqlen_kv
                    ), """Sequence lengths indicated by cu_seqlens_kv must be no greater than
5752
                        the sequence dimension in 'key_layer' and 'value_layer'!"""
5753
5754
5755
5756
5757
                if cu_seqlens_q is None or cu_seqlens_kv is None:
                    if "padding" in attn_mask_type:
                        assert (
                            attention_mask is not None
                        ), "Please provide attention_mask for padding!"
5758
                        if self.attention_type == "self":
5759
                            cu_seqlens_q = dpa_utils.get_cu_seqlens(attention_mask)
5760
5761
                            cu_seqlens_kv = cu_seqlens_q
                        else:
5762
5763
                            cu_seqlens_q = dpa_utils.get_cu_seqlens(attention_mask[0])
                            cu_seqlens_kv = dpa_utils.get_cu_seqlens(attention_mask[1])
5764
                    else:
5765
                        cu_seqlens_q = dpa_utils.get_full_cu_seqlens(
5766
5767
5768
5769
                            batch_size,
                            max_seqlen_q,
                            query_layer.device,
                        )
5770
                        cu_seqlens_kv = dpa_utils.get_full_cu_seqlens(
5771
5772
5773
5774
                            batch_size,
                            max_seqlen_kv,
                            key_layer.device,
                        )
5775

5776
5777
5778
5779
5780
            if (
                isinstance(query_layer, Float8Tensor)
                and isinstance(key_layer, Float8Tensor)
                and isinstance(value_layer, Float8Tensor)
            ):
5781
5782
5783
5784
                qkv_layout, query_layer._data, key_layer._data, value_layer._data = (
                    dpa_utils.get_qkv_layout(
                        query_layer._data, key_layer._data, value_layer._data, qkv_format=qkv_format
                    )
5785
5786
                )
            else:
5787
                qkv_layout, query_layer, key_layer, value_layer = dpa_utils.get_qkv_layout(
5788
5789
                    query_layer, key_layer, value_layer, qkv_format=qkv_format
                )
5790

5791
5792
5793
5794
5795
5796
5797
5798
            global _alibi_cache
            if alibi_slopes is not None:
                assert (
                    core_attention_bias_type == "alibi"
                ), "core_attention_bias_type must be alibi in order to use alibi_slopes!"
                if self.layer_number == 1:
                    _alibi_cache["_alibi_slopes_require_update"] = True
                    _alibi_cache["_alibi_bias_require_update"] = True
5799
            bottom_right_alignment = (attn_mask_type not in ["causal", "padding_causal"],)
5800
5801
5802
5803
5804
5805
5806
5807
            if core_attention_bias_type == "alibi":
                assert (
                    core_attention_bias is None
                ), "core_attention_bias must be None when core_attention_bias_type is alibi!"
                if (
                    _alibi_cache["_num_heads"] != query_layer.shape[-2]
                    or _alibi_cache["_max_seqlen_q"] != max_seqlen_q
                    or _alibi_cache["_max_seqlen_kv"] != max_seqlen_kv
5808
                    or _alibi_cache["_bottom_right_alignment"] != bottom_right_alignment
5809
5810
5811
5812
5813
                    or _alibi_cache["_alibi_slopes"] is None
                ):
                    _alibi_cache["_alibi_slopes_require_update"] = True
                    _alibi_cache["_alibi_bias_require_update"] = True

5814
5815
            core_attention_bias_shape = None
            if core_attention_bias is not None:
5816
                if (
5817
5818
                    core_attention_bias.shape[0] == batch_size
                    and core_attention_bias.shape[1] == query_layer.shape[-2]
5819
                ):
5820
5821
5822
5823
5824
5825
5826
5827
5828
5829
5830
5831
5832
5833
5834
5835
5836
                    core_attention_bias_shape = "bhss"
                elif (
                    core_attention_bias.shape[0] == 1
                    and core_attention_bias.shape[1] == query_layer.shape[-2]
                ):
                    core_attention_bias_shape = "1hss"
                elif (
                    core_attention_bias.shape[0] == batch_size and core_attention_bias.shape[1] == 1
                ):
                    core_attention_bias_shape = "b1ss"
                elif core_attention_bias.shape[0] == 1 and core_attention_bias.shape[1] == 1:
                    core_attention_bias_shape = "11ss"
                else:
                    assert (
                        False
                    ), "core_attention_bias must be in one of {bhss, 1hss, b1ss, 11ss} shapes"

5837
5838
5839
5840
5841
5842
5843
5844
5845
5846
5847
            if pad_between_seqs is None:
                if qkv_format == "thd":
                    pad_between_seqs = (
                        cu_seqlens_q_padded is not None
                        and not torch.equal(cu_seqlens_q_padded[:-1], cu_seqlens_q[:-1])
                    ) or (
                        cu_seqlens_kv_padded is not None
                        and not torch.equal(cu_seqlens_kv_padded[:-1], cu_seqlens_kv[:-1])
                    )
                else:
                    pad_between_seqs = False
5848

5849
            attention_params = dpa_utils.AttentionParams(
5850
5851
5852
5853
5854
5855
5856
5857
                qkv_type=type(query_layer),
                qkv_dtype=query_layer.dtype,
                qkv_layout=qkv_layout,
                batch_size=batch_size,
                num_heads=query_layer.shape[-2],
                num_gqa_groups=key_layer.shape[-2],
                max_seqlen_q=max_seqlen_q,
                max_seqlen_kv=max_seqlen_kv,
5858
5859
                head_dim_qk=query_layer.shape[-1],
                head_dim_v=value_layer.shape[-1],
5860
5861
5862
5863
5864
5865
5866
5867
5868
5869
5870
                attn_mask_type=attn_mask_type,
                window_size=window_size,
                alibi_slopes_shape=alibi_slopes.shape if alibi_slopes is not None else None,
                core_attention_bias_type=core_attention_bias_type,
                core_attention_bias_shape=core_attention_bias_shape,
                core_attention_bias_requires_grad=(
                    core_attention_bias.requires_grad if core_attention_bias is not None else False
                ),
                pad_between_seqs=pad_between_seqs,
                attention_dropout=self.attention_dropout,
                context_parallel=context_parallel,
5871
5872
                deterministic=self.deterministic,
                is_training=self.training,
5873
5874
5875
                fp8=self.fp8,
                fp8_meta=self.fp8_meta,
            )
5876
            global _attention_backends
5877
5878
5879
5880
5881
5882
5883
            if (
                _attention_backends["attention_params"] is None
                or attention_params != _attention_backends["attention_params"]
            ):
                _attention_backends["attention_params"] = attention_params
                _attention_backends["backend_selection_requires_update"] = True
            if _attention_backends["backend_selection_requires_update"]:
5884
                fa_utils.use_v3 = fa_utils.v3_is_installed
5885
5886
5887
5888
5889
5890
                (
                    use_flash_attention,
                    use_fused_attention,
                    fused_attention_backend,
                    use_unfused_attention,
                    _,
5891
5892
5893
5894
5895
5896
5897
5898
                ) = dpa_utils.get_attention_backend(attention_params)
                # Set global _attention_backends var using return value
                # from get_attention_backend()
                _attention_backends["use_flash_attention"] = use_flash_attention
                _attention_backends["use_fused_attention"] = use_fused_attention
                _attention_backends["fused_attention_backend"] = fused_attention_backend
                _attention_backends["use_unfused_attention"] = use_unfused_attention
                _attention_backends["backend_selection_requires_update"] = False
5899
                if use_flash_attention:
5900
5901
                    self.logger.info(
                        "Running with FlashAttention backend (version %s)",
5902
                        fa_utils.version if not fa_utils.use_v3 else fa_utils.fa3_version,
5903
                    )
5904
5905
5906
5907
                elif use_fused_attention:
                    self.logger.info(
                        "Running with FusedAttention backend (sub-backend %s)",
                        int(fused_attention_backend),
5908
                    )
5909
5910
5911
5912
5913
5914
5915
                elif use_unfused_attention:
                    self.logger.info("Running with UnfusedDotProductAttention backend")
            else:
                use_flash_attention = _attention_backends["use_flash_attention"]
                use_fused_attention = _attention_backends["use_fused_attention"]
                fused_attention_backend = _attention_backends["fused_attention_backend"]
                use_unfused_attention = _attention_backends["use_unfused_attention"]
5916

5917
5918
            if use_flash_attention:
                if core_attention_bias_type == "alibi":
5919
5920
                    alibi_slopes, _ = dpa_utils.get_alibi(
                        _alibi_cache,
5921
5922
5923
5924
5925
5926
5927
5928
5929
5930
5931
5932
5933
5934
5935
5936
5937
5938
5939
                        query_layer.shape[-2],
                        max_seqlen_q,
                        max_seqlen_kv,
                        alibi_slopes=alibi_slopes,
                    )
                return self.flash_attention(
                    query_layer,
                    key_layer,
                    value_layer,
                    attention_mask=attention_mask,
                    qkv_layout=qkv_layout,
                    cu_seqlens_q=cu_seqlens_q,
                    cu_seqlens_kv=cu_seqlens_kv,
                    attn_mask_type=attn_mask_type,
                    window_size=window_size,
                    alibi_slopes=alibi_slopes,
                    cp_group=self.cp_group,
                    cp_global_ranks=self.cp_global_ranks,
                    cp_stream=self.cp_stream,
5940
                    cp_comm_type=self.cp_comm_type,
5941
5942
                    max_seqlen_q=max_seqlen_q,
                    max_seqlen_kv=max_seqlen_kv,
5943
5944
                    fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa,
                    fp8_meta=self.fp8_meta,
5945
                    quantizers=self.quantizers,
5946
                )
5947

5948
            if use_fused_attention:
5949
5950
                fu_core_attention_bias_type = core_attention_bias_type
                fu_core_attention_bias = core_attention_bias
5951
5952
5953
                if core_attention_bias_type == "alibi" and (
                    alibi_slopes is not None or max_seqlen_q != max_seqlen_kv
                ):
5954
                    fu_core_attention_bias_type = "post_scale_bias"
5955
5956
                    _, fu_core_attention_bias = dpa_utils.get_alibi(
                        _alibi_cache,
5957
5958
5959
5960
5961
                        query_layer.shape[-2],
                        max_seqlen_q,
                        max_seqlen_kv,
                        alibi_slopes=alibi_slopes,
                        bias_dtype=query_layer.dtype,
5962
                        bottom_right_alignment=attn_mask_type not in ["causal", "padding_causal"],
5963
                    )
5964
5965
5966
5967
5968
5969
5970
5971
5972
                if checkpoint_core_attention:
                    return self._checkpointed_attention_forward(
                        self.fused_attention,
                        query_layer,
                        key_layer,
                        value_layer,
                        qkv_layout=qkv_layout,
                        cu_seqlens_q=cu_seqlens_q,
                        cu_seqlens_kv=cu_seqlens_kv,
5973
5974
                        cu_seqlens_q_padded=cu_seqlens_q_padded,
                        cu_seqlens_kv_padded=cu_seqlens_kv_padded,
5975
5976
5977
5978
                        max_seqlen_q=max_seqlen_q,
                        max_seqlen_kv=max_seqlen_kv,
                        attn_mask_type=attn_mask_type,
                        attention_mask=attention_mask,
5979
                        window_size=window_size,
5980
5981
5982
5983
5984
5985
5986
                        fused_attention_backend=fused_attention_backend,
                        core_attention_bias_type=fu_core_attention_bias_type,
                        core_attention_bias=fu_core_attention_bias,
                        fast_zero_fill=fast_zero_fill,
                        cp_group=self.cp_group,
                        cp_global_ranks=self.cp_global_ranks,
                        cp_stream=self.cp_stream,
5987
                        cp_comm_type=self.cp_comm_type,
5988
5989
                        fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa,
                        fp8_meta=self.fp8_meta,
5990
                        pad_between_seqs=pad_between_seqs,
5991
5992
                    )
                return self.fused_attention(
5993
5994
5995
5996
5997
5998
                    query_layer,
                    key_layer,
                    value_layer,
                    qkv_layout=qkv_layout,
                    cu_seqlens_q=cu_seqlens_q,
                    cu_seqlens_kv=cu_seqlens_kv,
5999
6000
                    cu_seqlens_q_padded=cu_seqlens_q_padded,
                    cu_seqlens_kv_padded=cu_seqlens_kv_padded,
6001
6002
                    max_seqlen_q=max_seqlen_q,
                    max_seqlen_kv=max_seqlen_kv,
6003
6004
                    attn_mask_type=attn_mask_type,
                    attention_mask=attention_mask,
6005
                    window_size=window_size,
6006
                    fused_attention_backend=fused_attention_backend,
6007
6008
                    core_attention_bias_type=fu_core_attention_bias_type,
                    core_attention_bias=fu_core_attention_bias,
6009
6010
6011
6012
                    fast_zero_fill=fast_zero_fill,
                    cp_group=self.cp_group,
                    cp_global_ranks=self.cp_global_ranks,
                    cp_stream=self.cp_stream,
6013
                    cp_comm_type=self.cp_comm_type,
6014
6015
                    fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa,
                    fp8_meta=self.fp8_meta,
6016
                    quantizers=self.quantizers,
6017
                    pad_between_seqs=pad_between_seqs,
6018
                )
6019

6020
            from .cpu_offload import CPUOffloadEnabled
6021

6022
6023
6024
6025
6026
            if CPUOffloadEnabled:
                warnings.warn(
                    "Attention activation Offloading is only implemented"
                    "with Flash Attention and Fused Attention!"
                )
6027

6028
6029
6030
6031
6032
6033
6034
6035
6036
6037
6038
6039
            if use_unfused_attention:
                if checkpoint_core_attention:
                    return self._checkpointed_attention_forward(
                        self.unfused_attention,
                        query_layer,
                        key_layer,
                        value_layer,
                        qkv_layout=qkv_layout,
                        cu_seqlens_q=cu_seqlens_q,
                        cu_seqlens_kv=cu_seqlens_kv,
                        attn_mask_type=attn_mask_type,
                        attention_mask=attention_mask,
6040
                        window_size=window_size,
6041
6042
6043
6044
6045
                        core_attention_bias_type=core_attention_bias_type,
                        core_attention_bias=core_attention_bias,
                        alibi_slopes=alibi_slopes,
                    )
                return self.unfused_attention(
6046
6047
6048
                    query_layer,
                    key_layer,
                    value_layer,
6049
6050
6051
6052
6053
                    qkv_layout=qkv_layout,
                    cu_seqlens_q=cu_seqlens_q,
                    cu_seqlens_kv=cu_seqlens_kv,
                    attn_mask_type=attn_mask_type,
                    attention_mask=attention_mask,
6054
                    window_size=window_size,
6055
6056
6057
6058
                    core_attention_bias_type=core_attention_bias_type,
                    core_attention_bias=core_attention_bias,
                    alibi_slopes=alibi_slopes,
                )
6059

6060
            raise ValueError("No dot product attention support for the provided inputs!")
6061
6062


6063
6064
6065
6066
6067
6068
6069
class MultiheadAttention(torch.nn.Module):
    r"""
    Multi-head Attention (MHA), including Query,
    Key, Value and Output projection.

    .. note::

6070
6071
        Argument :attr:`attention_mask` in the `forward` call is only used when
        :attr:`attn_mask_type` includes '"padding"' or `"arbitrary"`.
6072

6073
6074
6075
6076
6077
6078
6079
6080
6081
6082
6083
6084
6085
6086
6087
6088
6089
6090
6091
6092
6093
6094
6095
6096
6097
    Parameters
    ----------
    hidden_size : int
                 size of each input sample.
    num_attention_heads : int
                         number of attention heads in the transformer layer.
    kv_channels: int, default = `None`
                number of key-value channels. defaults to
                :attr:`hidden_size` / :attr:`num_attention_heads` if `None`.
    attention_dropout: float, default = 0.1
                      dropout probability for the dropout op during multi-head attention.
    layernorm_epsilon : float, default = 1e-5
                       a value added to the denominator of layer normalization
                       for numerical stability.
    init_method : Callable, default = `None`
                 used for initializing weights of QKV and FC1 weights in the following way:
                 `init_method(weight)`. When set to `None`, defaults to
                 `torch.nn.init.normal_(mean=0.0, std=0.023)`.
    output_layer_init_method : Callable, default = `None`
                              used for initializing weights of PROJ and FC2 in the following way:
                              `output_layer_init_method(weight)`. When set to `None`, defaults to
                              `torch.nn.init.normal_(mean=0.0, std=0.023)`.
    layer_number: int, default = `None`
                 layer number of the current `TransformerLayer` when multiple such modules are
                 concatenated to form a transformer block.
6098
6099
    attn_mask_type: {'no_mask', 'padding', 'causal', 'padding_causal', 'causal_bottom_right',
                   'padding_causal_bottom_right','arbitrary'},
6100
                   default = `causal`
6101
6102
6103
6104
6105
                   type of attention mask passed into softmax operation. Overridden by
                   :attr:`attn_mask_type` in the `forward` method. The forward
                   arg is useful for dynamically changing mask types, e.g. a different
                   mask for training and inference. The init arg is useful for cases
                   involving compilation/tracing, e.g. ONNX export.
6106
6107
6108
6109
    window_size: Optional[Tuple[int, int]], default = `None`
                sliding window size for local attention, where query at position i attends to keys
                in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
                + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
6110
6111
6112
                window and causal mask specifically. Both `causal` and `causal_bottom_right` masks
                map to `window_size = (-1, 0)` and Transformer Engine distinguishes them based on
                `attn_mask_type`. Similar to :attr:`attn_mask_type`, `window_size` can
6113
                be overridden by :attr:`window_size` in `forward` as well.
6114
6115
6116
6117
6118
6119
6120
6121
6122
6123
6124
6125
6126
    num_gqa_groups : int, default = `None`
                         number of GQA groups in the transformer layer.
                         Grouped Query Attention is described in
                         `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
                         This only affects the keys and values, not the querys.
                         GQA-1 is equivalent to Multi-Query Attention
                         (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
                         is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
    return_layernorm_output : bool, default = `False`
                             if set to `True`, output of layernorm is returned from the forward
                             together with the output of the linear transformation.
                             Example use case: residual connection for transformer module is
                             taken post layernorm.
6127
6128
    input_layernorm: bool, default = `False`
                     if set to `True`, layer normalization to the input is applied.
6129
6130
6131
6132
6133
6134
6135
6136
6137
6138
6139
6140
6141
6142
6143
6144
6145
6146
6147
6148
    attention_type: { 'self', 'cross' }, default = 'self'
                   type of attention applied.
    zero_centered_gamma : bool, default = 'False'
                         if set to 'True', gamma parameter in LayerNorm is initialized to 0 and
                         the LayerNorm formula changes to

                         .. math::
                            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} *
                            (1 + \gamma) + \beta
    normalization : { 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm'
                   type of normalization applied.
    qkv_weight_interleaved : bool, default = `True`
                            if set to `False`, the QKV weight is interpreted as a concatenation of
                            query, key, and value weights along the `0th` dimension. The default
                            interpretation is that the individual `q`, `k`, and `v` weights for each
                            attention head are interleaved. This parameter is set to `False` when
                            using :attr:`fuse_qkv_params=False`.
    bias : bool, default = `True`
          if set to `False`, the transformer layer will not learn any additive biases.
    device : Union[torch.device, str], default = "cuda"
6149
          The device on which the parameters of the model will be allocated. It is the user's
6150
6151
          responsibility to ensure all parameters are moved to the GPU before running the
          forward pass.
6152
6153
6154
6155
6156
6157
6158
    qkv_format: str, default = `sbhd`
            dimension format for `query_layer`, `key_layer` and `value_layer`,
            {`sbhd`, `bshd`}. `s` stands for the sequence length, `b` batch size,
            `h` the number of heads and `d` head size. `sbhd` and `bshd` formats
            are used for when sequences in a batch are of equal length or padded to
            equal length. Please note that these formats do not reflect how
            tensors `query_layer`, `key_layer`, `value_layer` are laid out in memory.
6159
            For that, please use `get_qkv_layout` to gain the layout information.
6160
6161
6162
6163
6164
6165
6166
6167
6168
6169
6170
6171
6172
6173
6174
6175
6176
6177
6178
6179
6180
6181
6182
6183
6184
6185
6186
6187
6188
6189
6190
6191
6192
6193
6194
6195
6196
6197
6198
6199

    Parallelism parameters
    ----------------------
    set_parallel_mode : bool, default = `False`
                      if set to `True`, QKV and FC1 layers are used as Column Parallel
                      whereas PROJ and FC2 is used as Row Parallel as described
                      `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
    sequence_parallel : bool, default = `False`
                       if set to `True`, uses sequence parallelism.
    tp_group : ProcessGroup, default = `None`
              tensor parallel process group.
    tp_size : int, default = 1
             used as TP (tensor parallel) world size when TP groups are not formed during
             initialization. In this case, users must call the
             `set_tensor_parallel_group(tp_group)` method on the initialized module before the
             forward pass to supply the tensor parallel group needed for tensor and sequence
             parallel collectives.

    Optimization parameters
    -----------------------
    fuse_wgrad_accumulation : bool, default = 'False'
                             if set to `True`, enables fusing of creation and accumulation of
                             the weight gradient. When enabled, it is assumed that the weights
                             have an additional `main_grad` attribute (used instead of the
                             regular `grad`) which is a pre-allocated buffer of the correct
                             size to accumulate gradients in.
    params_dtype : torch.dtype, default = `torch.get_default_dtype()`
                  it controls the type used to allocate the initial parameters. Useful when
                  the model is trained with lower precision and the original FP32 parameters
                  would not fit in GPU memory.
    return_bias : bool, default = `False`
                 when set to `True`, this module will not apply the additive bias itself, but
                 instead return the bias value during the forward pass together with the
                 output of the linear transformation :math:`y = xA^T`. This is useful when
                 the bias addition can be fused to subsequent operations.
    fuse_qkv_params: bool, default = 'False'
                    if set to `True`, `TransformerLayer` module exposes a single fused
                    parameter for query-key-value. This enables optimizations such as QKV
                    fusion without concatentations/splits and also enables the argument
                    `fuse_wgrad_accumulation`.
6200
6201
6202
6203
6204
6205
    """

    def __init__(
        self,
        hidden_size: int,
        num_attention_heads: int,
6206
6207
6208
6209
6210
        kv_channels: Optional[int] = None,
        attention_dropout: float = 0.1,
        layernorm_epsilon: float = 1e-5,
        init_method: Optional[Callable] = None,
        output_layer_init_method: Optional[Callable] = None,
6211
        layer_number: Optional[int] = None,
6212
        attn_mask_type: str = "causal",
6213
        window_size: Optional[Tuple[int, int]] = None,
6214
6215
        tp_group: Optional[dist_group_type] = None,
        tp_size: int = 1,
6216
        num_gqa_groups: Optional[int] = None,
6217
6218
6219
        fuse_wgrad_accumulation: bool = False,
        get_rng_state_tracker: Optional[Callable] = None,
        sequence_parallel: bool = False,
6220
        params_dtype: Optional[torch.dtype] = None,
6221
        return_bias: bool = False,
6222
6223
6224
6225
6226
6227
6228
        return_layernorm_output: bool = False,
        input_layernorm: bool = False,
        attention_type: str = "self",
        set_parallel_mode: bool = False,
        fuse_qkv_params: bool = False,
        zero_centered_gamma: bool = False,
        qkv_weight_interleaved: bool = True,
6229
        ub_overlap_ag: bool = False,
6230
6231
6232
6233
        ub_overlap_rs: bool = False,
        ub_overlap_rs_dgrad: bool = False,
        ub_bulk_dgrad: bool = False,
        ub_bulk_wgrad: bool = False,
6234
        bias: bool = True,
6235
        normalization: str = "LayerNorm",
6236
        device: Union[torch.device, str] = "cuda",
6237
        qkv_format: str = "sbhd",
6238
6239
    ) -> None:
        super().__init__()
6240

6241
        self.qkv_format = qkv_format
6242
        self.attn_mask_type = attn_mask_type
6243
        self.window_size = dpa_utils.check_set_window_size(attn_mask_type, window_size)
6244
        self.layer_number = layer_number
6245
6246
6247
6248
6249
        self.input_layernorm = input_layernorm
        self.attention_type = attention_type
        self.get_rng_state_tracker = get_rng_state_tracker
        self.tp_group = tp_group
        self.return_layernorm_output = return_layernorm_output
6250
        self.params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
6251
        self.num_attention_heads = num_attention_heads
6252
        self.return_bias = return_bias
6253
6254
        self.cp_size = 1
        self.cp_rank = 0
6255
6256
6257
6258
6259
6260
6261

        kv_channels = kv_channels if kv_channels else (hidden_size // num_attention_heads)

        if init_method is None:
            init_method = get_default_init_method()
        if output_layer_init_method is None:
            output_layer_init_method = get_default_init_method()
6262
6263
6264
6265
6266

        if not fuse_qkv_params:
            qkv_weight_interleaved = False
        self.qkv_weight_interleaved = qkv_weight_interleaved

6267
6268
6269
        assert attention_type in AttnTypes, f"attention_type {attention_type} not supported"
        if layer_number is not None:
            assert layer_number > 0, "layer_number must be a positive integer"
6270
6271
6272
6273
6274
6275

        tp_size = tp_size if tp_group is None else get_distributed_world_size(tp_group)
        self.tp_size = tp_size
        self.sequence_parallel = (tp_size > 1) and sequence_parallel

        self.num_attention_heads_per_partition = divide(num_attention_heads, tp_size)
6276
6277
6278
6279
6280
6281
6282
        self.num_gqa_groups = num_attention_heads if num_gqa_groups is None else num_gqa_groups
        assert (
            num_attention_heads % self.num_gqa_groups == 0
        ), "The number of attention heads must be divisible by the number of GQA groups!"
        assert (
            self.num_gqa_groups % tp_size == 0
        ), "The number of GQA groups must be divisible by tensor parallel size!"
6283
        self.num_gqa_groups_per_partition = int(self.num_gqa_groups // tp_size)
6284
6285
6286
6287

        self.hidden_size_per_attention_head = kv_channels
        self.hidden_size_q = self.hidden_size_per_attention_head * num_attention_heads
        self.hidden_size_kv = self.hidden_size_per_attention_head * self.num_gqa_groups
6288
6289
6290
6291
6292
6293
6294

        common_gemm_kwargs = {
            "fuse_wgrad_accumulation": fuse_wgrad_accumulation,
            "tp_group": tp_group,
            "tp_size": tp_size,
            "get_rng_state_tracker": get_rng_state_tracker,
            "sequence_parallel": sequence_parallel,
6295
            "params_dtype": self.params_dtype,
6296
            "device": device,
6297
6298
6299
6300
        }

        qkv_parallel_mode = "column" if set_parallel_mode else None

cyanguwa's avatar
cyanguwa committed
6301
        if self.attention_type == "self":
6302
6303
            parameters_split = None
            if not fuse_qkv_params:
6304
6305
6306
6307
6308
6309
6310
                parameters_split = collections.OrderedDict(
                    [
                        ("query", self.hidden_size_q),
                        ("key", self.hidden_size_kv),
                        ("value", self.hidden_size_kv),
                    ]
                )
6311
6312
6313
            if self.input_layernorm:
                self.layernorm_qkv = LayerNormLinear(
                    hidden_size,
6314
                    self.hidden_size_q + 2 * self.hidden_size_kv,
6315
6316
6317
6318
6319
6320
                    eps=layernorm_epsilon,
                    init_method=init_method,
                    bias=bias,
                    return_bias=False,
                    parallel_mode=qkv_parallel_mode,
                    return_layernorm_output=return_layernorm_output,
cyanguwa's avatar
cyanguwa committed
6321
                    parameters_split=parameters_split,
6322
6323
6324
                    zero_centered_gamma=zero_centered_gamma,
                    ub_bulk_wgrad=ub_bulk_wgrad,
                    ub_bulk_dgrad=ub_bulk_dgrad,
Jaemin Choi's avatar
Jaemin Choi committed
6325
                    ub_overlap_rs_dgrad=ub_overlap_rs_dgrad,
6326
                    ub_overlap_ag=ub_overlap_ag,
6327
                    normalization=normalization,
Przemyslaw Tredak's avatar
Przemyslaw Tredak committed
6328
                    ub_name="qkv",
6329
6330
6331
6332
6333
                    **common_gemm_kwargs,
                )
            else:
                self.qkv = Linear(
                    hidden_size,
6334
                    self.hidden_size_q + 2 * self.hidden_size_kv,
6335
6336
6337
6338
                    init_method=init_method,
                    bias=bias,
                    return_bias=False,
                    parallel_mode=qkv_parallel_mode,
cyanguwa's avatar
cyanguwa committed
6339
                    parameters_split=parameters_split,
6340
6341
                    **common_gemm_kwargs,
                )
cyanguwa's avatar
cyanguwa committed
6342
        elif self.attention_type == "cross":
6343
6344
6345
            if self.input_layernorm:
                self.layernorm_query = LayerNormLinear(
                    hidden_size,
6346
                    self.hidden_size_q,
6347
6348
6349
6350
6351
                    eps=layernorm_epsilon,
                    init_method=init_method,
                    bias=bias,
                    return_bias=False,
                    parallel_mode=qkv_parallel_mode,
6352
                    parameters_split=("query",) if not fuse_qkv_params else None,
6353
6354
6355
6356
                    return_layernorm_output=return_layernorm_output,
                    zero_centered_gamma=zero_centered_gamma,
                    ub_bulk_wgrad=ub_bulk_wgrad,
                    ub_bulk_dgrad=ub_bulk_dgrad,
Jaemin Choi's avatar
Jaemin Choi committed
6357
                    ub_overlap_rs_dgrad=ub_overlap_rs_dgrad,
6358
                    ub_overlap_ag=ub_overlap_ag,
6359
                    normalization=normalization,
Przemyslaw Tredak's avatar
Przemyslaw Tredak committed
6360
                    ub_name="qkv",
6361
6362
6363
6364
6365
                    **common_gemm_kwargs,
                )
            else:
                self.query_layer = Linear(
                    hidden_size,
6366
                    self.hidden_size_q,
6367
6368
6369
6370
6371
6372
6373
6374
                    init_method=init_method,
                    bias=bias,
                    return_bias=False,
                    parallel_mode=qkv_parallel_mode,
                    **common_gemm_kwargs,
                )
            self.key_value = Linear(
                hidden_size,
6375
                2 * self.hidden_size_kv,
6376
6377
6378
6379
                init_method=init_method,
                bias=bias,
                return_bias=False,
                parallel_mode=qkv_parallel_mode,
6380
                parameters_split=("key", "value") if not fuse_qkv_params else None,
6381
6382
6383
6384
6385
6386
                **common_gemm_kwargs,
            )

        # Attention.
        self.core_attention = DotProductAttention(
            num_attention_heads,
6387
            self.hidden_size_per_attention_head,
6388
6389
            num_gqa_groups=self.num_gqa_groups,
            attention_dropout=attention_dropout,
6390
            qkv_format=self.qkv_format,
6391
6392
6393
6394
            tp_size=tp_size,
            get_rng_state_tracker=get_rng_state_tracker,
            sequence_parallel=sequence_parallel,
            tp_group=tp_group,
6395
            layer_number=self.layer_number,
6396
            attention_type=self.attention_type,
6397
6398
6399
6400
        )

        # Linear
        self.proj = Linear(
6401
            self.hidden_size_q,
6402
6403
6404
            hidden_size,
            init_method=output_layer_init_method,
            bias=bias,
6405
            return_bias=return_bias,
6406
            parallel_mode="row" if set_parallel_mode else None,
6407
6408
            ub_overlap_rs=ub_overlap_rs,
            ub_overlap_ag=ub_overlap_ag,
Przemyslaw Tredak's avatar
Przemyslaw Tredak committed
6409
            ub_name="proj",
6410
6411
6412
6413
            **common_gemm_kwargs,
        )

    def _allocate_memory(
6414
        self, inference_max_sequence_len: int, batch_size: int, dtype: torch.dtype
6415
    ) -> torch.Tensor:
6416
        """Allocates memory for KV cache."""
6417
6418
6419
        return torch.empty(
            inference_max_sequence_len,
            batch_size,
6420
            self.num_gqa_groups_per_partition,
6421
            self.hidden_size_per_attention_head,
6422
            dtype=dtype,
6423
6424
6425
6426
            device=torch.cuda.current_device(),
        )

    def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None:
6427
6428
6429
6430
6431
6432
6433
6434
6435
        """
        Set the tensor parallel group for the given
        module before executing the forward pass.

        Parameters
        ----------
        tp_group : ProcessGroup, default = `None`
                  tensor parallel process group.
        """
6436
6437
        self.tp_group = tp_group

6438
    def set_context_parallel_group(
6439
        self,
6440
        cp_group: Union[dist_group_type, List[dist_group_type], None],
6441
        cp_global_ranks: List[int],
6442
        cp_stream: torch.cuda.Stream,
6443
        cp_comm_type: str = "p2p",
6444
    ) -> None:
6445
6446
6447
6448
6449
6450
        """
        Set the context parallel attributes for the given
        module before executing the forward pass.

        Parameters
        ----------
6451
        cp_group : Union[ProcessGroup, List[ProcessGroup]]
6452
                  context parallel process group.
6453
6454
6455
                  ProcessGroup is for cp_comm_type of "p2p", "all_gather", and "a2a".
                  List[ProcessGroup] is for cp_comm_type of "a2a+p2p", where cp_group[0]
                  and cp_group[1] are for a2a and p2p communications respectively.
6456
6457
6458
6459
        cp_global_ranks : List[int]
                         list of global ranks in the context group.
        cp_stream : torch.cuda.Stream
                   cuda stream for context parallel execution.
6460
        cp_comm_type : str, default = `p2p`
6461
                      inter-gpu communication type for context parallelism.
6462
                      Can be "p2p" or "all_gather" or "a2a", "a2a+p2p".
6463
6464
6465
6466
6467
6468
                      "p2p": Exchange KV chunks with P2P communications in ring topology.
                             P2P is async and can be overlapped with attention compute.
                      "all_gather": All-gather to get full sequence of KV before attention.
                                    The all-gather is not async, and cannot be overlapped.
                      "a2a": Like DeepSpeed Ulysses, scatter attention heads across the CP
                             group, and gather to get full sequence of QKV.
6469
6470
6471
                      "a2a+p2p": hierarchical CP implementation. First applying a2a to QKV
                      across each CP sub-group (e.g., via NVLink), then exchanging KV with
                      p2p between sub-groups (e.g., via IBLink).
6472
        """
6473
6474
6475
6476
6477
6478
6479
6480
6481
6482
6483
6484
6485
6486
6487
        if isinstance(cp_group, dist_group_type):
            self.cp_size = get_distributed_world_size(cp_group)
            self.cp_rank = get_distributed_rank(cp_group)
        elif isinstance(cp_group, list):
            assert len(cp_group) == 2, "Current implementation only supports two-level CP groups!"
            assert (
                cp_comm_type == "a2a+p2p"
            ), "Only cp_comm_type of a2a+p2p requires hierarchical CP groups!"
            cp_size_a2a = get_distributed_world_size(cp_group[0])
            cp_rank_a2a = get_distributed_rank(cp_group[0])
            cp_size_p2p = get_distributed_world_size(cp_group[1])
            cp_rank_p2p = get_distributed_rank(cp_group[1])
            self.cp_size = cp_size_a2a * cp_size_p2p
            self.cp_rank = cp_size_a2a * cp_rank_p2p + cp_rank_a2a

6488
6489
6490
6491
6492
        # Deep iterate but skip self to avoid infinite recursion.
        for index, child in enumerate(self.modules()):
            if index == 0:
                continue
            if hasattr(child, "set_context_parallel_group"):
6493
                child.set_context_parallel_group(cp_group, cp_global_ranks, cp_stream, cp_comm_type)
6494

6495
6496
6497
    def forward(
        self,
        hidden_states: torch.Tensor,
6498
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
6499
        encoder_output: Optional[torch.Tensor] = None,
6500
        attn_mask_type: Optional[str] = None,
6501
        window_size: Optional[Tuple[int, int]] = None,
6502
6503
        is_first_microbatch: Optional[bool] = None,
        checkpoint_core_attention: bool = False,
6504
        inference_params: Optional[InferenceParams] = None,
6505
        rotary_pos_emb: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
6506
6507
        core_attention_bias_type: str = "no_bias",
        core_attention_bias: Optional[torch.Tensor] = None,
6508
        alibi_slopes: Optional[torch.Tensor] = None,
6509
6510
6511
6512
        cu_seqlens_q: Optional[torch.Tensor] = None,
        cu_seqlens_kv: Optional[torch.Tensor] = None,
        max_seqlen_q: Optional[int] = None,
        max_seqlen_kv: Optional[int] = None,
6513
        fast_zero_fill: bool = True,
6514
        pad_between_seqs: Optional[bool] = None,
6515
    ) -> Tuple[Union[torch.Tensor, None], ...]:
6516
6517
6518
6519
6520
        """
        Forward propagation for MultiheadAttention layer.

        .. note::

6521
6522
            Argument :attr:`attention_mask` is only used when :attr:`attn_mask_type`
            includes `"padding"` or `"arbitrary"`.
6523
6524
6525
6526
6527

        Parameters
        ----------
        hidden_states : torch.Tensor
             Input tensor.
6528
6529
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]],
             default = `None`. Boolean tensor(s) used to mask out attention softmax input.
6530
             It should be `None` for causal masks and "`no_mask`". For padding masks, it should be
6531
6532
             a single tensor of [batch_size, 1, 1, seqlen_q] for self-attention, and a tuple of
             two tensors in shapes [batch_size, 1, 1, seqlen_q] and [batch_size, 1, 1, seqlen_kv]
6533
6534
6535
6536
6537
6538
             for cross-attention. For "`arbitrary`" mask, it should be in a shape broadcastable to
             [batch_size, num_heads, max_seqlen_q, max_seqlen_kv]. A `True` value means
             the corresponding position is masked out and a `False` means that position
             is allowed to participate in attention.
        attn_mask_type: {'no_mask', 'padding', 'causal', 'padding_causal', 'causal_bottom_right',
                       'padding_causal_bottom_right','arbitrary'},
6539
                       default = `None`
6540
6541
6542
6543
                       type of attention mask passed into softmax operation. By default,
                       causal masks are aligned to the top left corner of the softmax matrix.
                       When "`bottom_right`" is specified in the mask type, causal masks are
                       aligned to the bottom right corner.
6544
6545
        window_size: Optional[Tuple[int, int]], default = `None`
                    sliding window size for local attention.
6546
6547
6548
6549
6550
6551
6552
6553
6554
6555
6556
6557
6558
6559
6560
6561
6562
6563
6564
6565
6566
6567
6568
6569
6570
        encoder_output : Optional[torch.Tensor], default = `None`
             Output of the encoder block to be fed into the decoder block if using
             `layer_type="decoder"`.
        is_first_microbatch : {True, False, None}, default = None
                             During training using either gradient accumulation or
                             pipeline parallelism a minibatch of data is further split
                             into microbatches. Between the microbatches of the same minibatch
                             the model weights are not updated. Setting this parameter indicates
                             whether the current microbatch is the first in a minibatch or not.
                             When set, this parameter enables additional optimizations:

                             * during FP8 training, it allows caching of the FP8 versions of
                               the weights
                             * it also allows skipping gradient accumulation during the
                               first microbatch (since it is the first gradient being
                               produced)
        checkpoint_core_attention: bool, default = `False`
                                  If true, forward activations for core attention are recomputed
                                  during the backward pass in order to save memory that would
                                  otherwise be occupied to store the forward activations until
                                  backprop.
        rotary_pos_emb: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], default = `None`
                       Embeddings for query and key tensors for applying rotary position
                       embedding. By default no input embedding is applied.
        core_attention_bias_type: str, default = `no_bias`
6571
                    Bias type, {`no_bias`, `pre_scale_bias`, 'post_scale_bias`, `alibi`}
6572
        core_attention_bias: Optional[torch.Tensor], default = `None`
6573
6574
                    Bias tensor for Q * K.T, shape [1, num_head, max_seqlen_q, max_seqlen_kv].
                    It should be 'None' for 'no_bias' and 'alibi' bias types.
6575
6576
6577
6578
        alibi_slopes: Optional[torch.Tensor], default = `None`
                     ALiBi slopes in FP32 and shape [nheads] or [batch_size, nheads].
                     It adds a bias of (-alibi_slope * (i + seqlen_k - seqlen_q - j))
                     to the attention score of query i and key j.
6579
6580
6581
6582
6583
6584
6585
6586
6587
6588
6589
6590
        cu_seqlens_q: Optional[torch.Tensor], default = `None`
                   Cumulative sum of sequence lengths (without offset) in a batch for `query_layer`,
                   with shape [batch_size + 1] and dtype torch.int32.
        cu_seqlens_kv: Optional[torch.Tensor], default = `None`
                   Cumulative sum of sequence lengths (without offset) in a batch for `key_layer`
                   and `value_layer`, with shape [batch_size + 1] and dtype torch.int32.
        max_seqlen_q: Optional[int], default = `None`
                      Maximum sequence length in `query_layer`.
                      Calculated from `cu_seqlens_q` if not provided.
        max_seqlen_kv: Optional[int], default = `None`
                       Maximum sequence length in `key_layer` and `value_layer`.
                       Calculated from `cu_seqlens_kv` if not provided.
6591
6592
        fast_zero_fill: bool, default = `True`
                    Whether to set output tensors to 0 or not before use.
6593
6594
6595
        pad_between_seqs: Optional[bool], default = `None`
            If None, inferred from qkv_format, cu_seqlens and cu_seqlens_padded.
            If true, there are padding tokens between individual sequences in a packed batch.
6596
        """
6597
6598
        # hidden_states: [sq, b, h]

6599
        if attn_mask_type is None:
6600
            attn_mask_type = self.attn_mask_type
6601
6602
        if window_size is None:
            window_size = self.window_size
6603
        window_size = dpa_utils.check_set_window_size(attn_mask_type, window_size)
6604

6605
        if "padding" in attn_mask_type and attention_mask is not None:
6606
6607
            for mask in attention_mask:
                assert mask.dtype == torch.bool, "Attention mask must be in boolean type!"
6608

6609
6610
6611
        assert (
            core_attention_bias_type in AttnBiasTypes
        ), f"core_attention_bias_type {core_attention_bias_type} is not supported!"
6612

6613
        # =================================================
6614
        # Pre-allocate memory for key-values for inference
6615
6616
6617
        # =================================================

        if inference_params and self.layer_number is not None:
6618
6619
6620
            assert (
                self.qkv_format != "thd"
            ), "qkv_format == thd is not supported for an inference with KV-cache!"
6621
            if self.layer_number not in inference_params.key_value_memory_dict:
6622
                inf_max_seq_len = inference_params.max_sequence_length
6623
6624
                inf_max_batch_size = inference_params.max_batch_size
                inference_key_memory = self._allocate_memory(
6625
                    inf_max_seq_len, inf_max_batch_size, hidden_states.dtype
6626
6627
                )
                inference_value_memory = self._allocate_memory(
6628
                    inf_max_seq_len, inf_max_batch_size, hidden_states.dtype
6629
6630
6631
6632
6633
6634
6635
6636
6637
6638
6639
                )
                inference_params.key_value_memory_dict[self.layer_number] = (
                    inference_key_memory,
                    inference_value_memory,
                )
            else:
                (
                    inference_key_memory,
                    inference_value_memory,
                ) = inference_params.key_value_memory_dict[self.layer_number]

6640
        # ======================
6641
        # Query, Key, and Value
6642
        # ======================
6643

6644
6645
6646
6647
6648
        fp8_mha = (
            FP8GlobalStateManager.is_fp8_enabled()
            and FP8GlobalStateManager.get_fp8_recipe().fp8_mha
        )

6649
        layernorm_output = None
cyanguwa's avatar
cyanguwa committed
6650
6651
        if self.attention_type == "self":
            # Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn]
6652
6653
6654
6655
            if self.input_layernorm:
                layernorm_qkv_outputs = self.layernorm_qkv(
                    hidden_states,
                    is_first_microbatch=is_first_microbatch,
6656
                    fp8_output=fp8_mha and rotary_pos_emb is None,
6657
6658
6659
6660
6661
6662
6663
6664
6665
                )
                if self.return_layernorm_output:
                    mixed_x_layer, layernorm_output = layernorm_qkv_outputs
                else:
                    mixed_x_layer = layernorm_qkv_outputs
            else:
                mixed_x_layer = self.qkv(
                    hidden_states,
                    is_first_microbatch=is_first_microbatch,
6666
                    fp8_output=fp8_mha and rotary_pos_emb is None,
6667
6668
                )

6669
6670
6671
            num_queries_per_key_value = (
                self.num_attention_heads_per_partition // self.num_gqa_groups_per_partition
            )
6672
            if self.qkv_weight_interleaved:
cyanguwa's avatar
cyanguwa committed
6673
                # [sq, b, ng * (np/ng + 2) * hn] --> [sq, b, ng, (np/ng + 2), hn]
6674
                new_tensor_shape = mixed_x_layer.size()[:-1] + (
cyanguwa's avatar
cyanguwa committed
6675
6676
                    self.num_gqa_groups_per_partition,
                    (num_queries_per_key_value + 2),
6677
6678
6679
6680
                    self.hidden_size_per_attention_head,
                )
                # split along second last dimension
                split_dim = -2
cyanguwa's avatar
cyanguwa committed
6681
6682
6683
6684
6685
            else:
                # [sq, b, ng * (np/ng + 2) * hn] --> [sq, b, (np/ng + 2), ng, hn]
                new_tensor_shape = mixed_x_layer.size()[:-1] + (
                    (num_queries_per_key_value + 2),
                    self.num_gqa_groups_per_partition,
6686
                    self.hidden_size_per_attention_head,
cyanguwa's avatar
cyanguwa committed
6687
6688
6689
                )
                # split along third last dimension
                split_dim = -3
6690
6691
6692

            mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)

cyanguwa's avatar
cyanguwa committed
6693
6694
6695
6696
6697
6698
            # qkv_weight_interleaved:
            #  [sq, b, ng, (np/ng + 2), hn]
            #  --> [sq, b, ng, np/ng, hn], [sq, b, ng, 1, hn], [sq, b, ng, 1, hn]
            # not qkv_weight_interleaved:
            #  [sq, b, (np/ng + 2), ng, hn]
            #  --> [sq, b, np/ng, np, hn], [sq, b, 1, ng, hn], [sq, b, 1, ng, hn]
6699
6700
6701
            query_layer, key_layer, value_layer = _SplitAlongDim.apply(
                mixed_x_layer, split_dim, (num_queries_per_key_value, 1, 1)
            )
cyanguwa's avatar
cyanguwa committed
6702

6703
6704
6705
6706
6707
6708
6709
6710
6711
6712
6713
6714
            if self.qkv_format == "thd":
                query_layer, key_layer, value_layer = (
                    x.reshape(x.size(0), -1, self.hidden_size_per_attention_head)
                    for x in (query_layer, key_layer, value_layer)
                )
            else:
                # query: -> [sq, b, np, hn]
                # key, value: -> [sq, b, ng, hn]
                query_layer, key_layer, value_layer = (
                    x.reshape(x.size(0), x.size(1), -1, self.hidden_size_per_attention_head)
                    for x in (query_layer, key_layer, value_layer)
                )
cyanguwa's avatar
cyanguwa committed
6715
6716
        elif self.attention_type == "cross":
            # Attention heads [sk, b, h] --> [sk, b, (ng * 2 * hn)]
6717
            mixed_kv_layer = self.key_value(
cyanguwa's avatar
cyanguwa committed
6718
                encoder_output,
6719
                is_first_microbatch=is_first_microbatch,
6720
                fp8_output=fp8_mha and rotary_pos_emb is None,
6721
6722
6723
            )

            if self.qkv_weight_interleaved:
cyanguwa's avatar
cyanguwa committed
6724
                # [sq, b, (ng * 2 * hn)] --> [sq, b, ng, 2 * hn]
6725
                new_tensor_shape = mixed_kv_layer.size()[:-1] + (
6726
                    self.num_gqa_groups_per_partition,
6727
6728
6729
6730
6731
                    2 * self.hidden_size_per_attention_head,
                )
                # split along last dimension
                split_dim = -1
            else:
cyanguwa's avatar
cyanguwa committed
6732
                # [sq, b, (ng * 2 * hn)] --> [sq, b, 2 * ng, hn]
6733
                new_tensor_shape = mixed_kv_layer.size()[:-1] + (
6734
                    2 * self.num_gqa_groups_per_partition,
6735
6736
6737
6738
6739
6740
6741
                    self.hidden_size_per_attention_head,
                )
                # split along second last dimension
                split_dim = -2

            mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)

cyanguwa's avatar
cyanguwa committed
6742
            # mixed_kv_layer --> 2 [sk, b, ng, hn]
6743
6744
6745
6746
6747
            key_layer, value_layer = _SplitAlongDim.apply(
                mixed_kv_layer,
                split_dim,
                mixed_kv_layer.shape[split_dim] // 2,
            )
6748
6749
6750
6751
6752
6753
6754
6755
6756
            key_layer, value_layer = (
                x.reshape(
                    x.size(0),
                    x.size(1),
                    -1,
                    self.hidden_size_per_attention_head,
                )
                for x in (key_layer, value_layer)
            )
6757
6758
6759
6760
6761
6762

            # Attention head [sq, b, h] --> [sq, b, hp]
            if self.input_layernorm:
                layernorm_query_outputs = self.layernorm_query(
                    hidden_states,
                    is_first_microbatch=is_first_microbatch,
6763
                    fp8_output=fp8_mha and rotary_pos_emb is None,
6764
6765
6766
6767
6768
6769
6770
6771
6772
                )
                if self.return_layernorm_output:
                    query_layer, layernorm_output = layernorm_query_outputs
                else:
                    query_layer = layernorm_query_outputs
            else:
                query_layer = self.query_layer(
                    hidden_states,
                    is_first_microbatch=is_first_microbatch,
6773
                    fp8_output=fp8_mha and rotary_pos_emb is None,
6774
6775
6776
6777
6778
6779
6780
6781
6782
                )

            # [sq, b, hp] --> [sq, b, np, hn]
            new_tensor_shape = query_layer.size()[:-1] + (
                self.num_attention_heads_per_partition,
                self.hidden_size_per_attention_head,
            )
            query_layer = query_layer.view(*new_tensor_shape)

6783
6784
6785
        # ======================================================
        # Apply relative positional encoding (rotary embedding)
        # ======================================================
6786

6787
        if rotary_pos_emb is not None:
6788
6789
6790
            assert not isinstance(query_layer, Float8Tensor) and not isinstance(
                key_layer, Float8Tensor
            ), "RoPE is not supported for Float8Tensors!"
6791
            # duplicate the pos_emb for self attention
6792
            if not isinstance(rotary_pos_emb, tuple):
6793
                rotary_pos_emb = (rotary_pos_emb,) * 2
6794
6795

            q_pos_emb, k_pos_emb = rotary_pos_emb
6796
6797
6798
6799
6800
6801
6802

            # adjust key and value for inference
            if inference_params is not None:
                if self.qkv_format == "sbhd":
                    sequence_length = key_layer.size(0)
                elif self.qkv_format == "bshd":
                    sequence_length = key_layer.size(1)
6803
6804
                else:
                    raise ValueError(f"QKV format {self.qkv_format} not supported for KV caching.")
6805
6806
6807
6808
6809
6810
6811

                sequence_start = inference_params.sequence_len_offset
                sequence_end = sequence_start + sequence_length

                q_pos_emb = q_pos_emb[sequence_start:sequence_end, ...]
                k_pos_emb = k_pos_emb[sequence_start:sequence_end, ...]

6812
6813
6814
6815
6816
6817
6818
6819
6820
6821
6822
6823
6824
6825
6826
6827
6828
6829
            query_layer = apply_rotary_pos_emb(
                query_layer,
                q_pos_emb,
                self.qkv_format,
                fused=True,
                cu_seqlens=cu_seqlens_q,
                cp_size=self.cp_size,
                cp_rank=self.cp_rank,
            )
            key_layer = apply_rotary_pos_emb(
                key_layer,
                k_pos_emb,
                self.qkv_format,
                fused=True,
                cu_seqlens=cu_seqlens_kv,
                cp_size=self.cp_size,
                cp_rank=self.cp_rank,
            )
6830

6831
6832
6833
6834
        # ===========================
        # Core attention computation
        # ===========================

6835
6836
6837
6838
        context_layer = self.core_attention(
            query_layer,
            key_layer,
            value_layer,
6839
            qkv_format=self.qkv_format,
6840
6841
6842
6843
            cu_seqlens_q=cu_seqlens_q,
            cu_seqlens_kv=cu_seqlens_kv,
            max_seqlen_q=max_seqlen_q,
            max_seqlen_kv=max_seqlen_kv,
6844
6845
            attention_mask=attention_mask,
            attn_mask_type=attn_mask_type,
6846
            window_size=window_size,
6847
6848
6849
            checkpoint_core_attention=checkpoint_core_attention,
            core_attention_bias_type=core_attention_bias_type,
            core_attention_bias=core_attention_bias,
6850
            alibi_slopes=alibi_slopes,
6851
            fast_zero_fill=fast_zero_fill,
6852
            inference_params=inference_params,
6853
            pad_between_seqs=pad_between_seqs,
6854
6855
        )

6856
        # ===================
6857
        # Output. [sq, b, h]
6858
        # ===================
6859
        projection_output = self.proj(
6860
6861
            context_layer,
            is_first_microbatch=is_first_microbatch,
6862
            fp8_grad=isinstance(context_layer, QuantizedTensor),
6863
6864
        )

6865
6866
6867
6868
6869
6870
6871
6872
        if self.return_bias:
            attention_output, attention_bias = projection_output
        else:
            attention_output, attention_bias = projection_output, None

        outputs = (attention_output,)
        if self.return_bias:
            outputs += (attention_bias,)
6873
        if self.input_layernorm and self.return_layernorm_output:
6874
6875
            outputs += (layernorm_output,)
        return outputs if len(outputs) > 1 else outputs[0]