attention.py 326 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.

"""Attention."""
6
import collections
7
from contextlib import nullcontext
8
from importlib.metadata import version as get_pkg_version
9
from importlib.metadata import PackageNotFoundError
10
import math
11
import os
12
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
13
import warnings
14
import logging
15

cyanguwa's avatar
cyanguwa committed
16
import numpy as np
17
from packaging.version import Version as PkgVersion
18
19
20

import torch

21
import transformer_engine_torch as tex
22
23
24
25
26
from transformer_engine.pytorch.utils import (
    get_cudnn_version,
    nvtx_range_pop,
    nvtx_range_push,
)
27
from transformer_engine.pytorch.cpp_extensions.fused_attn import (
28
29
    fused_attn_fwd,
    fused_attn_bwd,
30
    FusedAttnBackend,
31
32
33
34
35
36
37
    META_QKV,
    META_O,
)
from transformer_engine.pytorch.fp8 import (
    FP8GlobalStateManager,
    get_fp8_te_dtype,
    get_fp8_torch_dtype,
38
)
39
from transformer_engine.pytorch.float8_tensor import Float8Tensor
40
from transformer_engine.pytorch.tensor._internal.float8_tensor_base import Float8TensorBase
41
from transformer_engine.pytorch.module import LayerNormLinear, Linear
42
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
43
44
45
46
47
from transformer_engine.pytorch.utils import (
    divide,
    attention_mask_func,
    split_tensor_along_dim,
    get_device_compute_capability,
48
    get_default_init_method,
49
50
51
52
)
from transformer_engine.pytorch.constants import (
    AttnMaskTypes,
    AttnTypes,
53
    AttnBiasTypes,
54
    QKVLayouts,
55
    dist_group_type,
56
    TE_DType,
57
58
59
60
)
from transformer_engine.pytorch.softmax import FusedScaleMaskSoftmax
from transformer_engine.pytorch.distributed import (
    get_distributed_world_size,
61
    get_distributed_rank,
62
    checkpoint,
63
64
65
    set_all_rng_states,
    CudaRNGStatesTracker,
    graph_safe_rng_available,
66
67
    gather_along_first_dim,
    reduce_scatter_along_first_dim,
68
)
69
from transformer_engine.pytorch.jit import jit_fuser, no_torch_dynamo
70
from transformer_engine.pytorch.graph import is_graph_capturing
71
from transformer_engine.pytorch.dot_product_attention.inference import InferenceParams
72
73
74
75
76
from transformer_engine.pytorch.tensor.quantized_tensor import (
    QuantizedTensor,
    prepare_for_saving,
    restore_from_saved,
)
77

78
79
80
81
82
# Import attention utils
import transformer_engine.pytorch.dot_product_attention.utils as dpa_utils
from transformer_engine.pytorch.dot_product_attention.utils import FlashAttentionUtils as fa_utils
from transformer_engine.pytorch.dot_product_attention.utils import AttentionLogging as attn_log
from transformer_engine.pytorch.dot_product_attention.rope import apply_rotary_pos_emb
83
84


85
86
87
# Setup Attention Logging
attn_log.setup_logging()

88
# Global vars for flash attn v2 and v3 imports
89
flash_attn_cuda_bwd = None
90
91
flash_attn_func = None
flash_attn_varlen_func = None
92
93
94
95
_flash_attn_fwd = None
_flash_attn_bwd = None
_flash_attn_varlen_fwd = None
_flash_attn_varlen_bwd = None
96
try:
97
    fa_utils.version = PkgVersion(get_pkg_version("flash-attn"))
98
except PackageNotFoundError:
99
    pass  # only print warning if use_flash_attention_2 = True in get_attention_backend
100
else:
101
    if torch.cuda.is_available() and get_device_compute_capability() >= (10, 0):
102
103
104
105
        if fa_utils.version_required_blackwell <= fa_utils.version <= fa_utils.max_version:
            fa_utils.is_installed = True
    elif fa_utils.version_required <= fa_utils.version <= fa_utils.max_version:
        fa_utils.is_installed = True
106

107
    if fa_utils.is_installed:
108
        from flash_attn_2_cuda import varlen_bwd as flash_attn_cuda_bwd
109
        from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func
110
111
        from flash_attn.flash_attn_interface import _flash_attn_forward as _flash_attn_fwd
        from flash_attn.flash_attn_interface import _flash_attn_backward as _flash_attn_bwd
112
        from flash_attn.flash_attn_interface import (
113
            _flash_attn_varlen_forward as _flash_attn_varlen_fwd,
114
115
        )
        from flash_attn.flash_attn_interface import (
116
            _flash_attn_varlen_backward as _flash_attn_varlen_bwd,
117
118
        )

119
120
        # Setup Flash attention utils
        fa_utils.set_flash_attention_version()
121
    elif (
122
123
124
        torch.cuda.is_available()
        and get_device_compute_capability() >= (8, 0)
        and dpa_utils._NVTE_FLASH_ATTN
125
    ):
126
        attn_log.fa_logger.warning(
127
            "Supported flash-attn versions are %s. Found flash-attn %s.",
128
            dpa_utils._get_supported_versions(
129
                (
130
                    fa_utils.version_required
131
                    if get_device_compute_capability() < (10, 0)
132
                    else fa_utils.version_required_blackwell
133
                ),
134
                fa_utils.max_version,
135
            ),
136
            fa_utils.version,
137
        )
138
try:
139
    fa_utils.fa3_version = PkgVersion(get_pkg_version("flash-attn-3"))
140
except PackageNotFoundError:
141
    pass  # only print warning if use_flash_attention_3 = True in get_attention_backend
142
else:
143
144
    from flash_attn_3.flash_attn_interface import flash_attn_func as flash_attn_func_v3
    from flash_attn_3.flash_attn_interface import (
145
146
        flash_attn_varlen_func as flash_attn_varlen_func_v3,
    )
147
148
    from flash_attn_3.flash_attn_interface import (
        flash_attn_with_kvcache as flash_attn_with_kvcache_v3,
149
    )
150
151
    from flash_attn_3.flash_attn_interface import _flash_attn_forward as _flash_attn_fwd_v3
    from flash_attn_3.flash_attn_interface import _flash_attn_backward as _flash_attn_bwd_v3
152

153
    fa_utils.set_flash_attention_3_params()
154

155
# Global vars for available attention backends and ALiBi cache
156
157
158
_attention_backends = {
    "attention_params": None,
    "use_flash_attention": None,
159
    "flash_attention_backend": None,
160
161
162
163
    "use_fused_attention": None,
    "fused_attention_backend": None,
    "use_unfused_attention": None,
    "backend_selection_requires_update": False,
164
}
165

166
167
168
169
170
171
172
173
174
175
176
_alibi_cache = {
    "_num_heads": None,
    "_alibi_slopes": None,
    "_max_seqlen_q": None,
    "_max_seqlen_kv": None,
    "_bottom_right_alignment": True,
    "_alibi_bias": None,
    "_alibi_slopes_require_update": False,
    "_alibi_bias_require_update": False,
}

177
__all__ = ["DotProductAttention", "MultiheadAttention"]
178
179


180
181
182
183
184
def maybe_contiguous(tensor: torch.Tensor) -> torch.Tensor:
    """Make tensor contiguous if final stride is not 1."""
    return tensor.contiguous() if tensor.stride(-1) != 1 else tensor


185
186
187
def flash_attn_p2p_communicate(
    rank, send_tensor, send_dst, recv_tensor, recv_src, cp_group, batch_p2p_comm
):
188
    """Point-to-point communications of KV and dKV in Attention with context parallelism"""
189
190
191
192
    send_recv_ops = []

    if batch_p2p_comm:
        if rank % 2 == 0:
193
194
195
196
197
198
            send_op = torch.distributed.P2POp(
                torch.distributed.isend, send_tensor, send_dst, cp_group
            )
            recv_op = torch.distributed.P2POp(
                torch.distributed.irecv, recv_tensor, recv_src, cp_group
            )
199
200
201
            send_recv_ops.append(send_op)
            send_recv_ops.append(recv_op)
        else:
202
203
204
205
206
207
            recv_op = torch.distributed.P2POp(
                torch.distributed.irecv, recv_tensor, recv_src, cp_group
            )
            send_op = torch.distributed.P2POp(
                torch.distributed.isend, send_tensor, send_dst, cp_group
            )
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
            send_recv_ops.append(recv_op)
            send_recv_ops.append(send_op)
        send_recv_reqs = torch.distributed.batch_isend_irecv(send_recv_ops)
    else:
        if rank % 2 == 0:
            send_op = torch.distributed.isend(send_tensor, send_dst, cp_group)
            recv_op = torch.distributed.irecv(recv_tensor, recv_src, cp_group)
            send_recv_ops.append(send_op)
            send_recv_ops.append(recv_op)
        else:
            recv_op = torch.distributed.irecv(recv_tensor, recv_src, cp_group)
            send_op = torch.distributed.isend(send_tensor, send_dst, cp_group)
            send_recv_ops.append(recv_op)
            send_recv_ops.append(send_op)
        send_recv_reqs = send_recv_ops

    return send_recv_reqs


227
228
229
230
231
232
233
234
235
236
237
238
239
240
@jit_fuser
def flash_attn_fwd_out_correction_init(
    out_init_step: torch.Tensor,
    softmax_lse: torch.Tensor,
    softmax_lse_init_step: torch.Tensor,
    seq_dim: int,
):
    """Merge partial outputs of the first step in Attention with context parallelism"""
    softmax_lse_corrected_exp = torch.exp(softmax_lse_init_step - softmax_lse).movedim(2, seq_dim)
    softmax_lse_corrected_exp = softmax_lse_corrected_exp.unsqueeze(-1)
    out_corrected = out_init_step * softmax_lse_corrected_exp
    return out_corrected.to(out_init_step.dtype)


241
@jit_fuser
242
243
244
245
246
def flash_attn_fwd_out_correction(
    out: torch.Tensor,
    out_per_step: torch.Tensor,
    softmax_lse: torch.Tensor,
    softmax_lse_per_step: torch.Tensor,
247
    seq_dim: int,
248
):
249
    """Merge partial outputs of each step in Attention with context parallelism"""
250
    softmax_lse_corrected_exp = torch.exp(softmax_lse_per_step - softmax_lse).movedim(2, seq_dim)
251
    softmax_lse_corrected_exp = softmax_lse_corrected_exp.unsqueeze(-1)
252
    out_corrected = out_per_step * softmax_lse_corrected_exp
253
254
255
    out.add_(out_corrected)


256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
@jit_fuser
def flash_attn_fwd_second_half_out_correction(
    out: torch.Tensor,
    out_per_step: torch.Tensor,
    softmax_lse: torch.Tensor,
    softmax_lse_per_step: torch.Tensor,
    seq_dim: int,
):
    """Merge second half of partial outputs of each step in Attention with context parallelism"""
    out_ = out.select(seq_dim, 1)
    softmax_lse_ = softmax_lse.view(*softmax_lse.shape[:-1], 2, -1)[..., 1, :]
    softmax_lse_corrected_exp = torch.exp(softmax_lse_per_step - softmax_lse_).movedim(2, seq_dim)
    softmax_lse_corrected_exp = softmax_lse_corrected_exp.unsqueeze(-1)
    out_corrected = out_per_step * softmax_lse_corrected_exp
    out_.add_(out_corrected)


273
@jit_fuser
274
275
276
277
def flash_attn_fwd_softmax_lse_correction(
    softmax_lse: torch.Tensor,
    softmax_lse_per_step: torch.Tensor,
):
278
    """Merge softmax stats of each step in Attention with context parallelism"""
279
280
    max_scale = torch.max(softmax_lse, softmax_lse_per_step)
    min_scale = torch.min(softmax_lse, softmax_lse_per_step)
281
    new_scale = max_scale + torch.log1p(torch.exp(min_scale - max_scale))
282
    softmax_lse.copy_(new_scale)
283
284


285
286
287
288
289
290
291
292
293
294
295
296
297
@jit_fuser
def flash_attn_fwd_second_half_softmax_lse_correction(
    softmax_lse: torch.Tensor,
    softmax_lse_per_step: torch.Tensor,
):
    """Merge second half of softmax stats of each step in Attention with context parallelism"""
    softmax_lse_ = softmax_lse[..., 1, :]
    max_scale = torch.max(softmax_lse_, softmax_lse_per_step)
    min_scale = torch.min(softmax_lse_, softmax_lse_per_step)
    new_scale = max_scale + torch.log1p(torch.exp(min_scale - max_scale))
    softmax_lse_.copy_(new_scale)


298
299
@jit_fuser
def get_cu_seqlens_on_cp_rank(
300
301
302
303
304
305
    cu_seqlens: torch.Tensor,
    cu_seqlens_padded_on_cp_rank: torch.Tensor,
    cp_size: int,
    cp_rank: int,
    first_half: bool,
    second_half: bool,
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
):
    """Compute cu_seqlens of a context parallelism rank"""
    seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
    seqlens_padded = (cu_seqlens_padded_on_cp_rank[1:] - cu_seqlens_padded_on_cp_rank[:-1]) // 2
    zeros = torch.zeros_like(seqlens)
    cu_seqlens_on_cp_rank = torch.zeros_like(cu_seqlens)
    if first_half:
        seqlens_1 = seqlens - cp_rank * seqlens_padded
        seqlens_1 = seqlens_1.clamp(zeros, seqlens_padded)
        cu_seqlens_on_cp_rank[1:].add_(seqlens_1)
    if second_half:
        seqlens_2 = seqlens - (2 * cp_size - cp_rank - 1) * seqlens_padded
        seqlens_2 = seqlens_2.clamp(zeros, seqlens_padded)
        cu_seqlens_on_cp_rank[1:].add_(seqlens_2)
    cu_seqlens_on_cp_rank.cumsum_(dim=0)
    return cu_seqlens_on_cp_rank


324
@jit_fuser
325
def get_seq_chunk_ids_for_reordering_before_attn(cp_size, device):
326
327
    """
    Context parallelism assigns two discontiguous sequence chunks to each GPU for load balancing.
328
329
330
    To make sure tokens are ordered correctly for compute, we need to reorder sequence chunks to
    be contigupus before attention compute. This function is to compute sequence chunk ids for
    reordering.
331
332
    """
    chunk_ids = torch.empty(2 * cp_size, dtype=torch.int32, device=device)
333
334
335
    for rank in range(cp_size):
        chunk_ids[rank] = 2 * rank
        chunk_ids[rank + cp_size] = 2 * cp_size - 2 * rank - 1
336
337
338
    return chunk_ids


339
@jit_fuser
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
def get_seq_chunk_ids_for_reordering_after_attn(cp_size, device):
    """
    Context parallelism assigns two discontiguous sequence chunks to each GPU for load balancing.
    We need to reorder sequence chunks back to discontiguous after attention compute. This function
    is to compute sequence chunk ids for reordering.
    """
    chunk_ids = torch.empty(2 * cp_size, dtype=torch.int32, device=device)
    for rank in range(cp_size):
        chunk_ids[2 * rank] = rank
        chunk_ids[2 * rank + 1] = 2 * cp_size - rank - 1
    return chunk_ids


@jit_fuser
def reorder_seq_chunks_for_a2a_before_attn(x, chunk_ids_for_a2a, seq_dim, cp_size):
    """Reorder sequence chunk for A2A communication before attention compute."""
    # [cp, b, s, np//cp, hn] -> [b, cp, s, np//cp, hn]
    # or [cp, s, b, np//cp, hn] -> [cp, s, b, np//cp, hn]
    x = x.movedim(0, seq_dim).contiguous()
    # [b, cp, s, np//cp, hn] -> [b, cp*2, s//2, np//cp, hn]
    # or [cp, s, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn]
    x = x.view(*x.shape[:seq_dim], cp_size * 2, -1, *x.shape[(seq_dim + 2) :])
    # reorder the sequence chunks
    x = torch.index_select(x, dim=seq_dim, index=chunk_ids_for_a2a)
    return x


@jit_fuser
def reorder_seq_chunks_for_a2a_after_attn(x, chunk_ids_for_a2a, seq_dim, cp_size):
    """Reorder sequence chunk for A2A communication after attention compute."""
    # [b, cp*2, s//2, np//cp, hn] -> [cp*2, b, s//2, np//cp, hn]
    # or [cp*2, s//2, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn]
    x = x.movedim(seq_dim, 0).contiguous()
    # reorder the sequence chunks
    x = torch.index_select(x, dim=0, index=chunk_ids_for_a2a)
    # [cp*2, b, s//2, np//cp, hn] -> [cp, 2, b, s//2, np//cp, hn]
    # or [cp*2, s//2, b, np//cp, hn] -> [cp, 2, s//2, b, np//cp, hn]
    x = x.view(cp_size, 2, *x.shape[1:])
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
    return x


def flash_attn_a2a_communicate(
    a2a_inputs: Union[torch.Tensor, List[torch.Tensor]],
    chunk_ids_for_a2a: torch.Tensor,
    seq_dim: int,
    cp_size: int,
    cp_group: dist_group_type,
    cp_stream: torch.cuda.Stream,
    before_attn: bool,
) -> Union[torch.Tensor, List[torch.Tensor]]:
    """A2A communication for context parallelism."""
    a2a_inputs = [a2a_inputs] if not isinstance(a2a_inputs, list) else a2a_inputs
    a2a_outputs, a2a_reqs = [None] * len(a2a_inputs), [None] * len(a2a_inputs)
    if before_attn:
        for i in range(len(a2a_inputs) + 2):
            if 0 < i < len(a2a_inputs) + 1:
                a2a_outputs[i - 1] = torch.empty_like(a2a_inputs[i - 1])
                a2a_reqs[i - 1] = torch.distributed.all_to_all_single(
                    a2a_outputs[i - 1], a2a_inputs[i - 1], group=cp_group, async_op=True
                )
            if i > 1:
                with torch.cuda.stream(cp_stream):
                    a2a_reqs[i - 2].wait()
                    x = a2a_outputs[i - 2]
                    # reorder the sequence chunks
405
406
                    x = reorder_seq_chunks_for_a2a_before_attn(
                        x, chunk_ids_for_a2a, seq_dim, cp_size
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
                    )
                    # [b, cp*2, s//2, np//cp, hn] -> [b, cp*s, np//cp, hn]
                    # or [cp*2, s//2, b, np//cp, hn] -> [cp*s, b, np//cp, hn]
                    a2a_outputs[i - 2] = x.view(*x.shape[:seq_dim], -1, *x.shape[(seq_dim + 2) :])
            if i < len(a2a_inputs):
                x = a2a_inputs[i]
                # [b, s, np, hn] -> [b, s, cp, np//cp, hn]
                # or [s, b, np, hn] -> [s, b, cp, np//cp, hn]
                x = x.view(*x.shape[:-2], cp_size, x.shape[-2] // cp_size, x.shape[-1])
                # [b, s, cp, np//cp, hn] -> [cp, b, s, np//cp, hn]
                # or [s, b, cp, np//cp, hn] -> [cp, s, b, np//cp, hn]
                a2a_inputs[i] = x.movedim(-3, 0).contiguous()
    else:
        for i in range(len(a2a_inputs) + 2):
            if 0 < i < len(a2a_inputs) + 1:
                a2a_outputs[i - 1] = torch.empty_like(a2a_inputs[i - 1])
                a2a_reqs[i - 1] = torch.distributed.all_to_all_single(
                    a2a_outputs[i - 1], a2a_inputs[i - 1], group=cp_group, async_op=True
                )
            if i < len(a2a_inputs):
                x = a2a_inputs[i]
                # [b, cp*s, np//cp, hn] -> [b, cp*2, s//2, np//cp, hn]
                # or [cp*s, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn]
                x = x.view(*x.shape[:seq_dim], cp_size * 2, -1, *x.shape[(seq_dim + 1) :])
                # reorder the sequence chunks
432
433
                a2a_inputs[i] = reorder_seq_chunks_for_a2a_after_attn(
                    x, chunk_ids_for_a2a, seq_dim, cp_size
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
                )
            if i > 1:
                with torch.cuda.stream(cp_stream):
                    a2a_reqs[i - 2].wait()
                    x = a2a_outputs[i - 2]
                    # [cp, 2, b, s//2, np//cp, hn] -> [b, 2, s//2, cp, np//cp, hn]
                    # or [cp, 2, s//2, b, np//cp, hn] -> [2, s//2, b, cp, np//cp, hn]
                    x = x.movedim(0, -3).movedim(0, seq_dim).contiguous()
                    # [b, 2, s//2, cp, np//cp, hn] -> [b*s, np, hn]
                    # or [2, s//2, b, cp, np//cp, hn] -> [s*b, np, hn]
                    a2a_outputs[i - 2] = x.view(-1, x.shape[-3] * x.shape[-2], x.shape[-1])
    torch.cuda.current_stream().wait_stream(cp_stream)
    return a2a_outputs[0] if len(a2a_inputs) == 1 else a2a_outputs


449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
_cu_seqlens_info_with_cp_cache = {}


def _get_cu_seqlens_info_with_cp(
    batch_size: int,
    max_seqlen: int,
    cp_size: int,
    cu_seqlens: torch.Tensor,
):
    """Cumulative sequence lengths with CP being considered."""
    global _cu_seqlens_info_with_cp_cache
    if (batch_size, max_seqlen, cp_size) not in _cu_seqlens_info_with_cp_cache:
        _cu_seqlens_info_with_cp_cache[(batch_size, max_seqlen, cp_size)] = (
            cu_seqlens // cp_size,
            cu_seqlens // (cp_size * 2),
        )
    return _cu_seqlens_info_with_cp_cache[(batch_size, max_seqlen, cp_size)]


468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
def get_fa_args(
    forward: bool,
    use_flash_attn_3: bool,
    qkv_format: str,
    cu_seqlens_q=None,
    cu_seqlens_kv=None,
    max_seqlen_q=None,
    max_seqlen_kv=None,
    dq=None,
    dk=None,
    dv=None,
):
    """Get forward/backward arguments for flash-attn v2 and v3."""
    if use_flash_attn_3:
        if forward:
            if qkv_format == "thd":
                return [
                    *[None] * 4,  # k_new, v_new, qv, out
                    cu_seqlens_q,
                    cu_seqlens_kv,
                    *[None] * 3,  # cu_seqlens_k_new, seqused_q, seqused_k
                    max_seqlen_q,
                    max_seqlen_kv,
                    *[None]
                    * 8,  # page_table, kv_batch_idx, leftpad_k, rotary_cos, rotary_sin, q_descale, k_descale, v_descale
                ]
            return [
                *[None]
                * 9,  # k_new, v_new, qv, out, cu_seqlens_q, cu_seqlens_kv, cu_seqlens_k_new, seqused_q, seqused_k
                max_seqlen_q,
                max_seqlen_kv,
                *[None]
                * 8,  # page_table, kv_batch_idx, leftpad_k, rotary_cos, rotary_sin, q_descale, k_descale, v_descale
            ]
        if qkv_format == "thd":
            return [
                cu_seqlens_q,
                cu_seqlens_kv,
                None,  # sequed_q
                None,  # sequed_k
                max_seqlen_q,
                max_seqlen_kv,
                dq,
                dk,
                dv,
            ]
        return [
            None,  # cu_seqlens_q
            None,  # cu_seqlens_kv
            None,  # sequed_q
            None,  # sequed_k
            max_seqlen_q,
            max_seqlen_kv,
            dq,
            dk,
            dv,
        ]
    if forward:
        if qkv_format == "thd":
            return [
                cu_seqlens_q,
                cu_seqlens_kv,
                max_seqlen_q,
                max_seqlen_kv,
            ]
        return []
    if qkv_format == "thd":
        return [
            dq,
            dk,
            dv,
            cu_seqlens_q,
            cu_seqlens_kv,
            max_seqlen_q,
            max_seqlen_kv,
        ]
    return [
        dq,
        dk,
        dv,
    ]


551
class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
552
    """
553
554
555
    Attention implementation with context parallelism. Exchange KV between CP ranks
    with P2P in ring topology. Split attention compute into multiple steps, and overlap
    current-step compute with next-step communication.
556
557
558
559
560

    This implementation also supports hierarchical CP, which parallelizes attention
    heads in low-level CP groups and parallelizes sequence dimension in high-level CP
    groups. For more details, please refer to `LongVILA <https://arxiv.org/abs/2408.10188>`_
    and `USP <https://arxiv.org/abs/2405.07719>`_.
561
562
563
    """

    @staticmethod
564
565
566
567
568
569
570
    def forward(
        ctx,
        is_training,
        q,
        k,
        v,
        cu_seqlens_q,
571
        cu_seqlens_kv,
572
        max_seqlen_q,
573
        max_seqlen_kv,
574
575
        cu_seqlens_q_padded,
        cu_seqlens_kv_padded,
576
577
578
579
580
581
582
583
        dropout_p,
        softmax_scale,
        qkv_format,
        attn_mask_type,
        attn_bias_type,
        attn_bias,
        deterministic,
        use_fused_attention,
584
585
        fp8,
        fp8_meta,
586
587
588
        cp_group,
        cp_global_ranks,
        cp_stream,
589
        quantizers,
590
        pad_between_seqs,
591
        use_flash_attn_3,
592
    ):
593
        # pylint: disable=missing-function-docstring
594
        nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVP2P.forward")
595
596
597
        if softmax_scale is None:
            softmax_scale = q.shape[-1] ** (-0.5)

598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
        if isinstance(cp_group, list):
            assert (
                qkv_format != "thd"
            ), f"{qkv_format} format is not supported with hierarchical CP implementation yet!"
            assert attn_bias_type == "no_bias", (
                f"{attn_bias_type} bias type is not supported with hierarchical CP implementation"
                " yet!"
            )
            cp_group_a2a = cp_group[0]
            cp_size_a2a = get_distributed_world_size(cp_group_a2a)
            rank_a2a = get_distributed_rank(cp_group_a2a)
            cp_group = cp_group[1]
        else:
            cp_group_a2a = None
            cp_size_a2a = 1
            rank_a2a = 0

615
616
        cp_size = get_distributed_world_size(cp_group)
        rank = get_distributed_rank(cp_group)
617
618
        send_dst = cp_global_ranks[(rank + 1) % cp_size * cp_size_a2a + rank_a2a]
        recv_src = cp_global_ranks[(rank - 1) % cp_size * cp_size_a2a + rank_a2a]
619
620
        batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2)

621
622
        causal = "causal" in attn_mask_type
        padding = "padding" in attn_mask_type
623

624
        batch_dim = None
625
        seq_dim = None
626
        cu_seqlens_q_half, cu_seqlens_kv_half = None, None
627
        if qkv_format in ["bshd", "sbhd"]:
628
            seq_dim = qkv_format.index("s")
629
            qkv_layout = qkv_format + "_" + qkv_format[:-2] + "2" + qkv_format[-2:]
630
631
632
633
634
635
636
637
638
            cu_seqlens_q_padded, cu_seqlens_kv_padded = None, None
            if use_fused_attention:
                batch_dim = qkv_format.index("b")
                cu_seqlens_q, cu_seqlens_q_half = _get_cu_seqlens_info_with_cp(
                    q.shape[batch_dim], max_seqlen_q, cp_size, cu_seqlens_q
                )
                cu_seqlens_kv, cu_seqlens_kv_half = _get_cu_seqlens_info_with_cp(
                    q.shape[batch_dim], max_seqlen_kv, cp_size, cu_seqlens_kv
                )
639
640
        else:
            qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format
641
642
            cu_seqlens_q_padded = cu_seqlens_q_padded // cp_size
            cu_seqlens_kv_padded = cu_seqlens_kv_padded // cp_size
643
644
645
646
647

        max_seqlen_q = max_seqlen_q // cp_size
        max_seqlen_kv = max_seqlen_kv // cp_size
        cu_seqlens_q_per_step = [None for _ in range(cp_size)]
        cu_seqlens_kv_per_step = [None for _ in range(cp_size)]
648

649
        fused_attn_backend = None
650
        qkv_dtype = q.dtype
651
652
653
        amax_per_step = None
        S_quantizer_per_step = [None for _ in range(cp_size)]
        O_CP_quantizer_per_step = [None for _ in range(cp_size)]
654
655
        # "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype
        is_input_fp8 = False
656
657
658
659
660
661
662
663
664
665
666
        is_output_fp8 = False

        (
            QKV_quantizer,
            O_quantizer,
            O_CP_quantizer,
            S_quantizer,
            dQKV_quantizer,
            dQKV_CP_quantizer,
            dO_quantizer,
            dP_quantizer,
667
        ) = dpa_utils.get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=True)
668

669
670
671
        if fp8:
            if use_fused_attention:
                fused_attn_backend = FusedAttnBackend["FP8"]
672

673
674
675
676
                assert isinstance(k, q.__class__) and isinstance(
                    v, q.__class__
                ), "q, k, and v must have the same type."
                is_input_fp8 = isinstance(q, Float8Tensor)
677
678
679
680
681
                is_output_fp8 = fp8_meta is not None and fp8_meta["recipe"].fp8_mha
                if is_input_fp8:
                    QKV_quantizer = q._quantizer
                    q, k, v = q._data, k._data, v._data
                else:
682
683
                    q_f16, k_f16, v_f16 = q, k, v
                    if cp_size_a2a == 1 or int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
684
                        q = QKV_quantizer(q_f16)._data
685
                    if int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
686
687
688
689
690
                        k, v = [QKV_quantizer(x)._data for x in [k_f16, v_f16]]
                amax_per_step = torch.zeros((2, cp_size), dtype=torch.float32, device=q.device)
                # partial result quantizer
                for i in range(cp_size):
                    S_quantizer_per_step[i] = S_quantizer.copy()
691
                    S_quantizer_per_step[i].amax = amax_per_step[0][i].reshape((1,))
692
                    O_CP_quantizer_per_step[i] = O_CP_quantizer.copy()
693
                    O_CP_quantizer_per_step[i].amax = amax_per_step[1][i].reshape((1,))
694
695
696
697
698
699
700
701
            else:
                assert False, "FP8 is only supported with Fused Attention!"
        else:
            q_f16 = q
            if use_fused_attention:
                fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"]

        if cp_size_a2a > 1:
702
            chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_before_attn(cp_size_a2a, q.device)
703

704
705
706
707
708
            q, k, v = flash_attn_a2a_communicate(
                [q, k, v], chunk_ids_for_a2a, seq_dim, cp_size_a2a, cp_group_a2a, cp_stream, True
            )
            if not fp8:
                q_f16 = q
709
            elif not is_input_fp8 and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
710
                q_f16 = q
711
                q = QKV_quantizer(q_f16)._data
712

713
714
715
        assert qkv_format == "thd" or (
            q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0
        ), "Sequence length per GPU needs to be divisible by 2!"
716
        if causal:
717
718
            if qkv_format == "bshd":
                # [b, s, np, hn] -> [b, 2, s//2, np, hn]
719
                q, k, v = [x.view(x.shape[0], 2, x.shape[1] // 2, *x.shape[2:]) for x in [q, k, v]]
720
721
            elif qkv_format == "sbhd":
                # [s, b, np, hn] -> [2, s//2, b, np, hn]
722
                q, k, v = [x.view(2, x.shape[0] // 2, *x.shape[1:]) for x in [q, k, v]]
723
        if attn_bias is not None:
724
            assert len(attn_bias.shape) == 4, (
725
726
727
                "Only support bias shape of [b, h, sq, sk] for forward, "
                "and [1, h, sq, sk] for backward!"
            )
728
729
730
            assert (
                attn_bias.shape[-2] % 2 == 0 and attn_bias.shape[-1] % (2 * cp_size) == 0
            ), "Sequence length does not meet divisible requirements!"
731
            # [b, np, sq, sk] -> [b, np, 2, sq//2, 2*cp, sk//(2*cp)]
732
733
734
735
736
737
            attn_bias_ = attn_bias.view(
                *attn_bias.shape[:-2],
                2,
                attn_bias.shape[-2] // 2,
                2 * cp_size,
                attn_bias.shape[-1] // (2 * cp_size),
738
739
            )
            # [b, np, sq, sk] -> [b, np, sq, 2*cp, sk//(2*cp)]
740
741
            attn_bias = attn_bias.view(
                *attn_bias.shape[:-1], 2 * cp_size, attn_bias.shape[-1] // (2 * cp_size)
742
            )
743
        assert q.shape[-1] % 8 == 0, "hidden size per attention head should be multiple of 8"
744

745
746
747
748
749
        softmax_lse_in_packed_format = False
        if qkv_format == "thd":
            if use_fused_attention:
                softmax_lse_in_packed_format = get_cudnn_version() >= (9, 6, 0)
            else:
750
                softmax_lse_in_packed_format = fa_utils.v2_6_0_plus or use_flash_attn_3
751

752
        flash_attn_fwd = None
753
754
        if not use_fused_attention:
            fa_forward_kwargs = {"softmax_scale": softmax_scale}
755
756
757
758
            if use_flash_attn_3:
                flash_attn_fwd = (
                    _flash_attn_fwd_v3  # pylint: disable=possibly-used-before-assignment
                )
759
760
                fa_forward_kwargs["window_size"] = (-1, 0) if causal else (-1, -1)
            else:
761
762
763
764
                if qkv_format == "thd":
                    flash_attn_fwd = _flash_attn_varlen_fwd
                else:
                    flash_attn_fwd = _flash_attn_fwd
765
766
                fa_forward_kwargs["dropout_p"] = dropout_p
                fa_forward_kwargs["return_softmax"] = False
767
                if fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus:
768
                    fa_forward_kwargs["window_size"] = (-1, 0) if causal else (-1, -1)
769
                elif fa_utils.v2_7_0_plus:
770
771
                    fa_forward_kwargs["window_size_left"] = -1
                    fa_forward_kwargs["window_size_right"] = 0 if causal else -1
772
                if fa_utils.v2_4_plus:
773
                    fa_forward_kwargs["alibi_slopes"] = None
774
                if fa_utils.v2_5_7_plus and qkv_format == "thd":
775
                    fa_forward_kwargs["block_table"] = None
776
                if fa_utils.v2_6_0_plus:
777
                    fa_forward_kwargs["softcap"] = 0.0
778

779
780
781
        # Flash Attn inputs
        q_inputs = [None, None]
        kv_inputs = [None, None]
782
        attn_bias_inputs = [None, None]
783
784
785
786
        # Flash Attn outputs
        out_per_step = [None for _ in range(cp_size)]
        softmax_lse_per_step = [None for _ in range(cp_size)]
        rng_states = [None for _ in range(cp_size)]
787
        attn_biases = [None for _ in range(cp_size)]
788
789
790
791
792
793
794

        # create two streams to resolve wave quantization issue of Flash Attn in each step
        flash_attn_streams = [torch.cuda.current_stream(), cp_stream]
        # synchronize fwd results correction across steps
        fwd_results_correction_done = torch.cuda.Event()

        p2p_comm_buffers = [None for _ in range(cp_size)]
795
        if qkv_format in ["bshd", "sbhd"]:
796
797
798
            p2p_comm_buffers[0] = torch.cat((k.unsqueeze(-3), v.unsqueeze(-3)), dim=-3)
        else:
            p2p_comm_buffers[0] = torch.cat((k.unsqueeze(0), v.unsqueeze(0)), dim=0)
799
800
        send_recv_reqs = [[], []]

801
        out = None
802
        for i in range(cp_size + 1):
803
            if i < cp_size:
804
                with torch.cuda.stream(flash_attn_streams[i % 2]):
805
                    # wait until KV is received
806
                    for req in send_recv_reqs[(i + 1) % 2]:
807
808
                        req.wait()

809
810
811
812
813
814
815
816
817
818
819
820
                    if i < (cp_size - 1):
                        p2p_comm_buffers[i + 1] = torch.empty_like(p2p_comm_buffers[i])
                        send_recv_reqs[i % 2] = flash_attn_p2p_communicate(
                            rank,
                            p2p_comm_buffers[i],
                            send_dst,
                            p2p_comm_buffers[i + 1],
                            recv_src,
                            cp_group,
                            batch_p2p_comm,
                        )

821
                    if not fp8 or is_input_fp8 or int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
822
823
824
                        kv_inputs[i % 2] = p2p_comm_buffers[i]
                    else:
                        # KV exchange is in BF16/FP16, cast received KV in each step
825
                        kv_inputs[i % 2] = QKV_quantizer(p2p_comm_buffers[i])._data
826
827
                    if causal:
                        if i == 0:
828
                            if pad_between_seqs:
829
830
831
832
833
834
                                cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank(
                                    cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, True, True
                                )
                                cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank(
                                    cu_seqlens_kv, cu_seqlens_kv_padded, cp_size, rank, True, True
                                )
835
836
                            elif qkv_format == "thd":
                                cu_seqlens_q_per_step[i] = cu_seqlens_q // cp_size
837
                                cu_seqlens_kv_per_step[i] = cu_seqlens_kv // cp_size
838
839
840
                            else:
                                cu_seqlens_q_per_step[i] = cu_seqlens_q
                                cu_seqlens_kv_per_step[i] = cu_seqlens_kv
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
                            if qkv_format == "bshd":
                                # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
                                q_inputs[i % 2] = q.view(q.shape[0], -1, *q.shape[-2:])
                                # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn]
                                kv_inputs[i % 2] = kv_inputs[i % 2].view(
                                    k.shape[0], -1, 2, *k.shape[-2:]
                                )
                            elif qkv_format == "sbhd":
                                # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
                                q_inputs[i % 2] = q.view(-1, *q.shape[-3:])
                                # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn]
                                kv_inputs[i % 2] = kv_inputs[i % 2].view(
                                    -1, k.shape[2], 2, *k.shape[-2:]
                                )
                            elif qkv_format == "thd":
                                q_inputs[i % 2] = q
857
                            if use_fused_attention:
858
859
                                if attn_bias is not None:
                                    idx = (rank - i) % cp_size
860
861
862
863
864
865
                                    attn_bias_inputs[i % 2] = torch.cat(
                                        (
                                            attn_bias[..., idx, :],
                                            attn_bias[..., (2 * cp_size - idx - 1), :],
                                        ),
                                        dim=-1,
866
                                    ).contiguous()
867
868
869
870
871
872
873
874
875
876
877
878

                                q_part = q_inputs[i % 2]
                                k_part = (
                                    kv_inputs[i % 2][..., 0, :, :]
                                    if qkv_format in ["bshd", "sbhd"]
                                    else kv_inputs[i % 2][0]
                                )
                                v_part = (
                                    kv_inputs[i % 2][..., 1, :, :]
                                    if qkv_format in ["bshd", "sbhd"]
                                    else kv_inputs[i % 2][1]
                                )
879
                                fp8_meta_kwargs = {}
880
881
882
883
884
885
886
887
888
889
                                if fp8:
                                    q_part = QKV_quantizer.create_tensor_from_data(
                                        q_part, fake_dtype=qkv_dtype, internal=True
                                    )
                                    k_part = QKV_quantizer.create_tensor_from_data(
                                        k_part, fake_dtype=qkv_dtype, internal=True
                                    )
                                    v_part = QKV_quantizer.create_tensor_from_data(
                                        v_part, fake_dtype=qkv_dtype, internal=True
                                    )
890
891
                                    fp8_meta_kwargs["s_quantizer"] = S_quantizer_per_step[i]
                                    fp8_meta_kwargs["o_quantizer"] = O_CP_quantizer_per_step[i]
892

893
894
895
896
897
898
                                out_per_step[i], aux_ctx_tensors = fused_attn_fwd(
                                    is_training,
                                    max_seqlen_q,
                                    max_seqlen_kv,
                                    cu_seqlens_q_per_step[i],
                                    cu_seqlens_kv_per_step[i],
899
900
901
902
903
                                    q_part,
                                    k_part,
                                    v_part,
                                    fake_dtype=qkv_dtype,
                                    fused_attention_backend=fused_attn_backend,
904
905
906
907
908
909
910
911
912
                                    attn_scale=softmax_scale,
                                    dropout=dropout_p,
                                    qkv_layout=qkv_layout,
                                    attn_mask_type=attn_mask_type,
                                    attn_bias_type=attn_bias_type,
                                    attn_bias=attn_bias_inputs[i % 2],
                                    cu_seqlens_q_padded=cu_seqlens_q_padded,
                                    cu_seqlens_kv_padded=cu_seqlens_kv_padded,
                                    **fp8_meta_kwargs,
913
                                )
914
915
916
917
918
                                if fp8:
                                    softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors
                                else:
                                    softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors
                                    attn_biases[i] = rest[0] if len(rest) > 0 else None
919
                            else:
920
921
922
923
924
925
926
927
928
                                fa_forward_args_thd = get_fa_args(
                                    True,
                                    use_flash_attn_3,
                                    qkv_format,
                                    cu_seqlens_q=cu_seqlens_q_per_step[i],
                                    cu_seqlens_kv=cu_seqlens_kv_per_step[i],
                                    max_seqlen_q=max_seqlen_q,
                                    max_seqlen_kv=max_seqlen_kv,
                                )
929
                                fa_outputs = flash_attn_fwd(
930
                                    q_inputs[i % 2],
931
932
933
934
935
936
937
938
939
940
941
                                    (
                                        kv_inputs[i % 2][..., 0, :, :]
                                        if qkv_format in ["bshd", "sbhd"]
                                        else kv_inputs[i % 2][0]
                                    ),
                                    (
                                        kv_inputs[i % 2][..., 1, :, :]
                                        if qkv_format in ["bshd", "sbhd"]
                                        else kv_inputs[i % 2][1]
                                    ),
                                    *fa_forward_args_thd,
942
                                    causal=True,
943
                                    **fa_forward_kwargs,
944
                                )
945
                                if not fa_utils.v2_7_0_plus:
946
947
                                    out_per_step[i] = fa_outputs[4]
                                    softmax_lse_per_step[i] = fa_outputs[5]
948
                                    if not use_flash_attn_3:
949
950
951
952
                                        rng_states[i] = fa_outputs[7]
                                else:
                                    out_per_step[i] = fa_outputs[0]
                                    softmax_lse_per_step[i] = fa_outputs[1]
953
                                    if not use_flash_attn_3:
954
                                        rng_states[i] = fa_outputs[3]
955
                        elif i <= rank:
956
                            if pad_between_seqs:
957
958
959
960
961
962
963
964
965
966
967
                                cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank(
                                    cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, True, True
                                )
                                cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank(
                                    cu_seqlens_kv,
                                    cu_seqlens_kv_padded,
                                    cp_size,
                                    (rank - i) % cp_size,
                                    True,
                                    False,
                                )
968
969
                            elif qkv_format == "thd":
                                cu_seqlens_q_per_step[i] = cu_seqlens_q // cp_size
970
                                cu_seqlens_kv_per_step[i] = cu_seqlens_kv // (cp_size * 2)
971
972
973
                            else:
                                cu_seqlens_q_per_step[i] = cu_seqlens_q
                                cu_seqlens_kv_per_step[i] = cu_seqlens_kv_half
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
                            if qkv_format == "bshd":
                                # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
                                q_inputs[i % 2] = q.view(q.shape[0], -1, *q.shape[-2:])
                                # [b, 2, sk//2, 2, np, hn] -> [b, sk//2, 2, np, hn]
                                kv_inputs[i % 2] = kv_inputs[i % 2][:, 0, ...]
                            elif qkv_format == "sbhd":
                                # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
                                q_inputs[i % 2] = q.view(-1, *q.shape[-3:])
                                # [2, sk//2, b, 2, np, hn] -> [sk//2, b, 2, np, hn]
                                kv_inputs[i % 2] = kv_inputs[i % 2][0]
                            elif qkv_format == "thd":
                                q_inputs[i % 2] = q
                                # [2, t, np, hn] -> [2, t/2, np, hn]
                                kv_inputs[i % 2] = tex.thd_read_half_tensor(
                                    kv_inputs[i % 2], cu_seqlens_kv_padded, 0
                                )
990
                            if use_fused_attention:
991
                                kv_inputs[i % 2] = kv_inputs[i % 2].contiguous()
992
993
                                if attn_bias is not None:
                                    idx = (rank - i) % cp_size
994
                                    attn_bias_inputs[i % 2] = attn_bias[..., idx, :].contiguous()
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006

                                q_part = q_inputs[i % 2]
                                k_part = (
                                    kv_inputs[i % 2][..., 0, :, :]
                                    if qkv_format in ["bshd", "sbhd"]
                                    else kv_inputs[i % 2][0]
                                )
                                v_part = (
                                    kv_inputs[i % 2][..., 1, :, :]
                                    if qkv_format in ["bshd", "sbhd"]
                                    else kv_inputs[i % 2][1]
                                )
1007
                                fp8_meta_kwargs = {}
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
                                if fp8:
                                    q_part = QKV_quantizer.create_tensor_from_data(
                                        q_part, fake_dtype=qkv_dtype, internal=True
                                    )
                                    k_part = QKV_quantizer.create_tensor_from_data(
                                        k_part, fake_dtype=qkv_dtype, internal=True
                                    )
                                    v_part = QKV_quantizer.create_tensor_from_data(
                                        v_part, fake_dtype=qkv_dtype, internal=True
                                    )
1018
1019
                                    fp8_meta_kwargs["s_quantizer"] = S_quantizer_per_step[i]
                                    fp8_meta_kwargs["o_quantizer"] = O_CP_quantizer_per_step[i]
1020
1021
1022
1023
1024
1025
                                out_per_step[i], aux_ctx_tensors = fused_attn_fwd(
                                    is_training,
                                    max_seqlen_q,
                                    max_seqlen_kv // 2,
                                    cu_seqlens_q_per_step[i],
                                    cu_seqlens_kv_per_step[i],
1026
1027
1028
1029
                                    q_part,
                                    k_part,
                                    v_part,
                                    qkv_dtype,
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
                                    fused_attn_backend,
                                    attn_scale=softmax_scale,
                                    dropout=dropout_p,
                                    qkv_layout=qkv_layout,
                                    attn_mask_type="padding" if padding else "no_mask",
                                    attn_bias_type=attn_bias_type,
                                    attn_bias=attn_bias_inputs[i % 2],
                                    cu_seqlens_q_padded=cu_seqlens_q_padded,
                                    cu_seqlens_kv_padded=(
                                        None
                                        if cu_seqlens_kv_padded is None
                                        else cu_seqlens_kv_padded // 2
                                    ),
                                    **fp8_meta_kwargs,
1044
                                )
1045
1046
1047
1048
1049
                                if fp8:
                                    softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors
                                else:
                                    softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors
                                    attn_biases[i] = rest[0] if len(rest) > 0 else None
1050
                            else:
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
                                fa_forward_args_thd = get_fa_args(
                                    True,
                                    use_flash_attn_3,
                                    qkv_format,
                                    cu_seqlens_q=cu_seqlens_q_per_step[i],
                                    cu_seqlens_kv=cu_seqlens_kv_per_step[i],
                                    max_seqlen_q=max_seqlen_q,
                                    max_seqlen_kv=max_seqlen_kv // 2,
                                )
                                if use_flash_attn_3 or (
1061
                                    fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus
1062
                                ):
1063
                                    fa_forward_kwargs["window_size"] = (-1, -1)
1064
                                elif fa_utils.v2_7_0_plus:
1065
1066
                                    fa_forward_kwargs["window_size_left"] = -1
                                    fa_forward_kwargs["window_size_right"] = -1
1067
                                fa_outputs = flash_attn_fwd(
1068
                                    q_inputs[i % 2],
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
                                    (
                                        kv_inputs[i % 2][..., 0, :, :]
                                        if qkv_format in ["bshd", "sbhd"]
                                        else kv_inputs[i % 2][0]
                                    ),
                                    (
                                        kv_inputs[i % 2][..., 1, :, :]
                                        if qkv_format in ["bshd", "sbhd"]
                                        else kv_inputs[i % 2][1]
                                    ),
                                    *fa_forward_args_thd,
1080
                                    causal=False,
1081
                                    **fa_forward_kwargs,
1082
                                )
1083
                                if not fa_utils.v2_7_0_plus:
1084
1085
                                    out_per_step[i] = fa_outputs[4]
                                    softmax_lse_per_step[i] = fa_outputs[5]
1086
                                    if not use_flash_attn_3:
1087
1088
1089
1090
                                        rng_states[i] = fa_outputs[7]
                                else:
                                    out_per_step[i] = fa_outputs[0]
                                    softmax_lse_per_step[i] = fa_outputs[1]
1091
                                    if not use_flash_attn_3:
1092
                                        rng_states[i] = fa_outputs[3]
1093
                        else:
1094
                            if pad_between_seqs:
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
                                cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank(
                                    cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, False, True
                                )
                                cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank(
                                    cu_seqlens_kv,
                                    cu_seqlens_kv_padded,
                                    cp_size,
                                    (rank - i) % cp_size,
                                    True,
                                    True,
                                )
1106
1107
                            elif qkv_format == "thd":
                                cu_seqlens_q_per_step[i] = cu_seqlens_q // (cp_size * 2)
1108
                                cu_seqlens_kv_per_step[i] = cu_seqlens_kv // cp_size
1109
1110
1111
                            else:
                                cu_seqlens_q_per_step[i] = cu_seqlens_q_half
                                cu_seqlens_kv_per_step[i] = cu_seqlens_kv
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
                            if qkv_format == "bshd":
                                # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
                                q_inputs[i % 2] = q[:, 1, ...]
                                # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn]
                                kv_inputs[i % 2] = kv_inputs[i % 2].view(
                                    k.shape[0], -1, 2, *k.shape[-2:]
                                )
                            elif qkv_format == "sbhd":
                                # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn]
                                q_inputs[i % 2] = q[1]
                                # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn]
                                kv_inputs[i % 2] = kv_inputs[i % 2].view(
                                    -1, k.shape[2], 2, *k.shape[-2:]
                                )
                            elif qkv_format == "thd":
                                # [t, np, hn] -> [t/2, np, hn]
                                q_inputs[i % 2] = tex.thd_read_half_tensor(
                                    q, cu_seqlens_q_padded, 1
                                )
1131
                            if use_fused_attention:
1132
                                q_inputs[i % 2] = q_inputs[i % 2].contiguous()
1133
1134
                                if attn_bias is not None:
                                    idx = (rank - i) % cp_size
1135
1136
1137
1138
1139
1140
                                    attn_bias_inputs[i % 2] = torch.cat(
                                        (
                                            attn_bias_[..., 1, :, idx, :],
                                            attn_bias_[..., 1, :, (2 * cp_size - idx - 1), :],
                                        ),
                                        dim=-1,
1141
                                    ).contiguous()
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153

                                q_part = q_inputs[i % 2]
                                k_part = (
                                    kv_inputs[i % 2][..., 0, :, :]
                                    if qkv_format in ["bshd", "sbhd"]
                                    else kv_inputs[i % 2][0]
                                )
                                v_part = (
                                    kv_inputs[i % 2][..., 1, :, :]
                                    if qkv_format in ["bshd", "sbhd"]
                                    else kv_inputs[i % 2][1]
                                )
1154
                                fp8_meta_kwargs = {}
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
                                if fp8:
                                    q_part = QKV_quantizer.create_tensor_from_data(
                                        q_part, fake_dtype=qkv_dtype, internal=True
                                    )
                                    k_part = QKV_quantizer.create_tensor_from_data(
                                        k_part, fake_dtype=qkv_dtype, internal=True
                                    )
                                    v_part = QKV_quantizer.create_tensor_from_data(
                                        v_part, fake_dtype=qkv_dtype, internal=True
                                    )
1165
1166
                                    fp8_meta_kwargs["s_quantizer"] = S_quantizer_per_step[i]
                                    fp8_meta_kwargs["o_quantizer"] = O_CP_quantizer_per_step[i]
1167
1168
1169
1170
1171
1172
                                out_per_step[i], aux_ctx_tensors = fused_attn_fwd(
                                    is_training,
                                    max_seqlen_q // 2,
                                    max_seqlen_kv,
                                    cu_seqlens_q_per_step[i],
                                    cu_seqlens_kv_per_step[i],
1173
1174
1175
1176
                                    q_part,
                                    k_part,
                                    v_part,
                                    qkv_dtype,
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
                                    fused_attn_backend,
                                    attn_scale=softmax_scale,
                                    dropout=dropout_p,
                                    qkv_layout=qkv_layout,
                                    attn_mask_type="padding" if padding else "no_mask",
                                    attn_bias_type=attn_bias_type,
                                    attn_bias=attn_bias_inputs[i % 2],
                                    cu_seqlens_q_padded=(
                                        None
                                        if cu_seqlens_q_padded is None
                                        else cu_seqlens_q_padded // 2
                                    ),
                                    cu_seqlens_kv_padded=cu_seqlens_kv_padded,
                                    **fp8_meta_kwargs,
1191
                                )
1192
1193
1194
1195
1196
                                if fp8:
                                    softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors
                                else:
                                    softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors
                                    attn_biases[i] = rest[0] if len(rest) > 0 else None
1197
                            else:
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
                                fa_forward_args_thd = get_fa_args(
                                    True,
                                    use_flash_attn_3,
                                    qkv_format,
                                    cu_seqlens_q=cu_seqlens_q_per_step[i],
                                    cu_seqlens_kv=cu_seqlens_kv_per_step[i],
                                    max_seqlen_q=max_seqlen_q // 2,
                                    max_seqlen_kv=max_seqlen_kv,
                                )
                                if use_flash_attn_3 or (
1208
                                    fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus
1209
                                ):
1210
                                    fa_forward_kwargs["window_size"] = (-1, -1)
1211
                                elif fa_utils.v2_7_0_plus:
1212
1213
                                    fa_forward_kwargs["window_size_left"] = -1
                                    fa_forward_kwargs["window_size_right"] = -1
1214
                                fa_outputs = flash_attn_fwd(
1215
                                    q_inputs[i % 2],
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
                                    (
                                        kv_inputs[i % 2][..., 0, :, :]
                                        if qkv_format in ["bshd", "sbhd"]
                                        else kv_inputs[i % 2][0]
                                    ),
                                    (
                                        kv_inputs[i % 2][..., 1, :, :]
                                        if qkv_format in ["bshd", "sbhd"]
                                        else kv_inputs[i % 2][1]
                                    ),
                                    *fa_forward_args_thd,
1227
                                    causal=False,
1228
                                    **fa_forward_kwargs,
1229
                                )
1230
                                if not fa_utils.v2_7_0_plus:
1231
1232
                                    out_per_step[i] = fa_outputs[4]
                                    softmax_lse_per_step[i] = fa_outputs[5]
1233
                                    if not use_flash_attn_3:
1234
1235
1236
1237
                                        rng_states[i] = fa_outputs[7]
                                else:
                                    out_per_step[i] = fa_outputs[0]
                                    softmax_lse_per_step[i] = fa_outputs[1]
1238
                                    if not use_flash_attn_3:
1239
                                        rng_states[i] = fa_outputs[3]
1240
                    else:
1241
                        if pad_between_seqs:
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
                            cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank(
                                cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, True, True
                            )
                            cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank(
                                cu_seqlens_kv,
                                cu_seqlens_kv_padded,
                                cp_size,
                                (rank - i) % cp_size,
                                True,
                                True,
                            )
1253
1254
                        elif qkv_format == "thd":
                            cu_seqlens_q_per_step[i] = cu_seqlens_q // cp_size
1255
                            cu_seqlens_kv_per_step[i] = cu_seqlens_kv // cp_size
1256
1257
1258
                        else:
                            cu_seqlens_q_per_step[i] = cu_seqlens_q
                            cu_seqlens_kv_per_step[i] = cu_seqlens_kv
1259
                        if use_fused_attention:
1260
1261
                            if attn_bias is not None:
                                idx = (rank - i) % cp_size
1262
1263
1264
1265
1266
1267
                                attn_bias_inputs[i % 2] = torch.cat(
                                    (
                                        attn_bias[..., idx, :],
                                        attn_bias[..., (2 * cp_size - idx - 1), :],
                                    ),
                                    dim=-1,
1268
                                ).contiguous()
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280

                            q_part = q
                            k_part = (
                                kv_inputs[i % 2][..., 0, :, :]
                                if qkv_format in ["bshd", "sbhd"]
                                else kv_inputs[i % 2][0]
                            )
                            v_part = (
                                kv_inputs[i % 2][..., 1, :, :]
                                if qkv_format in ["bshd", "sbhd"]
                                else kv_inputs[i % 2][1]
                            )
1281
                            fp8_meta_kwargs = {}
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
                            if fp8:
                                q_part = QKV_quantizer.create_tensor_from_data(
                                    q_part, fake_dtype=qkv_dtype, internal=True
                                )
                                k_part = QKV_quantizer.create_tensor_from_data(
                                    k_part, fake_dtype=qkv_dtype, internal=True
                                )
                                v_part = QKV_quantizer.create_tensor_from_data(
                                    v_part, fake_dtype=qkv_dtype, internal=True
                                )
1292
1293
                                fp8_meta_kwargs["s_quantizer"] = S_quantizer_per_step[i]
                                fp8_meta_kwargs["o_quantizer"] = O_CP_quantizer_per_step[i]
1294
1295
1296
1297
1298
1299
                            out_per_step[i], aux_ctx_tensors = fused_attn_fwd(
                                is_training,
                                max_seqlen_q,
                                max_seqlen_kv,
                                cu_seqlens_q_per_step[i],
                                cu_seqlens_kv_per_step[i],
1300
1301
1302
1303
                                q_part,
                                k_part,
                                v_part,
                                qkv_dtype,
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
                                fused_attn_backend,
                                attn_scale=softmax_scale,
                                dropout=dropout_p,
                                qkv_layout=qkv_layout,
                                attn_mask_type=attn_mask_type,
                                attn_bias_type=attn_bias_type,
                                attn_bias=attn_bias_inputs[i % 2],
                                cu_seqlens_q_padded=cu_seqlens_q_padded,
                                cu_seqlens_kv_padded=cu_seqlens_kv_padded,
                                **fp8_meta_kwargs,
1314
                            )
1315
1316
1317
1318
1319
                            if fp8:
                                softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors
                            else:
                                softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors
                                attn_biases[i] = rest[0] if len(rest) > 0 else None
1320
                        else:
1321
1322
1323
1324
1325
1326
1327
1328
1329
                            fa_forward_args_thd = get_fa_args(
                                True,
                                use_flash_attn_3,
                                qkv_format,
                                cu_seqlens_q=cu_seqlens_q_per_step[i],
                                cu_seqlens_kv=cu_seqlens_kv_per_step[i],
                                max_seqlen_q=max_seqlen_q,
                                max_seqlen_kv=max_seqlen_kv,
                            )
1330
                            fa_outputs = flash_attn_fwd(
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
                                q,
                                (
                                    kv_inputs[i % 2][..., 0, :, :]
                                    if qkv_format in ["bshd", "sbhd"]
                                    else kv_inputs[i % 2][0]
                                ),
                                (
                                    kv_inputs[i % 2][..., 1, :, :]
                                    if qkv_format in ["bshd", "sbhd"]
                                    else kv_inputs[i % 2][1]
                                ),
                                *fa_forward_args_thd,
1343
                                causal=False,
1344
                                **fa_forward_kwargs,
1345
                            )
1346
                            if not fa_utils.v2_7_0_plus:
1347
1348
                                out_per_step[i] = fa_outputs[4]
                                softmax_lse_per_step[i] = fa_outputs[5]
1349
                                if not use_flash_attn_3:
1350
1351
1352
1353
                                    rng_states[i] = fa_outputs[7]
                            else:
                                out_per_step[i] = fa_outputs[0]
                                softmax_lse_per_step[i] = fa_outputs[1]
1354
                                if not use_flash_attn_3:
1355
                                    rng_states[i] = fa_outputs[3]
1356
1357
1358
1359

            if i > 0:
                # wait until fwd restuls correction of last step is done
                if i > 1:
1360
                    flash_attn_streams[(i - 1) % 2].wait_event(fwd_results_correction_done)
1361

1362
                if use_fused_attention:
1363
1364
                    # [b, np, sq, 1] -> [b, np, sq] or
                    # [t, np, 1] -> [t, np]
1365
                    softmax_lse_per_step[i - 1].squeeze_(-1)
1366
1367
1368
1369
                    if softmax_lse_in_packed_format:
                        softmax_lse_per_step[i - 1] = (
                            softmax_lse_per_step[i - 1].transpose(0, 1).contiguous()
                        )
1370

1371
                with torch.cuda.stream(flash_attn_streams[(i - 1) % 2]):
1372
                    if fp8:
1373
                        out_per_step[i - 1] = out_per_step[i - 1].dequantize(dtype=torch.float32)
1374
1375
                    if i == 1:
                        softmax_lse = torch.clone(softmax_lse_per_step[0]).to(torch.double)
1376
1377
                        if qkv_format == "thd":
                            out = torch.zeros_like(q if not fp8 else out_per_step[0]).view(q.shape)
1378
1379
1380
1381
                    elif (i - 1) <= rank or not causal:
                        flash_attn_fwd_softmax_lse_correction(
                            softmax_lse, softmax_lse_per_step[i - 1]
                        )
1382
                    else:
1383
                        if qkv_format == "thd":
1384
                            tex.thd_second_half_lse_correction(
1385
1386
1387
                                softmax_lse,
                                softmax_lse_per_step[i - 1],
                                cu_seqlens_q_padded,
1388
                                softmax_lse_in_packed_format,
1389
                            )
1390
                        else:
1391
1392
1393
                            flash_attn_fwd_second_half_softmax_lse_correction(
                                softmax_lse.view(*softmax_lse.shape[:-1], 2, -1),
                                softmax_lse_per_step[i - 1],
1394
                            )
1395
1396

                if i < cp_size:
1397
                    flash_attn_streams[(i - 1) % 2].record_event(fwd_results_correction_done)
1398
1399
1400

        torch.cuda.current_stream().wait_stream(flash_attn_streams[1])

1401
1402
1403
1404
        second_half_lse_seqlen = None
        if causal and rank < (cp_size - 1):
            second_half_lse_seqlen = softmax_lse_per_step[-1].shape[-1]

1405
1406
        softmax_lse = softmax_lse.to(torch.float)
        for i in range(cp_size):
1407
            if i <= rank or not causal:
1408
                if qkv_format in ["bshd", "sbhd"]:
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
                    if i == 0:
                        out = flash_attn_fwd_out_correction_init(
                            out_per_step[0],
                            softmax_lse,
                            softmax_lse_per_step[0],
                            seq_dim,
                        )
                        out = out.view(q.shape)
                    else:
                        flash_attn_fwd_out_correction(
                            out.view(*out_per_step[i].shape),
                            out_per_step[i],
                            softmax_lse,
                            softmax_lse_per_step[i],
                            seq_dim,
                        )
1425
                elif qkv_format == "thd":
1426
1427
1428
1429
1430
                    tex.thd_out_correction(
                        out,
                        out_per_step[i],
                        softmax_lse,
                        softmax_lse_per_step[i],
1431
                        cu_seqlens_q_padded,
1432
                        False,
1433
                        softmax_lse_in_packed_format,
1434
                    )
1435
            else:
1436
                if qkv_format in ["bshd", "sbhd"]:
1437
1438
                    flash_attn_fwd_second_half_out_correction(
                        out,
1439
                        out_per_step[i],
1440
                        softmax_lse,
1441
                        softmax_lse_per_step[i],
1442
                        seq_dim,
1443
                    )
1444
                elif qkv_format == "thd":
1445
1446
1447
1448
1449
                    tex.thd_out_correction(
                        out,
                        out_per_step[i],
                        softmax_lse,
                        softmax_lse_per_step[i],
1450
                        cu_seqlens_q_padded,
1451
                        True,
1452
                        softmax_lse_in_packed_format,
1453
                    )
1454
1455

        kv = p2p_comm_buffers[-1]
1456
1457
1458
1459
1460
1461
1462
1463
        if qkv_format == "bshd":
            out = out.view(out.shape[0], -1, *out.shape[-2:])
            ctx.batch_size = out.shape[0]
        elif qkv_format == "sbhd":
            out = out.view(-1, *out.shape[-3:])
            ctx.batch_size = out.shape[1]

        if cp_size_a2a > 1:
1464
            chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size_a2a, out.device)
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
            out = flash_attn_a2a_communicate(
                out, chunk_ids_for_a2a, seq_dim, cp_size_a2a, cp_group_a2a, cp_stream, False
            )
            if use_fused_attention:
                if qkv_format == "bshd":
                    # [b*s, np, hn] -> [b, s, np, hn]
                    out = out.view(ctx.batch_size, -1, *out.shape[-2:])
                elif qkv_format == "sbhd":
                    # [s*b, np, hn] -> [s, b, np, hn]
                    out = out.view(-1, ctx.batch_size, *out.shape[-2:])
        elif not use_fused_attention:
1476
            out = out.view(-1, *out.shape[-2:])
1477

1478
1479
        if fp8 and use_fused_attention:
            amax_cp_fwd = amax_per_step.amax(dim=1)
1480
1481
            S_quantizer.amax.copy_(amax_cp_fwd[0])
            O_CP_quantizer.amax.copy_(amax_cp_fwd[1])
1482

1483
        out_fp8 = None
1484
        out_f16 = out.to(qkv_dtype)
1485

1486
        if fp8 and (is_output_fp8 or int(os.getenv("NVTE_FP8_DPA_BWD", "1"))):
1487
1488
1489
            out_fp8 = O_quantizer(out_f16)  # final result

        out_ret = out_fp8 if (fp8 and is_output_fp8) else out_f16
1490
1491

        if fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
1492
            q_save, kv_save, out_save = q, kv, out_fp8._data
1493
        elif fp8 and is_input_fp8:
1494
            q_save, kv_save, out_save = q, kv, out_f16
1495
        else:
1496
            q_f16 = q_f16.view(q.shape)
1497
1498
            q_save, kv_save, out_save = q_f16, kv, out_f16

1499
        tensors_to_save, tensor_objects = prepare_for_saving(
1500
1501
1502
            q_save,
            kv_save,
            out_save,
1503
            softmax_lse,
1504
1505
            cu_seqlens_q_padded,
            cu_seqlens_kv_padded,
1506
1507
            *cu_seqlens_q_per_step,
            *cu_seqlens_kv_per_step,
1508
1509
            *rng_states,
            *attn_biases,
1510
        )
1511
1512
1513
        ctx.save_for_backward(*tensors_to_save)
        ctx.tensor_objects = tensor_objects

1514
1515
1516
        ctx.cp_group_a2a = cp_group_a2a
        ctx.cp_size_a2a = cp_size_a2a
        ctx.rank_a2a = rank_a2a
1517
1518
        ctx.cp_group = cp_group
        ctx.cp_global_ranks = cp_global_ranks
1519
        ctx.cp_stream = cp_stream
1520
1521
        ctx.dropout_p = dropout_p
        ctx.max_seqlen_q = max_seqlen_q
1522
        ctx.max_seqlen_kv = max_seqlen_kv
1523
        ctx.softmax_scale = softmax_scale
1524
        ctx.qkv_format = qkv_format
1525
        ctx.attn_mask_type = attn_mask_type
1526
1527
        ctx.attn_bias_type = attn_bias_type
        ctx.attn_bias_shape = None if attn_bias is None else attn_bias.shape
1528
        ctx.deterministic = deterministic
1529
        ctx.use_fused_attention = use_fused_attention
1530
        ctx.softmax_lse_in_packed_format = softmax_lse_in_packed_format
1531
        ctx.second_half_lse_seqlen = second_half_lse_seqlen
1532
1533
        ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
        ctx.fp8_meta = fp8_meta
1534
1535
        ctx.is_input_fp8 = is_input_fp8
        ctx.is_output_fp8 = is_output_fp8
1536
        ctx.use_flash_attn_3 = use_flash_attn_3
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552

        ctx.qkv_dtype = qkv_dtype
        ctx.dQKV_quantizer = dQKV_quantizer
        ctx.dQKV_CP_quantizer = dQKV_CP_quantizer
        ctx.dO_quantizer = dO_quantizer
        ctx.dP_quantizer = dP_quantizer
        ctx.QKV_quantizer = QKV_quantizer
        ctx.O_quantizer = O_quantizer
        ctx.S_quantizer = S_quantizer
        if ctx.fp8:
            ctx.QKV_quantizer = QKV_quantizer.copy()
            ctx.QKV_quantizer.scale = QKV_quantizer.scale.clone()
            ctx.O_quantizer = O_quantizer.copy()
            ctx.O_quantizer.scale = O_quantizer.scale.clone()
            ctx.S_quantizer = S_quantizer.copy()
            ctx.S_quantizer.scale = S_quantizer.scale.clone()
1553
        nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVP2P.forward")
1554

1555
        return out_ret
1556
1557
1558

    @staticmethod
    def backward(ctx, dout):
1559
        # pylint: disable=missing-function-docstring
1560
        nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVP2P.backward")
1561
1562
1563
        cp_size_a2a = ctx.cp_size_a2a
        rank_a2a = ctx.rank_a2a

1564
1565
        cp_size = get_distributed_world_size(ctx.cp_group)
        rank = get_distributed_rank(ctx.cp_group)
1566
1567
        send_dst = ctx.cp_global_ranks[(rank - 1) % cp_size * cp_size_a2a + rank_a2a]
        recv_src = ctx.cp_global_ranks[(rank + 1) % cp_size * cp_size_a2a + rank_a2a]
1568
1569
        batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2)

1570
        q, kv, out, softmax_lse, cu_seqlens_q_padded, cu_seqlens_kv_padded, *other_tensors = (
1571
            restore_from_saved(ctx.tensor_objects, ctx.saved_tensors)
1572
1573
1574
1575
1576
        )
        cu_seqlens_q_per_step = other_tensors[:cp_size]
        cu_seqlens_kv_per_step = other_tensors[cp_size : cp_size * 2]
        rng_states = other_tensors[cp_size * 2 : cp_size * 3]
        attn_biases = other_tensors[cp_size * 3 : cp_size * 4]
1577

1578
1579
        causal = "causal" in ctx.attn_mask_type
        padding = "padding" in ctx.attn_mask_type
1580
1581

        seq_dim = None
1582
        if ctx.qkv_format in ["bshd", "sbhd"]:
1583
            seq_dim = ctx.qkv_format.index("s")
1584
1585
1586
            qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format[:-2] + "2" + ctx.qkv_format[-2:]
        else:
            qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format
1587

1588
        if attn_biases[0] is not None:
1589
1590
            # [b, np, sq, 2*cp, sk//(2*cp)]
            attn_dbias = torch.zeros(
1591
                *ctx.attn_bias_shape, dtype=attn_biases[0].dtype, device=attn_biases[0].device
1592
1593
1594
            )
            # [b, np, sq, 2*cp, sk//(2*cp)] -> [b, np, 2, sq//2, 2*cp, sk//(2*cp)]
            attn_dbias_ = attn_dbias.view(
1595
                *attn_dbias.shape[:-3], 2, attn_dbias.shape[-3] // 2, *attn_dbias.shape[-2:]
1596
1597
1598
            )
        else:
            attn_dbias = None
1599
            attn_dbias_ = None
1600

1601
1602
        softmax_lse_ = None
        if causal and ctx.second_half_lse_seqlen is not None:
1603
            if ctx.qkv_format == "thd":
1604
                softmax_lse_ = tex.thd_read_second_half_lse(
1605
1606
1607
1608
                    softmax_lse,
                    cu_seqlens_q_padded,
                    ctx.softmax_lse_in_packed_format,
                    ctx.second_half_lse_seqlen,
1609
                )
1610
1611
            else:
                # [b, np, sq] -> [b, np, 2, sq//2]
1612
                softmax_lse_ = softmax_lse.view(*softmax_lse.shape[:-1], 2, -1)
1613
                softmax_lse_ = softmax_lse_[..., 1, :].contiguous()
1614
1615
1616
1617
1618
1619
            if ctx.use_fused_attention:
                if ctx.softmax_lse_in_packed_format:
                    softmax_lse_ = softmax_lse_.transpose(0, 1).contiguous()
                # [b, np, sq//2] -> [b, np, sq//2, 1] or
                # [t//2, np] -> [t//2, np, 1]
                softmax_lse_.unsqueeze_(-1)
1620
        if ctx.use_fused_attention:
1621
1622
1623
1624
            if ctx.softmax_lse_in_packed_format:
                softmax_lse = softmax_lse.transpose(0, 1).contiguous()
            # [b, np, sq] -> [b, np, sq, 1] or
            # [t, np] -> [t, np, 1]
1625
            softmax_lse.unsqueeze_(-1)
1626
            dout = dout.contiguous()
1627

1628
        dq = None
1629
        dout_dtype = dout.dtype
1630
1631
        fused_attn_backend = None
        fused_attn_dqkv_dtype = None
1632
1633
1634
        amax_per_step = None
        dP_quantizer_per_step = [None for _ in range(cp_size)]
        dQKV_CP_quantizer_per_step = [None for _ in range(cp_size)]
1635
1636
1637
        if ctx.fp8:
            if ctx.use_fused_attention:
                fused_attn_backend = FusedAttnBackend["FP8"]
1638

1639
                if ctx.is_output_fp8:
1640
                    assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!"
1641
                    ctx.dO_quantizer = dout._quantizer
1642
                else:
1643
                    dout = ctx.dO_quantizer(dout)
1644
1645
1646
1647
1648
1649
                fused_attn_dqkv_dtype = TE_DType[dout._data.dtype]
                dq_fp8 = torch.empty((cp_size, *q.shape), dtype=dout._data.dtype, device=q.device)
                dkv_fp8 = torch.empty(
                    (cp_size, *kv.shape), dtype=dout._data.dtype, device=kv.device
                )
                dkv_fp8_ = torch.empty_like(dkv_fp8)
1650
                p2p_comm_buffers = [[kv, dkv_fp8], [torch.empty_like(kv), dkv_fp8_]]
1651
                dout = dout._data
1652
                fp8_meta_kwargs = {}
1653
                fp8_meta_kwargs["s_quantizer"] = ctx.S_quantizer
1654
1655
1656
                amax_per_step = torch.zeros((2, cp_size), dtype=torch.float32, device=q.device)
                for i in range(cp_size):
                    dP_quantizer_per_step[i] = ctx.dP_quantizer.copy()
1657
                    dP_quantizer_per_step[i].amax = amax_per_step[0][i].reshape((1,))
1658
                    dQKV_CP_quantizer_per_step[i] = ctx.dQKV_CP_quantizer.copy()
1659
                    dQKV_CP_quantizer_per_step[i].amax = amax_per_step[1][i].reshape((1,))
1660
1661
1662
            else:
                assert False, "FP8 is only supported with Fused Attention!"
        else:
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
            if ctx.fp8_meta is not None:
                if ctx.is_input_fp8:
                    q = ctx.QKV_quantizer.create_tensor_from_data(
                        q, fake_dtype=ctx.qkv_dtype, internal=True
                    )
                    kv = ctx.QKV_quantizer.create_tensor_from_data(
                        kv, fake_dtype=ctx.qkv_dtype, internal=True
                    )
                    q = q.dequantize(dtype=ctx.qkv_dtype)
                    kv = kv.dequantize(dtype=ctx.qkv_dtype)
                if ctx.is_output_fp8:
                    assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!"
                    if cp_size_a2a == 1:
                        dout = dout.dequantize(dtype=dout_dtype)
                    else:
                        ctx.dO_quantizer = dout._quantizer
                        dout = dout._data
1680
1681
1682
1683
1684
1685
1686
1687
            dq = torch.empty_like(q)
            p2p_comm_buffers = [
                torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device),
                torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device),
            ]
            p2p_comm_buffers[0][0].copy_(kv)
            if ctx.use_fused_attention:
                fp8_meta_kwargs = {}
1688
                fused_attn_dqkv_dtype = TE_DType[dout_dtype]
1689
1690
                fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"]

1691
1692
1693
1694
        if cp_size_a2a > 1:
            if not ctx.use_fused_attention:
                out = out.view(ctx.batch_size, -1, *out.shape[-2:])
                dout = dout.view(*out.shape)
1695
1696
1697
            chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_before_attn(
                cp_size_a2a, out.device
            )
1698
1699
1700
1701
1702
1703
1704
1705
1706
            out, dout = flash_attn_a2a_communicate(
                [out, dout],
                chunk_ids_for_a2a,
                seq_dim,
                cp_size_a2a,
                ctx.cp_group_a2a,
                ctx.cp_stream,
                True,
            )
1707
            if not ctx.fp8 and ctx.fp8_meta is not None and ctx.is_output_fp8:
1708
1709
1710
1711
                dout = ctx.dO_quantizer.create_tensor_from_data(
                    dout, fake_dtype=dout_dtype, internal=True
                )
                dout = dout.dequantize(dtype=dout_dtype)
1712

1713
1714
1715
1716
        out = out.view(*q.shape)
        dout = dout.view(*q.shape)
        send_recv_reqs = []

1717
        flash_attn_bwd = None
1718
1719
        if not ctx.use_fused_attention:
            fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale}
1720
1721
1722
1723
            if ctx.use_flash_attn_3:
                flash_attn_bwd = (
                    _flash_attn_bwd_v3  # pylint: disable=possibly-used-before-assignment
                )
1724
1725
                fa_backward_kwargs["deterministic"] = ctx.deterministic
            else:
1726
1727
1728
1729
                if ctx.qkv_format == "thd":
                    flash_attn_bwd = _flash_attn_varlen_bwd
                else:
                    flash_attn_bwd = _flash_attn_bwd
1730
                fa_backward_kwargs["dropout_p"] = ctx.dropout_p
1731
                if fa_utils.v2_4_plus:
1732
                    fa_backward_kwargs["alibi_slopes"] = None
1733
                if fa_utils.v2_4_1_plus:
1734
                    fa_backward_kwargs["deterministic"] = ctx.deterministic
1735
                if fa_utils.v2_6_0_plus:
1736
                    fa_backward_kwargs["softcap"] = 0.0
1737

1738
1739
1740
1741
1742
        for i in range(cp_size):
            # wait until KV is received
            for req in send_recv_reqs:
                req.wait()

1743
1744
            send_tensor = p2p_comm_buffers[i % 2]
            recv_tensor = p2p_comm_buffers[(i + 1) % 2]
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
            if ctx.fp8:
                if i < cp_size - 1:
                    send_recv_reqs = flash_attn_p2p_communicate(
                        rank,
                        send_tensor[0],
                        send_dst,
                        recv_tensor[0],
                        recv_src,
                        ctx.cp_group,
                        batch_p2p_comm,
                    )
                else:
                    dkv_a2a_req = torch.distributed.all_to_all_single(
                        dkv_fp8,
                        dkv_fp8_,
                        group=ctx.cp_group,
                        async_op=True,
                    )
                    send_recv_reqs = [dkv_a2a_req]
            else:
                if i == 0:
                    send_tensor = send_tensor[0]
                    recv_tensor = recv_tensor[0]
                if i == (cp_size - 1):
                    send_tensor = send_tensor[1]
                    recv_tensor = recv_tensor[1]
                send_recv_reqs = flash_attn_p2p_communicate(
                    rank, send_tensor, send_dst, recv_tensor, recv_src, ctx.cp_group, batch_p2p_comm
                )
1774

1775
            kv = p2p_comm_buffers[i % 2][0]
1776
1777
            q_, kv_, out_, dout_ = None, None, None, None
            dq_, dk_, dv_ = None, None, None
1778
            # In reversed order of fwd
1779
            if causal:
1780
                if i == (cp_size - 1):
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
                    if ctx.qkv_format == "bshd":
                        # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
                        q_, out_, dout_ = [
                            x.view(x.shape[0], -1, *x.shape[-2:]) for x in [q, out, dout]
                        ]
                        # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn]
                        kv_ = kv.view(kv.shape[0], -1, *kv.shape[-3:])
                    elif ctx.qkv_format == "sbhd":
                        # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
                        q_, out_, dout_ = [x.view(-1, *x.shape[-3:]) for x in [q, out, dout]]
                        # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn]
                        kv_ = kv.view(-1, *kv.shape[-4:])
                    elif ctx.qkv_format == "thd":
                        q_, kv_, out_, dout_ = q, kv, out, dout
1795
                    if ctx.use_fused_attention:
1796
1797
1798
1799
1800
1801
1802
1803
                        if ctx.fp8:
                            aux_ctx_tensors = [
                                softmax_lse,
                                softmax_lse,
                                rng_states[cp_size - i - 1],
                            ]
                        else:
                            aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]]
1804
                        if attn_dbias is not None:
1805
                            aux_ctx_tensors += [attn_biases[cp_size - i - 1]]
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
                        q_part = q_
                        k_part = kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0]
                        v_part = kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1]
                        out_part = out_
                        dout_part = dout_

                        if ctx.fp8:
                            q_part = ctx.QKV_quantizer.create_tensor_from_data(
                                q_part, fake_dtype=ctx.qkv_dtype, internal=True
                            )
                            k_part = ctx.QKV_quantizer.create_tensor_from_data(
                                k_part, fake_dtype=ctx.qkv_dtype, internal=True
                            )
                            v_part = ctx.QKV_quantizer.create_tensor_from_data(
                                v_part, fake_dtype=ctx.qkv_dtype, internal=True
                            )
                            out_part = ctx.O_quantizer.create_tensor_from_data(
                                out_part, fake_dtype=ctx.qkv_dtype, internal=True
                            )
                            dout_part = ctx.dO_quantizer.create_tensor_from_data(
1826
                                dout_part, fake_dtype=dout_dtype, internal=True
1827
                            )
1828
1829
                            fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i]
                            fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i]
1830
                        dq_, dk_, dv_, dbias_ = fused_attn_bwd(
1831
                            ctx.max_seqlen_q,
1832
1833
1834
                            ctx.max_seqlen_kv,
                            cu_seqlens_q_per_step[cp_size - i - 1],
                            cu_seqlens_kv_per_step[cp_size - i - 1],
1835
1836
1837
1838
1839
                            q_part,
                            k_part,
                            v_part,
                            out_part,
                            dout_part,
1840
                            dout_dtype,
1841
                            fused_attn_dqkv_dtype,
1842
                            aux_ctx_tensors,
1843
                            fused_attn_backend,
1844
1845
                            cu_seqlens_q_padded=cu_seqlens_q_padded,
                            cu_seqlens_kv_padded=cu_seqlens_kv_padded,
1846
1847
                            attn_scale=ctx.softmax_scale,
                            dropout=ctx.dropout_p,
1848
                            qkv_layout=qkv_layout,
1849
                            attn_mask_type=ctx.attn_mask_type,
1850
                            attn_bias_type=ctx.attn_bias_type,
1851
1852
                            deterministic=ctx.deterministic,
                            **fp8_meta_kwargs,
1853
                        )
1854
1855
1856
1857
                        if ctx.fp8:
                            dq_ = dq_._data
                            dk_ = dk_._data
                            dv_ = dv_._data
1858
                    else:
1859
                        dq_ = torch.empty_like(q_)
1860
                        dkv_ = torch.empty_like(kv_)
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
                        fa_backward_args_thd = get_fa_args(
                            False,
                            ctx.use_flash_attn_3,
                            ctx.qkv_format,
                            cu_seqlens_q=cu_seqlens_q_per_step[cp_size - i - 1],
                            cu_seqlens_kv=cu_seqlens_kv_per_step[cp_size - i - 1],
                            max_seqlen_q=ctx.max_seqlen_q,
                            max_seqlen_kv=ctx.max_seqlen_kv,
                            dq=dq_,
                            dk=(
                                dkv_[..., 0, :, :]
                                if ctx.qkv_format in ["bshd", "sbhd"]
                                else dkv_[0]
                            ),
                            dv=(
                                dkv_[..., 1, :, :]
                                if ctx.qkv_format in ["bshd", "sbhd"]
                                else dkv_[1]
                            ),
                        )
                        if ctx.use_flash_attn_3 or (
                            fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus
                        ):
1884
                            fa_backward_kwargs["window_size"] = (-1, 0)
1885
                        elif fa_utils.v2_7_0_plus:
1886
1887
                            fa_backward_kwargs["window_size_left"] = -1
                            fa_backward_kwargs["window_size_right"] = 0
1888
                        if not ctx.use_flash_attn_3:
1889
1890
                            fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1]
                        flash_attn_bwd(
1891
1892
                            dout_,
                            q_,
1893
1894
                            kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0],
                            kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1],
1895
1896
                            out_,
                            softmax_lse,
1897
                            *fa_backward_args_thd,
1898
1899
                            causal=True,
                            **fa_backward_kwargs,
1900
                        )
1901
                elif i >= (cp_size - rank - 1):
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
                    if ctx.qkv_format == "bshd":
                        # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
                        q_, out_, dout_ = [
                            x.view(x.shape[0], -1, *x.shape[-2:]) for x in [q, out, dout]
                        ]
                        # [b, 2, sk//2, 2, np, hn] -> [b, sk//2, 2, np, hn]
                        kv_ = kv[:, 0]
                    elif ctx.qkv_format == "sbhd":
                        # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
                        q_, out_, dout_ = [x.view(-1, *x.shape[-3:]) for x in [q, out, dout]]
                        # [2, sk//2, b, 2, np, hn] -> [sk//2, b, 2, np, hn]
                        kv_ = kv[0]
                    elif ctx.qkv_format == "thd":
                        q_, out_, dout_ = q, out, dout
                        # [2, t, np, hn] -> [2, t/2, np, hn]
                        kv_ = tex.thd_read_half_tensor(kv, cu_seqlens_kv_padded, 0)
1918
                    if ctx.use_fused_attention:
1919
                        kv_ = kv_.contiguous()
1920
1921
1922
1923
1924
1925
1926
1927
                        if ctx.fp8:
                            aux_ctx_tensors = [
                                softmax_lse,
                                softmax_lse,
                                rng_states[cp_size - i - 1],
                            ]
                        else:
                            aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]]
1928
                        if attn_dbias is not None:
1929
                            aux_ctx_tensors += [attn_biases[cp_size - i - 1]]
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
                        q_part = q_
                        k_part = kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0]
                        v_part = kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1]
                        out_part = out_
                        dout_part = dout_

                        if ctx.fp8:
                            q_part = ctx.QKV_quantizer.create_tensor_from_data(
                                q_part, fake_dtype=ctx.qkv_dtype, internal=True
                            )
                            k_part = ctx.QKV_quantizer.create_tensor_from_data(
                                k_part, fake_dtype=ctx.qkv_dtype, internal=True
                            )
                            v_part = ctx.QKV_quantizer.create_tensor_from_data(
                                v_part, fake_dtype=ctx.qkv_dtype, internal=True
                            )
                            out_part = ctx.O_quantizer.create_tensor_from_data(
                                out_part, fake_dtype=ctx.qkv_dtype, internal=True
                            )
                            dout_part = ctx.dO_quantizer.create_tensor_from_data(
1950
                                dout_part, fake_dtype=dout_dtype, internal=True
1951
                            )
1952
1953
                            fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i]
                            fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i]
1954
                        dq_, dk_, dv_, dbias_ = fused_attn_bwd(
1955
                            ctx.max_seqlen_q,
1956
1957
1958
                            ctx.max_seqlen_kv // 2,
                            cu_seqlens_q_per_step[cp_size - i - 1],
                            cu_seqlens_kv_per_step[cp_size - i - 1],
1959
1960
1961
1962
1963
                            q_part,
                            k_part,
                            v_part,
                            out_part,
                            dout_part,
1964
                            dout_dtype,
1965
                            fused_attn_dqkv_dtype,
1966
                            aux_ctx_tensors,
1967
                            fused_attn_backend,
1968
1969
1970
1971
                            cu_seqlens_q_padded=cu_seqlens_q_padded,
                            cu_seqlens_kv_padded=(
                                None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded // 2
                            ),
1972
1973
                            attn_scale=ctx.softmax_scale,
                            dropout=ctx.dropout_p,
1974
                            qkv_layout=qkv_layout,
1975
                            attn_mask_type="padding" if padding else "no_mask",
1976
                            attn_bias_type=ctx.attn_bias_type,
1977
1978
                            deterministic=ctx.deterministic,
                            **fp8_meta_kwargs,
1979
                        )
1980
1981
1982
1983
                        if ctx.fp8:
                            dq_ = dq_._data
                            dk_ = dk_._data
                            dv_ = dv_._data
1984
                    else:
1985
                        dq_ = torch.empty_like(q_)
1986
                        dkv_ = torch.empty_like(kv_)
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
                        fa_backward_args_thd = get_fa_args(
                            False,
                            ctx.use_flash_attn_3,
                            ctx.qkv_format,
                            cu_seqlens_q=cu_seqlens_q_per_step[cp_size - i - 1],
                            cu_seqlens_kv=cu_seqlens_kv_per_step[cp_size - i - 1],
                            max_seqlen_q=ctx.max_seqlen_q,
                            max_seqlen_kv=ctx.max_seqlen_kv // 2,
                            dq=dq_,
                            dk=(
                                dkv_[..., 0, :, :]
                                if ctx.qkv_format in ["bshd", "sbhd"]
                                else dkv_[0]
                            ),
                            dv=(
                                dkv_[..., 1, :, :]
                                if ctx.qkv_format in ["bshd", "sbhd"]
                                else dkv_[1]
                            ),
                        )
                        if ctx.use_flash_attn_3 or (
                            fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus
                        ):
2010
                            fa_backward_kwargs["window_size"] = (-1, -1)
2011
                        elif fa_utils.v2_7_0_plus:
2012
2013
                            fa_backward_kwargs["window_size_left"] = -1
                            fa_backward_kwargs["window_size_right"] = -1
2014
                        if not ctx.use_flash_attn_3:
2015
2016
                            fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1]
                        flash_attn_bwd(
2017
2018
                            dout_,
                            q_,
2019
2020
                            kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0],
                            kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1],
2021
2022
                            out_,
                            softmax_lse,
2023
                            *fa_backward_args_thd,
2024
2025
                            causal=False,
                            **fa_backward_kwargs,
2026
2027
                        )
                else:
2028
2029
2030
2031
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
                    if ctx.qkv_format == "bshd":
                        # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
                        q_, out_, dout_ = q[:, 1], out[:, 1], dout[:, 1]
                        # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn]
                        kv_ = kv.view(kv.shape[0], -1, *kv.shape[-3:])
                    elif ctx.qkv_format == "sbhd":
                        # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn]
                        q_, out_, dout_ = q[1], out[1], dout[1]
                        # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn]
                        kv_ = kv.view(-1, *kv.shape[-4:])
                    elif ctx.qkv_format == "thd":
                        # [t, np, hn] -> [t/2, np, hn]
                        q_, out_, dout_ = [
                            tex.thd_read_half_tensor(x, cu_seqlens_q_padded, 1)
                            for x in [q, out, dout]
                        ]
                        kv_ = kv
2045
                    if ctx.use_fused_attention:
2046
                        q_, out_, dout_ = [x.contiguous() for x in [q_, out_, dout_]]
2047
2048
2049
2050
2051
2052
2053
2054
                        if ctx.fp8:
                            aux_ctx_tensors = [
                                softmax_lse_,
                                softmax_lse_,
                                rng_states[cp_size - i - 1],
                            ]
                        else:
                            aux_ctx_tensors = [softmax_lse_, rng_states[cp_size - i - 1]]
2055
                        if attn_dbias is not None:
2056
                            aux_ctx_tensors += [attn_biases[cp_size - i - 1]]
2057
2058
2059
2060
2061
2062
2063
2064
2065
2066
2067
2068
2069
2070
2071
2072
2073
2074
2075
2076
2077

                        q_part = q_
                        k_part = kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0]
                        v_part = kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1]
                        out_part = out_
                        dout_part = dout_

                        if ctx.fp8:
                            q_part = ctx.QKV_quantizer.create_tensor_from_data(
                                q_part, fake_dtype=ctx.qkv_dtype, internal=True
                            )
                            k_part = ctx.QKV_quantizer.create_tensor_from_data(
                                k_part, fake_dtype=ctx.qkv_dtype, internal=True
                            )
                            v_part = ctx.QKV_quantizer.create_tensor_from_data(
                                v_part, fake_dtype=ctx.qkv_dtype, internal=True
                            )
                            out_part = ctx.O_quantizer.create_tensor_from_data(
                                out_part, fake_dtype=ctx.qkv_dtype, internal=True
                            )
                            dout_part = ctx.dO_quantizer.create_tensor_from_data(
2078
                                dout_part, fake_dtype=dout_dtype, internal=True
2079
                            )
2080
2081
                            fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i]
                            fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i]
2082
                        dq_, dk_, dv_, dbias_ = fused_attn_bwd(
2083
                            ctx.max_seqlen_q // 2,
2084
2085
2086
                            ctx.max_seqlen_kv,
                            cu_seqlens_q_per_step[cp_size - i - 1],
                            cu_seqlens_kv_per_step[cp_size - i - 1],
2087
2088
2089
2090
2091
                            q_part,
                            k_part,
                            v_part,
                            out_part,
                            dout_part,
2092
                            dout_dtype,
2093
                            fused_attn_dqkv_dtype,
2094
                            aux_ctx_tensors,
2095
                            fused_attn_backend,
2096
2097
2098
2099
                            cu_seqlens_q_padded=(
                                None if cu_seqlens_q_padded is None else cu_seqlens_q_padded // 2
                            ),
                            cu_seqlens_kv_padded=cu_seqlens_kv_padded,
2100
2101
                            attn_scale=ctx.softmax_scale,
                            dropout=ctx.dropout_p,
2102
                            qkv_layout=qkv_layout,
2103
                            attn_mask_type="padding" if padding else "no_mask",
2104
                            attn_bias_type=ctx.attn_bias_type,
2105
2106
                            deterministic=ctx.deterministic,
                            **fp8_meta_kwargs,
2107
                        )
2108
2109
2110
2111
                        if ctx.fp8:
                            dq_ = dq_._data
                            dk_ = dk_._data
                            dv_ = dv_._data
2112
                    else:
2113
                        dq_ = torch.empty_like(q_)
2114
                        dkv_ = torch.empty_like(kv_)
2115
2116
2117
2118
2119
2120
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
                        fa_backward_args_thd = get_fa_args(
                            False,
                            ctx.use_flash_attn_3,
                            ctx.qkv_format,
                            cu_seqlens_q=cu_seqlens_q_per_step[cp_size - i - 1],
                            cu_seqlens_kv=cu_seqlens_kv_per_step[cp_size - i - 1],
                            max_seqlen_q=ctx.max_seqlen_q // 2,
                            max_seqlen_kv=ctx.max_seqlen_kv,
                            dq=dq_,
                            dk=(
                                dkv_[..., 0, :, :]
                                if ctx.qkv_format in ["bshd", "sbhd"]
                                else dkv_[0]
                            ),
                            dv=(
                                dkv_[..., 1, :, :]
                                if ctx.qkv_format in ["bshd", "sbhd"]
                                else dkv_[1]
                            ),
                        )
                        if ctx.use_flash_attn_3 or (
                            fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus
                        ):
2138
                            fa_backward_kwargs["window_size"] = (-1, -1)
2139
                        elif fa_utils.v2_7_0_plus:
2140
2141
                            fa_backward_kwargs["window_size_left"] = -1
                            fa_backward_kwargs["window_size_right"] = -1
2142
                        if not ctx.use_flash_attn_3:
2143
2144
                            fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1]
                        flash_attn_bwd(
2145
2146
                            dout_,
                            q_,
2147
2148
                            kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0],
                            kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1],
2149
2150
                            out_,
                            softmax_lse_,
2151
                            *fa_backward_args_thd,
2152
2153
                            causal=False,
                            **fa_backward_kwargs,
2154
2155
2156
                        )
            else:
                if ctx.use_fused_attention:
2157
2158
2159
2160
                    if ctx.fp8:
                        aux_ctx_tensors = [softmax_lse, softmax_lse, rng_states[cp_size - i - 1]]
                    else:
                        aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]]
2161
                    if attn_dbias is not None:
2162
                        aux_ctx_tensors += [attn_biases[cp_size - i - 1]]
2163
2164
2165
2166
2167
2168
2169
2170
                    q_part = q
                    k_part = kv[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[0]
                    v_part = kv[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[1]
                    out_part = out
                    dout_part = dout

                    if ctx.fp8:
                        q_part = ctx.QKV_quantizer.create_tensor_from_data(
2171
                            q_part, fake_dtype=ctx.qkv_dtype, internal=True
2172
2173
                        )
                        k_part = ctx.QKV_quantizer.create_tensor_from_data(
2174
                            k_part, fake_dtype=ctx.qkv_dtype, internal=True
2175
2176
                        )
                        v_part = ctx.QKV_quantizer.create_tensor_from_data(
2177
                            v_part, fake_dtype=ctx.qkv_dtype, internal=True
2178
2179
                        )
                        out_part = ctx.O_quantizer.create_tensor_from_data(
2180
                            out_part, fake_dtype=ctx.qkv_dtype, internal=True
2181
2182
                        )
                        dout_part = ctx.dO_quantizer.create_tensor_from_data(
2183
                            dout_part, fake_dtype=dout_dtype, internal=True
2184
                        )
2185
2186
                        fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i]
                        fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i]
2187
                    dq_, dk_, dv_, dbias_ = fused_attn_bwd(
2188
                        ctx.max_seqlen_q,
2189
2190
2191
                        ctx.max_seqlen_kv,
                        cu_seqlens_q_per_step[cp_size - i - 1],
                        cu_seqlens_kv_per_step[cp_size - i - 1],
2192
2193
2194
2195
2196
                        q_part,
                        k_part,
                        v_part,
                        out_part,
                        dout_part,
2197
                        dout_dtype,
2198
                        fused_attn_dqkv_dtype,
2199
                        aux_ctx_tensors,
2200
                        fused_attn_backend,
2201
2202
                        cu_seqlens_q_padded=cu_seqlens_q_padded,
                        cu_seqlens_kv_padded=cu_seqlens_kv_padded,
2203
2204
                        attn_scale=ctx.softmax_scale,
                        dropout=ctx.dropout_p,
2205
                        qkv_layout=qkv_layout,
2206
                        attn_mask_type=ctx.attn_mask_type,
2207
                        attn_bias_type=ctx.attn_bias_type,
2208
2209
                        deterministic=ctx.deterministic,
                        **fp8_meta_kwargs,
2210
                    )
2211
2212
2213
2214
2215
2216

                    if ctx.fp8:
                        dq_ = dq_._data
                        dk_ = dk_._data
                        dv_ = dv_._data

2217
                else:
2218
2219
                    dq_ = torch.empty_like(q)
                    dkv_ = torch.empty_like(kv)
2220
2221
2222
2223
2224
2225
2226
2227
2228
2229
2230
2231
2232
                    fa_backward_args_thd = get_fa_args(
                        False,
                        ctx.use_flash_attn_3,
                        ctx.qkv_format,
                        cu_seqlens_q=cu_seqlens_q_per_step[cp_size - i - 1],
                        cu_seqlens_kv=cu_seqlens_kv_per_step[cp_size - i - 1],
                        max_seqlen_q=ctx.max_seqlen_q,
                        max_seqlen_kv=ctx.max_seqlen_kv,
                        dq=dq_,
                        dk=dkv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[0],
                        dv=dkv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[1],
                    )
                    if ctx.use_flash_attn_3 or (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus):
2233
                        fa_backward_kwargs["window_size"] = (-1, -1)
2234
                    elif fa_utils.v2_7_0_plus:
2235
2236
                        fa_backward_kwargs["window_size_left"] = -1
                        fa_backward_kwargs["window_size_right"] = -1
2237
                    if not ctx.use_flash_attn_3:
2238
2239
                        fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1]
                    flash_attn_bwd(
2240
2241
2242
2243
2244
                        dout,
                        q,
                        kv[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[0],
                        kv[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[1],
                        out,
2245
                        softmax_lse,
2246
                        *fa_backward_args_thd,
2247
2248
                        causal=False,
                        **fa_backward_kwargs,
2249
2250
                    )

2251
2252
            if ctx.fp8:
                dq = dq_fp8[(rank + i + 1) % cp_size]
2253
2254
2255
            if causal and ctx.qkv_format in ["bshd", "sbhd"] and i >= (cp_size - rank - 1):
                # [b, sq, np, hn] -> [b, 2, sq//2, np, hn] or
                # [sq, b, np, hn] -> [2, sq//2, b, np, hn]
2256
                dq_ = dq_.view(*dq.shape)
2257

2258
2259
2260
2261
2262
2263
2264
2265
2266
2267
2268
            if ctx.fp8:
                if i >= (cp_size - rank - 1) or not causal:
                    dq.copy_(dq_)
                else:
                    if ctx.qkv_format == "bshd":
                        dq[:, 0, ...].fill_(0)
                        dq[:, 1, ...].copy_(dq_)
                    elif ctx.qkv_format == "sbhd":
                        dq[0].fill_(0)
                        dq[1].copy_(dq_)
            elif causal:
2269
                if i > (cp_size - rank - 1):
2270
                    dq.add_(dq_)
2271
2272
                elif i == (cp_size - rank - 1):
                    if rank == (cp_size - 1):
2273
2274
                        dq.copy_(dq_)
                    else:
2275
2276
2277
2278
2279
2280
                        if ctx.qkv_format == "bshd":
                            dq[:, 0, ...].copy_(dq_[:, 0, ...])
                            dq[:, 1, ...].add_(dq_[:, 1, ...])
                        elif ctx.qkv_format == "sbhd":
                            dq[0].copy_(dq_[0])
                            dq[1].add_(dq_[1])
2281
                        elif ctx.qkv_format == "thd":
2282
                            tex.thd_grad_correction(dq, dq_, cu_seqlens_q_padded, "copy", "add")
2283
                elif i > 0:
2284
2285
2286
2287
                    if ctx.qkv_format == "bshd":
                        dq[:, 1, ...].add_(dq_)
                    elif ctx.qkv_format == "sbhd":
                        dq[1].add_(dq_)
2288
                    elif ctx.qkv_format == "thd":
2289
                        tex.thd_grad_correction(dq, dq_, cu_seqlens_q_padded, "none", "add")
2290
                else:
2291
2292
2293
2294
                    if ctx.qkv_format == "bshd":
                        dq[:, 1, ...].copy_(dq_)
                    elif ctx.qkv_format == "sbhd":
                        dq[1].copy_(dq_)
2295
                    elif ctx.qkv_format == "thd":
2296
                        tex.thd_grad_correction(dq, dq_, cu_seqlens_q_padded, "none", "copy")
2297
2298
2299
2300
2301
            else:
                if i == 0:
                    dq.copy_(dq_)
                else:
                    dq.add_(dq_)
2302

2303
            if attn_dbias is not None:
2304
                idx = (rank + i + 1) % cp_size
2305
                if i == (cp_size - 1) or not causal:
2306
                    # [b, np, sq, sk//cp] -> [b, np, sq, 2, sk//(2*cp)]
2307
                    dbias_ = dbias_.view(*dbias_.shape[:-1], 2, dbias_.shape[-1] // 2)
2308
                    attn_dbias[..., idx, :].copy_(dbias_[..., 0, :])
2309
2310
                    attn_dbias[..., (2 * cp_size - idx - 1), :].copy_(dbias_[..., 1, :])
                elif i >= (cp_size - rank - 1):
2311
2312
2313
2314
                    # [b, np, sq, sk//(2*cp)]
                    attn_dbias[..., idx, :].copy_(dbias_)
                else:
                    # [b, np, sq//2, sk//cp] -> [b, np, sq//2, 2, sk//(2*cp)]
2315
                    dbias_ = dbias_.view(*dbias_.shape[:-1], 2, dbias_.shape[-1] // 2)
2316
                    attn_dbias_[..., 1, :, idx, :].copy_(dbias_[..., 0, :])
2317
                    attn_dbias_[..., 1, :, (2 * cp_size - idx - 1), :].copy_(dbias_[..., 1, :])
2318

2319
2320
2321
            # wait until dKV is received
            for req in send_recv_reqs:
                req.wait()
2322

2323
2324
2325
2326
2327
2328
2329
            if ctx.fp8:
                if i < cp_size - 1:
                    dkv = dkv_fp8_[(rank + i + 1) % cp_size]
                else:
                    dkv = dkv_fp8[(rank + i + 1) % cp_size]
            else:
                dkv = p2p_comm_buffers[(i + 1) % 2][1]
2330
            if ctx.use_fused_attention:
2331
                if ctx.qkv_format in ["bshd", "sbhd"]:
2332
2333
2334
2335
2336
2337
2338
2339
2340
2341
2342
2343
2344
2345
                    dkv_ = _combine_tensors([dk_, dv_], -2)
                elif ctx.qkv_format == "thd":
                    dkv_ = torch.cat(
                        (dk_.unsqueeze(0), dv_.unsqueeze(0)), dim=0
                    )  # pylint: disable=used-before-assignment
            if ctx.qkv_format in ["bshd", "sbhd"]:
                # [b, 2, sk//2, 2, np, hn] -> [2, b, 2, sk//2, np, hn] or
                # [2, sk//2, b, 2, np, hn] -> [2, 2, sk//2, b, np, hn]
                dkv = dkv.view(2, *dkv.shape[0:-3], *dkv.shape[-2:])
                dkv_ = dkv_.movedim(-3, 0)
                if causal and (i < (cp_size - rank - 1) or i == (cp_size - 1)):
                    # [2, b, sk, np, hn] -> [2, b, 2, sk//2, np, hn] or
                    # [2, sk, b, np, hn] -> [2, 2, sk//2, b, np, hn]
                    dkv_ = dkv_.view(*dkv.shape)
2346

2347
2348
2349
2350
2351
2352
2353
2354
2355
2356
2357
            if ctx.fp8:
                if causal and i >= (cp_size - rank - 1) and i != (cp_size - 1):
                    if ctx.qkv_format == "bshd":
                        dkv[:, :, 0, ...].copy_(dkv_)
                        dkv[:, :, 1, ...].fill_(0)
                    elif ctx.qkv_format == "sbhd":
                        dkv[:, 0, ...].copy_(dkv_)
                        dkv[:, 1, ...].fill_(0)
                else:
                    dkv.copy_(dkv_)
            elif causal:
2358
                if i == (cp_size - 1):
2359
                    if rank == 0:
2360
2361
2362
2363
2364
2365
                        if ctx.qkv_format == "bshd":
                            dkv[:, :, 0, ...].add_(dkv_[:, :, 0, ...])
                            dkv[:, :, 1, ...].copy_(dkv_[:, :, 1, ...])
                        elif ctx.qkv_format == "sbhd":
                            dkv[:, 0, ...].add_(dkv_[:, 0, ...])
                            dkv[:, 1, ...].copy_(dkv_[:, 1, ...])
2366
                        elif ctx.qkv_format == "thd":
2367
                            tex.thd_grad_correction(dkv, dkv_, cu_seqlens_kv_padded, "add", "copy")
2368
2369
                    else:
                        dkv.add_(dkv_)
2370
2371
                elif i >= (cp_size - rank - 1):
                    if i == 0 and rank == (cp_size - 1):
2372
2373
2374
2375
                        if ctx.qkv_format == "bshd":
                            dkv[:, :, 0, ...].copy_(dkv_)
                        elif ctx.qkv_format == "sbhd":
                            dkv[:, 0, ...].copy_(dkv_)
2376
                        elif ctx.qkv_format == "thd":
2377
                            tex.thd_grad_correction(dkv, dkv_, cu_seqlens_kv_padded, "copy", "none")
2378
                    else:
2379
2380
2381
2382
                        if ctx.qkv_format == "bshd":
                            dkv[:, :, 0, ...].add_(dkv_)
                        elif ctx.qkv_format == "sbhd":
                            dkv[:, 0, ...].add_(dkv_)
2383
                        elif ctx.qkv_format == "thd":
2384
                            tex.thd_grad_correction(dkv, dkv_, cu_seqlens_kv_padded, "add", "none")
2385
2386
2387
2388
2389
                elif i > 0:
                    dkv.add_(dkv_)
                else:
                    dkv.copy_(dkv_)
            else:
2390
2391
2392
2393
2394
                if i == 0:
                    dkv.copy_(dkv_)
                else:
                    dkv.add_(dkv_)

2395
        if ctx.fp8 and ctx.use_fused_attention:
2396
            amax_cp_bwd = amax_per_step.amax(dim=1)
2397
2398
            ctx.dP_quantizer.amax.copy_(amax_cp_bwd[0])
            ctx.dQKV_CP_quantizer.amax.copy_(amax_cp_bwd[1])
2399
2400
2401
2402
            if ctx.qkv_format in ["bshd", "sbhd"]:
                # [cp, b, 2, sk//2, 2, np, hn] -> [cp, 2, b, 2, sk//2, np, hn] or
                # [cp, 2, sk//2, b, 2, np, hn] -> [cp, 2, 2, sk//2, b, np, hn]
                dkv_fp8 = dkv_fp8.view(cp_size, 2, *dkv_fp8.shape[1:-3], *dkv_fp8.shape[-2:])
2403
2404
2405
2406
2407
2408
2409
            dq = ctx.dQKV_CP_quantizer.create_tensor_from_data(
                dq_fp8, fake_dtype=torch.float32, internal=True
            )
            dkv = ctx.dQKV_CP_quantizer.create_tensor_from_data(
                dkv_fp8, fake_dtype=torch.float32, internal=True
            )
            dq, dkv = [x.dequantize(dtype=torch.float32) for x in [dq, dkv]]
2410
2411
            dq, dkv = [x.sum(dim=0).to(dout_dtype) for x in [dq, dkv]]

2412
        if causal:
2413
2414
            if ctx.qkv_format == "bshd":
                # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
2415
                dq = dq.view(dq.shape[0], -1, *dq.shape[-2:])
2416
                # [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn]
2417
                dkv = dkv.view(*dkv.shape[0:2], -1, *dkv.shape[-2:])
2418
2419
            elif ctx.qkv_format == "sbhd":
                # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
2420
                dq = dq.view(-1, *dq.shape[-3:])
2421
                # [2, 2, sk//2, b, np, hn] -> [2, sk, b, np, hn]
2422
2423
                dkv = dkv.view(dkv.shape[0], -1, *dkv.shape[-3:])

2424
2425
2426
        if ctx.qkv_format == "thd" and not ctx.use_fused_attention:
            dq[cu_seqlens_q_padded[-1] :].fill_(0)
            dkv[:, cu_seqlens_kv_padded[-1] :].fill_(0)
2427

2428
        if ctx.fp8 and ctx.is_input_fp8:
2429
2430
            assert torch.uint8 not in [dq.dtype, dkv.dtype]
            dq, dkv = [ctx.dQKV_quantizer(x)._data for x in [dq, dkv]]
2431
2432
2433
        dk, dv = dkv[0], dkv[1]

        if cp_size_a2a > 1:
2434
            chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size_a2a, q.device)
2435
2436
2437
2438
2439
2440
2441
2442
2443
2444
2445
2446
2447
2448
            dq, dk, dv = flash_attn_a2a_communicate(
                [dq, dk, dv],
                chunk_ids_for_a2a,
                seq_dim,
                cp_size_a2a,
                ctx.cp_group_a2a,
                ctx.cp_stream,
                False,
            )
            if ctx.qkv_format == "bshd":
                dq, dk, dv = [x.view(ctx.batch_size, -1, *x.shape[-2:]) for x in [dq, dk, dv]]
            elif ctx.qkv_format == "sbhd":
                dq, dk, dv = [x.view(-1, ctx.batch_size, *x.shape[-2:]) for x in [dq, dk, dv]]

2449
2450
2451
        if attn_dbias is not None:
            # [b, np, sq, 2*cp, sk//(2*cp)] -> [b, np, sq, sk]
            attn_dbias = attn_dbias.view(*attn_dbias.shape[:-2], -1)
2452
2453
        # converting torch.uint8 to float8tensor
        if ctx.fp8 and ctx.is_input_fp8:
2454
2455
2456
            dq = ctx.dQKV_quantizer.create_tensor_from_data(dq, fake_dtype=dout_dtype)
            dk = ctx.dQKV_quantizer.create_tensor_from_data(dk, fake_dtype=dout_dtype)
            dv = ctx.dQKV_quantizer.create_tensor_from_data(dv, fake_dtype=dout_dtype)
2457
        nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVP2P.backward")
2458

2459
2460
2461
        return (
            None,
            dq,
2462
2463
            dk,
            dv,
2464
2465
2466
2467
2468
2469
2470
2471
2472
2473
2474
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
2475
            attn_dbias,
2476
2477
2478
2479
2480
            None,
            None,
            None,
            None,
            None,
2481
2482
            None,
            None,
2483
            None,
2484
            None,
2485
            None,
2486
        )
2487
2488


2489
2490
def get_kv_seq_info_after_all_gather(
    local_chunk_id, cp_size, max_seqlen_q, max_seqlen_kv, window_size, causal
2491
):
2492
2493
2494
2495
2496
2497
2498
2499
2500
2501
2502
2503
2504
2505
2506
2507
2508
2509
2510
2511
2512
2513
    """Compute KV sequence index range and update window size after all-gather."""
    local_chunk_end_idx = (local_chunk_id + 1) * max_seqlen_kv
    full_seq_end_idx = max_seqlen_kv * cp_size * 2

    if window_size is None:
        window_size = (-1, 0) if causal else (-1, -1)

    if window_size[1] == -1:
        seq_end_idx = full_seq_end_idx
        window_size_right = -1
    else:
        seq_end_idx = min(full_seq_end_idx, local_chunk_end_idx + window_size[1])
        window_size_right = local_chunk_end_idx + window_size[1] - seq_end_idx

    if window_size[0] == -1:
        seq_start_idx = 0
        window_size_left = -1
    else:
        seq_start_idx = max(0, local_chunk_end_idx - max_seqlen_q - window_size[0])
        window_size_left = window_size[0] + seq_end_idx - local_chunk_end_idx

    return (seq_start_idx, seq_end_idx), (window_size_left, window_size_right)
2514
2515
2516
2517


class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
    """
2518
2519
    Attention implementation with context parallelism. KV all-gather between CP ranks is exposed.
    Refer section 3.3.2 of `The Llama 3 Herd of Models <https://arxiv.org/abs/2407.21783>`_.
2520
2521
2522
2523
2524
2525
2526
2527
2528
2529
2530
2531
2532
2533
2534
2535
2536
2537
2538
2539
2540
2541
    """

    @staticmethod
    def forward(
        ctx,
        is_training,
        q,
        k,
        v,
        cu_seqlens_q,
        max_seqlen_q,
        max_seqlen_kv,
        cu_seqlens_q_padded,
        dropout_p,
        softmax_scale,
        qkv_format,
        attn_mask_type,
        attn_bias_type,
        attn_bias,
        deterministic,
        use_fused_attention,
        window_size,
2542
2543
        cp_group,
        cp_stream,
2544
        use_flash_attn_3,
2545
    ):
2546
        # pylint: disable=missing-function-docstring
2547
        nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVAllGather.forward")
2548
2549
2550
2551
2552
2553
        if softmax_scale is None:
            softmax_scale = q.shape[-1] ** (-0.5)

        cp_size = get_distributed_world_size(cp_group)
        rank = get_distributed_rank(cp_group)

2554
2555
        qkv_dtype = q.dtype

2556
2557
        causal = "causal" in attn_mask_type
        padding = "padding" in attn_mask_type
2558
        assert not padding, f"{attn_mask_type} mask type is not supported!"
2559
2560
2561
2562
2563
        if use_fused_attention and causal and "bottom_right" not in attn_mask_type:
            attn_mask_type = attn_mask_type + "_bottom_right"
        assert attn_bias_type == "no_bias", f"{attn_bias_type} bias type is not supported!"
        assert q.shape[-1] % 8 == 0, "Hidden size per attention head should be multiple of 8!"
        assert (
2564
            use_fused_attention or fa_utils.v2_3_plus
2565
        ), "Sliding window attention only can work with FusedAttention or FlashAttention >= 2.3!"
2566

2567
        flash_attn_fwd = None
2568
2569
        if not use_fused_attention:
            fa_forward_kwargs = {"softmax_scale": softmax_scale}
2570
2571
            if use_flash_attn_3:
                flash_attn_fwd = _flash_attn_fwd_v3
2572
            else:
2573
2574
2575
2576
                if qkv_format == "thd":
                    flash_attn_fwd = _flash_attn_varlen_fwd
                else:
                    flash_attn_fwd = _flash_attn_fwd
2577
2578
                fa_forward_kwargs["dropout_p"] = dropout_p
                fa_forward_kwargs["return_softmax"] = False
2579
                if fa_utils.v2_4_plus:
2580
                    fa_forward_kwargs["alibi_slopes"] = None
2581
                if fa_utils.v2_5_7_plus and qkv_format == "thd":
2582
                    fa_forward_kwargs["block_table"] = None
2583
                if fa_utils.v2_6_0_plus:
2584
                    fa_forward_kwargs["softcap"] = 0.0
2585
2586
2587
2588
2589
2590
2591
2592
2593
2594
2595

        assert qkv_format != "thd", f"{qkv_format} format is not supported!"
        qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format

        seq_dim = qkv_format.index("s")
        assert (
            q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0
        ), "Sequence length per GPU needs to be divisible by 2!"

        max_seqlen_q = max_seqlen_q // (2 * cp_size)
        max_seqlen_kv = max_seqlen_kv // (2 * cp_size)
2596
2597
        if use_fused_attention or qkv_format == "thd":
            cu_seqlens_q = cu_seqlens_q // (2 * cp_size)
2598
2599
2600
2601
        if cu_seqlens_q_padded is not None and qkv_format == "thd":
            cu_seqlens_q_padded = cu_seqlens_q_padded // (2 * cp_size)
        else:
            cu_seqlens_q_padded = None
2602

2603
2604
2605
2606
        # [b, s, np, hn] -> [b, 2, s//2, np, hn] or [s, b, np, hn] -> [2, s//2, b, np, hn]
        q = q.view(*q.shape[:seq_dim], 2, q.shape[seq_dim] // 2, *q.shape[(seq_dim + 1) :])
        # [b, s, np, hn] or [s, b, np, hn] -> [s, b, np, hn]
        k, v = [x.movedim(seq_dim, 0).contiguous() for x in [k, v]]
2607

2608
        # [s, b, np, hn] -> [cp, s, b, np, hn]
2609
2610
        k_ag, _ = gather_along_first_dim(k, cp_group)
        v_ag, _ = gather_along_first_dim(v, cp_group)
2611
2612

        # [cp, s, b, np, hn] -> [cp*2, s//2, b, np, hn]
2613
2614
        k_ag = k_ag.view(2 * cp_size, k.shape[0] // 2, *k.shape[1:])
        v_ag = v_ag.view(2 * cp_size, v.shape[0] // 2, *v.shape[1:])
2615
        chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_before_attn(cp_size, k.device)
2616
2617
2618
2619
2620
2621
2622
2623
2624
        k_ag = torch.index_select(k_ag, dim=0, index=chunk_ids_for_kv_ag)
        v_ag = torch.index_select(v_ag, dim=0, index=chunk_ids_for_kv_ag)
        # [cp*2, s//2, b, np, hn] -> [cp*s, b, np, hn]
        k_ag = k_ag.view(-1, *k.shape[1:])
        v_ag = v_ag.view(-1, *v.shape[1:])
        cp_stream.wait_stream(torch.cuda.current_stream())

        # create two streams to resolve wave quantization issue of Flash Attn in each step
        flash_attn_streams = [torch.cuda.current_stream(), cp_stream]
2625
2626

        local_seq_chunk_ids = [rank, 2 * cp_size - rank - 1]
2627
2628
2629
        kv_seq_range_per_step = [None, None]
        window_size_per_step = [None, None]
        cu_seqlens_kv_per_step = [None, None]
2630
2631
2632
2633
2634
2635
2636
2637
        out_per_step = [None, None]
        softmax_lse_per_step = [None, None]
        rng_states = [None, None]
        out = torch.empty_like(q)

        for i in range(len(local_seq_chunk_ids) + 1):
            if i < len(local_seq_chunk_ids):
                with torch.cuda.stream(flash_attn_streams[i]):
2638
2639
                    # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
                    # or [2, sq//2, b, np, hn] -> [sq//2, b, np, hn]
2640
2641
2642
2643
2644
2645
2646
2647
2648
                    q_ = q.select(seq_dim, i).contiguous()
                    kv_seq_range_per_step[i], window_size_per_step[i] = (
                        get_kv_seq_info_after_all_gather(
                            local_seq_chunk_ids[i],
                            cp_size,
                            max_seqlen_q,
                            max_seqlen_kv,
                            window_size,
                            causal,
2649
                        )
2650
2651
2652
2653
2654
2655
                    )
                    seq_start_idx, seq_end_idx = (
                        kv_seq_range_per_step[i][0],
                        kv_seq_range_per_step[i][1],
                    )
                    max_seqlen_kv_ = seq_end_idx - seq_start_idx
2656
                    if use_fused_attention or qkv_format == "thd":
2657
                        cu_seqlens_kv_per_step[i] = dpa_utils.get_full_cu_seqlens(
2658
2659
                            k.shape[1], max_seqlen_kv_, k.device
                        )
2660
2661
2662
                    k_, v_ = [x[seq_start_idx:seq_end_idx] for x in [k_ag, v_ag]]
                    # [s_range, b, np, hn] -> [b, s_range, np, hn] or [s_range, b, np, hn]
                    k_, v_ = [x.movedim(0, seq_dim).contiguous() for x in [k_, v_]]
2663
2664
2665
2666
                    if use_fused_attention:
                        out_per_step[i], [softmax_lse_per_step[i], rng_states[i]] = fused_attn_fwd(
                            is_training,
                            max_seqlen_q,
2667
                            max_seqlen_kv_,
2668
                            cu_seqlens_q,
2669
                            cu_seqlens_kv_per_step[i],
2670
2671
2672
                            q_,
                            k_,
                            v_,
2673
                            qkv_dtype,
2674
2675
2676
2677
2678
2679
2680
2681
                            tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
                            attn_scale=softmax_scale,
                            dropout=dropout_p,
                            qkv_layout=qkv_layout,
                            attn_mask_type=attn_mask_type,
                            attn_bias_type=attn_bias_type,
                            attn_bias=attn_bias,
                            cu_seqlens_q_padded=cu_seqlens_q_padded,
2682
2683
                            cu_seqlens_kv_padded=cu_seqlens_kv_per_step[i],
                            window_size=window_size_per_step[i],
2684
2685
                        )
                    else:
2686
2687
2688
2689
2690
2691
2692
2693
2694
2695
                        fa_forward_args_thd = get_fa_args(
                            True,
                            use_flash_attn_3,
                            qkv_format,
                            cu_seqlens_q=cu_seqlens_q,
                            cu_seqlens_kv=cu_seqlens_kv_per_step[i],
                            max_seqlen_q=max_seqlen_q,
                            max_seqlen_kv=max_seqlen_kv_,
                        )
                        if use_flash_attn_3 or (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus):
2696
                            fa_forward_kwargs["window_size"] = window_size_per_step[i]
2697
                        elif fa_utils.v2_7_0_plus:
2698
2699
                            fa_forward_kwargs["window_size_left"] = window_size_per_step[i][0]
                            fa_forward_kwargs["window_size_right"] = window_size_per_step[i][1]
2700
2701
2702
2703
                        fa_outputs = flash_attn_fwd(
                            q_,
                            k_,
                            v_,
2704
                            *fa_forward_args_thd,
2705
2706
                            causal=causal,
                            **fa_forward_kwargs,
2707
                        )
2708
                        if not fa_utils.v2_7_0_plus:
2709
2710
                            out_per_step[i] = fa_outputs[4]
                            softmax_lse_per_step[i] = fa_outputs[5]
2711
                            if not use_flash_attn_3:
2712
2713
2714
2715
                                rng_states[i] = fa_outputs[7]
                        else:
                            out_per_step[i] = fa_outputs[0]
                            softmax_lse_per_step[i] = fa_outputs[1]
2716
                            if not use_flash_attn_3:
2717
                                rng_states[i] = fa_outputs[3]
2718
2719
2720
2721

            if i > 0:
                with torch.cuda.stream(flash_attn_streams[i - 1]):
                    if qkv_format == "bshd":
2722
                        out[:, i - 1].copy_(out_per_step[i - 1])
2723
                    elif qkv_format == "sbhd":
2724
                        out[i - 1].copy_(out_per_step[i - 1])
2725
2726
2727
2728
2729
2730
2731
2732
2733
2734
2735
2736
2737
2738
2739
2740
2741

        torch.cuda.current_stream().wait_stream(cp_stream)

        if use_fused_attention:
            if qkv_format == "bshd":
                out = out.view(out.shape[0], -1, *out.shape[-2:])
            elif qkv_format == "sbhd":
                out = out.view(-1, *out.shape[-3:])
        else:
            out = out.view(-1, *out.shape[-2:])

        ctx.save_for_backward(
            q,
            k,
            v,
            cu_seqlens_q,
            cu_seqlens_q_padded,
2742
            *cu_seqlens_kv_per_step,
2743
2744
2745
2746
            *out_per_step,
            *softmax_lse_per_step,
            *rng_states,
        )
2747
2748

        ctx.qkv_dtype = qkv_dtype
2749
2750
        ctx.kv_seq_range_per_step = kv_seq_range_per_step
        ctx.window_size_per_step = window_size_per_step
2751
2752
2753
2754
2755
2756
2757
        ctx.cp_group = cp_group
        ctx.cp_stream = cp_stream
        ctx.dropout_p = dropout_p
        ctx.max_seqlen_q = max_seqlen_q
        ctx.softmax_scale = softmax_scale
        ctx.qkv_format = qkv_format
        ctx.attn_bias_type = attn_bias_type
2758
        ctx.attn_mask_type = attn_mask_type
2759
2760
        ctx.deterministic = deterministic
        ctx.use_fused_attention = use_fused_attention
2761
        ctx.use_flash_attn_3 = use_flash_attn_3
2762
        nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVAllGather.forward")
2763
2764
2765
2766
        return out

    @staticmethod
    def backward(ctx, dout):
2767
        # pylint: disable=missing-function-docstring
2768
        nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVAllGather.backward")
2769
2770
2771
        cp_size = get_distributed_world_size(ctx.cp_group)
        rank = get_distributed_rank(ctx.cp_group)

2772
2773
2774
2775
2776
2777
        (*saved_tensors,) = ctx.saved_tensors
        (q, k, v, cu_seqlens_q, cu_seqlens_q_padded) = saved_tensors[:5]
        cu_seqlens_kv_per_step = saved_tensors[5:7]
        out_per_step = saved_tensors[7:9]
        softmax_lse_per_step = saved_tensors[9:11]
        rng_states = saved_tensors[11:13]
2778
2779
        kv_seq_range_per_step = ctx.kv_seq_range_per_step
        window_size_per_step = ctx.window_size_per_step
2780

2781
        seq_dim = ctx.qkv_format.index("s")
2782
2783
        qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format

2784
        dout = dout.view(q.shape)
2785
        dq = torch.empty_like(q)
2786
        dk = torch.zeros((k.shape[0] * cp_size, *k.shape[1:]), dtype=k.dtype, device=k.device)
2787
2788
2789
2790
2791
2792
2793
2794
2795
2796
        dv = torch.zeros_like(dk)
        dq_per_step = [None, None]
        dk_per_step = [None, None]
        dv_per_step = [None, None]

        # create two streams to resolve wave quantization issue of Flash Attn in each step
        flash_attn_streams = [torch.cuda.current_stream(), ctx.cp_stream]
        # synchronize dkv update across steps
        dkv_update_done = torch.cuda.Event()

2797
        # [s, b, np, hn] -> [cp, s, b, np, hn]
2798
2799
        k_ag, _ = gather_along_first_dim(k, ctx.cp_group)
        v_ag, _ = gather_along_first_dim(v, ctx.cp_group)
2800
2801

        # [cp, s, b, np, hn] -> [cp*2, s//2, b, np, hn]
2802
2803
        k_ag = k_ag.view(2 * cp_size, k.shape[0] // 2, *k.shape[1:])
        v_ag = v_ag.view(2 * cp_size, v.shape[0] // 2, *v.shape[1:])
2804
        chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_before_attn(cp_size, k.device)
2805
2806
2807
2808
2809
2810
        k_ag = torch.index_select(k_ag, dim=0, index=chunk_ids_for_kv_ag)
        v_ag = torch.index_select(v_ag, dim=0, index=chunk_ids_for_kv_ag)
        # [cp*2, s//2, b, np, hn] -> [cp*s, b, np, hn]
        k_ag = k_ag.view(-1, *k.shape[1:])
        v_ag = v_ag.view(-1, *v.shape[1:])
        ctx.cp_stream.wait_stream(torch.cuda.current_stream())
2811
2812
2813

        local_seq_chunk_ids = [rank, 2 * cp_size - rank - 1]

2814
        flash_attn_bwd = None
2815
2816
        if not ctx.use_fused_attention:
            fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale}
2817
2818
            if ctx.use_flash_attn_3:
                flash_attn_bwd = _flash_attn_bwd_v3
2819
2820
                fa_backward_kwargs["deterministic"] = ctx.deterministic
            else:
2821
2822
2823
2824
                if ctx.qkv_format == "thd":
                    flash_attn_bwd = _flash_attn_varlen_bwd
                else:
                    flash_attn_bwd = _flash_attn_bwd
2825
                fa_backward_kwargs["dropout_p"] = ctx.dropout_p
2826
                if fa_utils.v2_4_plus:
2827
                    fa_backward_kwargs["alibi_slopes"] = None
2828
                if fa_utils.v2_4_1_plus:
2829
                    fa_backward_kwargs["deterministic"] = ctx.deterministic
2830
                if fa_utils.v2_6_0_plus:
2831
                    fa_backward_kwargs["softcap"] = 0.0
2832
2833
2834
2835

        for i in range(len(local_seq_chunk_ids) + 1):
            if i < len(local_seq_chunk_ids):
                with torch.cuda.stream(flash_attn_streams[i]):
2836
2837
                    # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
                    # or [2, sq//2, b, np, hn] -> [sq//2, b, np, hn]
2838
2839
2840
2841
2842
2843
2844
2845
2846
                    q_ = q.select(seq_dim, i).contiguous()
                    seq_start_idx, seq_end_idx = (
                        kv_seq_range_per_step[i][0],
                        kv_seq_range_per_step[i][1],
                    )
                    max_seqlen_kv = seq_end_idx - seq_start_idx
                    k_, v_ = [x[seq_start_idx:seq_end_idx] for x in [k_ag, v_ag]]
                    # [cp*s, b, np, hn] -> [b, s_range, np, hn] or [s_range, b, np, hn]
                    k_, v_ = [x.movedim(0, seq_dim).contiguous() for x in [k_, v_]]
2847
                    out_ = out_per_step[i]
2848
                    dout_ = dout.select(seq_dim, i).contiguous().view(out_.shape)
2849
2850
2851
2852
                    if ctx.use_fused_attention:
                        aux_ctx_tensors = [softmax_lse_per_step[i], rng_states[i]]
                        dq_per_step[i], dk_per_step[i], dv_per_step[i], _ = fused_attn_bwd(
                            ctx.max_seqlen_q,
2853
                            max_seqlen_kv,
2854
                            cu_seqlens_q,
2855
                            cu_seqlens_kv_per_step[i],
2856
2857
2858
2859
2860
                            q_,
                            k_,
                            v_,
                            out_,
                            dout_,
2861
                            ctx.qkv_dtype,
2862
                            TE_DType[dout.dtype],
2863
2864
2865
                            aux_ctx_tensors,
                            tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
                            cu_seqlens_q_padded=cu_seqlens_q_padded,
2866
                            cu_seqlens_kv_padded=cu_seqlens_kv_per_step[i],
2867
2868
2869
2870
2871
                            attn_scale=ctx.softmax_scale,
                            dropout=ctx.dropout_p,
                            qkv_layout=qkv_layout,
                            attn_mask_type=ctx.attn_mask_type,
                            attn_bias_type=ctx.attn_bias_type,
2872
2873
                            window_size=window_size_per_step[i],
                            deterministic=ctx.deterministic,
2874
2875
2876
2877
2878
                        )
                    else:
                        dq_per_step[i], dk_per_step[i], dv_per_step[i] = [
                            torch.empty_like(x) for x in [q_, k_, v_]
                        ]
2879
2880
2881
2882
2883
2884
2885
2886
2887
2888
2889
2890
2891
                        fa_backward_args_thd = get_fa_args(
                            False,
                            ctx.use_flash_attn_3,
                            ctx.qkv_format,
                            cu_seqlens_q=cu_seqlens_q,
                            cu_seqlens_kv=cu_seqlens_kv_per_step[i],
                            max_seqlen_q=ctx.max_seqlen_q,
                            max_seqlen_kv=max_seqlen_kv,
                            dq=dq_per_step[i],
                            dk=dk_per_step[i],
                            dv=dv_per_step[i],
                        )
                        if not ctx.use_flash_attn_3:
2892
                            fa_backward_kwargs["rng_state"] = rng_states[i]
2893
2894
2895
                        if ctx.use_flash_attn_3 or (
                            fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus
                        ):
2896
                            fa_backward_kwargs["window_size"] = window_size_per_step[i]
2897
                        elif fa_utils.v2_7_0_plus:
2898
2899
                            fa_backward_kwargs["window_size_left"] = window_size_per_step[i][0]
                            fa_backward_kwargs["window_size_right"] = window_size_per_step[i][1]
2900
                        flash_attn_bwd(
2901
2902
2903
2904
2905
2906
                            dout_,
                            q_,
                            k_,
                            v_,
                            out_,
                            softmax_lse_per_step[i],
2907
                            *fa_backward_args_thd,
2908
2909
                            causal="causal" in ctx.attn_mask_type,
                            **fa_backward_kwargs,
2910
2911
2912
2913
2914
                        )

            if i > 0:
                with torch.cuda.stream(flash_attn_streams[i - 1]):
                    if ctx.qkv_format == "bshd":
2915
                        dq[:, i - 1].copy_(dq_per_step[i - 1])
2916
                    elif ctx.qkv_format == "sbhd":
2917
2918
2919
2920
2921
2922
                        dq[i - 1].copy_(dq_per_step[i - 1])
                    # [b, s_range, np, hn] or [s_range, b, np, hn] -> [s_range, b, np, hn]
                    dk_per_step[i - 1], dv_per_step[i - 1] = [
                        x.movedim(seq_dim, 0).contiguous()
                        for x in [dk_per_step[i - 1], dv_per_step[i - 1]]
                    ]
2923
2924
2925
                    # wait until dkv update of last step is done
                    if i > 1:
                        flash_attn_streams[i - 1].wait_event(dkv_update_done)
2926
2927
2928
2929
2930
2931
                    seq_start_idx, seq_end_idx = (
                        kv_seq_range_per_step[i - 1][0],
                        kv_seq_range_per_step[i - 1][1],
                    )
                    dk[seq_start_idx:seq_end_idx].add_(dk_per_step[i - 1])
                    dv[seq_start_idx:seq_end_idx].add_(dv_per_step[i - 1])
2932
2933
2934
2935
2936
                    if i < len(local_seq_chunk_ids):
                        flash_attn_streams[i - 1].record_event(dkv_update_done)

        torch.cuda.current_stream().wait_stream(ctx.cp_stream)

2937
2938
2939
        # [cp*s, b, np, hn] -> [cp*2, s//2, b, np, hn]
        dk = dk.view(2 * cp_size, -1, *dk.shape[-3:])
        dv = dv.view(2 * cp_size, -1, *dv.shape[-3:])
2940
        chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_after_attn(cp_size, dk.device)
2941
2942
2943
        dk = torch.index_select(dk, dim=0, index=chunk_ids_for_kv_ag)
        dv = torch.index_select(dv, dim=0, index=chunk_ids_for_kv_ag)
        # [cp*2, s//2, b, np, hn] -> [cp*s, b, np, hn]
2944
2945
2946
2947
2948
        dk = dk.view(-1, *dk.shape[-3:])
        dv = dv.view(-1, *dv.shape[-3:])
        dk, _ = reduce_scatter_along_first_dim(dk, ctx.cp_group)
        dv, _ = reduce_scatter_along_first_dim(dv, ctx.cp_group)

2949
2950
2951
        dq = dq.view(*dq.shape[:seq_dim], -1, *dq.shape[(seq_dim + 2) :])
        dk = dk.movedim(0, seq_dim).contiguous()
        dv = dv.movedim(0, seq_dim).contiguous()
2952
        nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVAllGather.backward")
2953
2954
2955
2956
2957
2958
2959
2960
2961
2962
2963
2964
2965
2966
2967
2968
2969
2970
2971
2972
2973

        return (
            None,
            dq,
            dk,
            dv,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
2974
            None,
2975
2976
2977
2978
2979
2980
2981
2982
2983
2984
2985
2986
2987
2988
2989
2990
2991
2992
2993
2994
2995
2996
2997
2998
2999
3000
3001
3002
3003
3004
3005
3006
3007
3008
3009
        )


class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
    """
    Attention implementation with context parallelism. Like Ulysses, applying A2A to QKVO.
    Refer the paper `DeepSpeed Ulysses <https://arxiv.org/abs/2309.14509>`_.
    """

    @staticmethod
    def forward(
        ctx,
        is_training,
        q,
        k,
        v,
        cu_seqlens_q,
        cu_seqlens_kv,
        max_seqlen_q,
        max_seqlen_kv,
        cu_seqlens_q_padded,
        cu_seqlens_kv_padded,
        dropout_p,
        softmax_scale,
        qkv_format,
        attn_mask_type,
        attn_bias_type,
        attn_bias,
        deterministic,
        use_fused_attention,
        window_size,
        fp8,
        fp8_meta,
        cp_group,
        cp_stream,
3010
        quantizers,
3011
        use_flash_attn_3,
3012
    ):
3013
        # pylint: disable=missing-function-docstring
3014
        nvtx_range_push("transformer_engine.AttnFuncWithCPAndQKVOA2A.forward")
3015
3016
3017
3018
        if softmax_scale is None:
            softmax_scale = q.shape[-1] ** (-0.5)

        cp_size = get_distributed_world_size(cp_group)
3019
        qkv_dtype = q.dtype
3020
3021
3022
3023
3024
3025
3026
3027
3028
3029

        causal = "causal" in attn_mask_type
        padding = "padding" in attn_mask_type
        assert not padding, f"{attn_mask_type} mask type is not supported!"
        assert attn_bias_type == "no_bias", f"{attn_bias_type} bias type is not supported!"
        assert q.shape[-1] % 8 == 0, "Hidden size per attention head should be multiple of 8!"
        assert (
            window_size == (-1, 0)
            or window_size == (-1, -1)
            or use_fused_attention
3030
            or fa_utils.v2_3_plus
3031
        ), "Sliding window attention only can work with FusedAttention or FlashAttention >= 2.3!"
3032

3033
        flash_attn_fwd = None
3034
3035
        if not use_fused_attention:
            fa_forward_kwargs = {"softmax_scale": softmax_scale}
3036
3037
            if use_flash_attn_3:
                flash_attn_fwd = _flash_attn_fwd_v3
3038
3039
                fa_forward_kwargs["window_size"] = window_size
            else:
3040
3041
3042
3043
                if qkv_format == "thd":
                    flash_attn_fwd = _flash_attn_varlen_fwd
                else:
                    flash_attn_fwd = _flash_attn_fwd
3044
3045
                fa_forward_kwargs["dropout_p"] = dropout_p
                fa_forward_kwargs["return_softmax"] = False
3046
                if fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus:
3047
                    fa_forward_kwargs["window_size"] = window_size
3048
                elif fa_utils.v2_7_0_plus:
3049
3050
                    fa_forward_kwargs["window_size_left"] = window_size[0]
                    fa_forward_kwargs["window_size_right"] = window_size[1]
3051
                if fa_utils.v2_4_plus:
3052
                    fa_forward_kwargs["alibi_slopes"] = None
3053
                if fa_utils.v2_5_7_plus and qkv_format == "thd":
3054
                    fa_forward_kwargs["block_table"] = None
3055
                if fa_utils.v2_6_0_plus:
3056
                    fa_forward_kwargs["softcap"] = 0.0
3057
3058
3059
3060
3061
3062
3063
3064
3065
3066
3067
3068
3069
3070

        assert (
            q.shape[-2] % cp_size == 0 and k.shape[-2] % cp_size == 0
        ), "The number of attention heads needs to be divisible by CP size!"

        assert qkv_format != "thd", f"{qkv_format} format is not supported!"
        qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format

        batch_dim = qkv_format.index("b")
        seq_dim = qkv_format.index("s")
        assert (
            q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0
        ), "Sequence length per GPU needs to be divisible by 2!"

3071
        fused_attn_backend = None
3072
3073
        # "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype
        is_input_fp8 = False
3074
3075
3076
        is_output_fp8 = False

        QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer = (
3077
            dpa_utils.get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False)
3078
3079
3080
        )
        if fp8:
            if use_fused_attention:
3081
                fused_attn_backend = FusedAttnBackend["FP8"]
3082
3083
3084
3085
                assert isinstance(k, q.__class__) and isinstance(
                    v, q.__class__
                ), "q, k, and v must have the same type."
                is_input_fp8 = isinstance(q, Float8Tensor)
3086
                is_output_fp8 = fp8_meta is not None and fp8_meta["recipe"].fp8_mha
3087
                if is_input_fp8:
3088
                    QKV_quantizer = q._quantizer
3089
3090
3091
3092
                    q_fp8, k_fp8, v_fp8 = q, k, v
                    q, k, v = q_fp8._data, k_fp8._data, v_fp8._data
                elif int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
                    q_f16, k_f16, v_f16 = q, k, v
3093
                    q, k, v = [QKV_quantizer(x)._data for x in [q_f16, k_f16, v_f16]]
3094
                fp8_meta_kwargs = {}
3095
3096
                fp8_meta_kwargs["s_quantizer"] = S_quantizer
                fp8_meta_kwargs["o_quantizer"] = O_quantizer  # partial result quantizer
3097
3098
3099
3100
3101
3102
3103
            else:
                assert False, "FP8 is only supported with Fused Attention!"
        else:
            if use_fused_attention:
                fp8_meta_kwargs = {}
                fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"]

3104
        chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_before_attn(cp_size, q.device)
3105
3106
3107
3108
        q, k, v = flash_attn_a2a_communicate(
            [q, k, v], chunk_ids_for_a2a, seq_dim, cp_size, cp_group, cp_stream, True
        )

3109
        if fp8 and not is_input_fp8 and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
3110
            q_f16, k_f16, v_f16 = q, k, v
3111
            q, k, v = [QKV_quantizer(x)._data for x in [q_f16, k_f16, v_f16]]
3112
3113
3114

        batch_size = q.shape[batch_dim]
        if use_fused_attention:
3115
3116
3117
3118
3119
3120
3121
3122
3123
3124
3125
            q_part, k_part, v_part = q, k, v
            if fp8:
                q_part = QKV_quantizer.create_tensor_from_data(
                    q, fake_dtype=qkv_dtype, internal=True
                )
                k_part = QKV_quantizer.create_tensor_from_data(
                    k, fake_dtype=qkv_dtype, internal=True
                )
                v_part = QKV_quantizer.create_tensor_from_data(
                    v, fake_dtype=qkv_dtype, internal=True
                )
3126
3127
3128
3129
3130
3131
            out, aux_ctx_tensors = fused_attn_fwd(
                is_training,
                max_seqlen_q,
                max_seqlen_kv,
                cu_seqlens_q,
                cu_seqlens_kv,
3132
3133
3134
3135
                q_part,
                k_part,
                v_part,
                qkv_dtype,
3136
3137
3138
3139
3140
3141
3142
3143
3144
3145
3146
3147
                fused_attn_backend,
                attn_scale=softmax_scale,
                dropout=dropout_p,
                qkv_layout=qkv_layout,
                attn_mask_type=attn_mask_type,
                attn_bias_type=attn_bias_type,
                attn_bias=attn_bias,
                cu_seqlens_q_padded=cu_seqlens_q_padded,
                cu_seqlens_kv_padded=cu_seqlens_kv_padded,
                window_size=window_size,
                **fp8_meta_kwargs,
            )
3148
3149
            if fp8:
                out = out._data
3150
        else:
3151
3152
3153
3154
3155
3156
3157
3158
3159
            fa_forward_args_thd = get_fa_args(
                True,
                use_flash_attn_3,
                qkv_format,
                cu_seqlens_q=cu_seqlens_q,
                cu_seqlens_kv=cu_seqlens_kv,
                max_seqlen_q=max_seqlen_q,
                max_seqlen_kv=max_seqlen_kv,
            )
3160
            fa_outputs = flash_attn_fwd(
3161
3162
3163
                q,
                k,
                v,
3164
                *fa_forward_args_thd,
3165
                causal=causal,
3166
                **fa_forward_kwargs,
3167
            )
3168
            if not fa_utils.v2_7_0_plus:
3169
                out, softmax_lse = fa_outputs[4], fa_outputs[5]
3170
                rng_state = fa_outputs[7] if not use_flash_attn_3 else None
3171
3172
            else:
                out, softmax_lse = fa_outputs[0], fa_outputs[1]
3173
                rng_state = fa_outputs[3] if not use_flash_attn_3 else None
3174
3175
            aux_ctx_tensors = [softmax_lse, rng_state]

3176
        chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size, out.device)
3177
3178
3179
3180
3181
3182
3183
3184
3185
3186
3187
3188
3189
        out = flash_attn_a2a_communicate(
            out, chunk_ids_for_a2a, seq_dim, cp_size, cp_group, cp_stream, False
        )

        if use_fused_attention:
            if qkv_format == "bshd":
                # [b*s, np, hn] -> [b, s, np, hn]
                out = out.view(batch_size, -1, *out.shape[-2:])
            elif qkv_format == "sbhd":
                # [s*b, np, hn] -> [s, b, np, hn]
                out = out.view(-1, batch_size, *out.shape[-2:])

        if fp8:
3190
            if is_output_fp8:
3191
3192
                out_fp8 = O_quantizer.create_tensor_from_data(
                    out, fake_dtype=qkv_dtype, internal=False
3193
3194
                )
                out_ret = out_fp8
3195
                out = out_fp8._data
3196
            else:
3197
                out_fp8 = O_quantizer.create_tensor_from_data(
3198
                    out, fake_dtype=qkv_dtype, internal=True
3199
                )
3200
                out_f16 = out_fp8.dequantize(dtype=qkv_dtype)
3201
3202
3203
3204
                out_ret = out_f16
        else:
            out_ret = out

3205
        if not fp8 or int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
3206
            q_save, k_save, v_save, out_save = q, k, v, out
3207
3208
3209
3210
3211
3212
3213
3214
3215
        else:
            if is_input_fp8:
                q_save, k_save, v_save = q, k, v
            else:
                q_save, k_save, v_save = q_f16, k_f16, v_f16
            if is_output_fp8:
                out_save = out
            else:
                out_save = out_f16
3216

3217
        tensors_to_save, tensor_objects = prepare_for_saving(
3218
3219
3220
3221
3222
3223
3224
3225
3226
3227
            q_save,
            k_save,
            v_save,
            out_save,
            cu_seqlens_q,
            cu_seqlens_kv,
            cu_seqlens_q_padded,
            cu_seqlens_kv_padded,
            *aux_ctx_tensors,
        )
3228
3229
3230
        ctx.save_for_backward(*tensors_to_save)
        ctx.tensor_objects = tensor_objects

3231
3232
3233
3234
3235
3236
3237
3238
3239
3240
3241
3242
3243
3244
3245
        ctx.batch_size = batch_size
        ctx.cp_group = cp_group
        ctx.cp_stream = cp_stream
        ctx.dropout_p = dropout_p
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_kv = max_seqlen_kv
        ctx.softmax_scale = softmax_scale
        ctx.qkv_format = qkv_format
        ctx.attn_mask_type = attn_mask_type
        ctx.attn_bias_type = attn_bias_type
        ctx.deterministic = deterministic
        ctx.window_size = window_size
        ctx.use_fused_attention = use_fused_attention
        ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
        ctx.fp8_meta = fp8_meta
3246
3247
        ctx.is_input_fp8 = is_input_fp8
        ctx.is_output_fp8 = is_output_fp8
3248
        ctx.use_flash_attn_3 = use_flash_attn_3
3249
3250
3251
3252
3253
3254
3255
3256
3257
3258
3259
3260
3261
3262
3263

        ctx.qkv_dtype = qkv_dtype
        ctx.dQKV_quantizer = dQKV_quantizer
        ctx.dO_quantizer = dO_quantizer
        ctx.dP_quantizer = dP_quantizer
        ctx.QKV_quantizer = QKV_quantizer
        ctx.O_quantizer = O_quantizer
        ctx.S_quantizer = S_quantizer
        if ctx.fp8:
            ctx.QKV_quantizer = QKV_quantizer.copy()
            ctx.QKV_quantizer.scale = QKV_quantizer.scale.clone()
            ctx.O_quantizer = O_quantizer.copy()
            ctx.O_quantizer.scale = O_quantizer.scale.clone()
            ctx.S_quantizer = S_quantizer.copy()
            ctx.S_quantizer.scale = S_quantizer.scale.clone()
3264
        nvtx_range_pop("transformer_engine.AttnFuncWithCPAndQKVOA2A.forward")
3265
3266
3267
3268
        return out_ret

    @staticmethod
    def backward(ctx, dout):
3269
        # pylint: disable=missing-function-docstring
3270
        nvtx_range_push("transformer_engine.AttnFuncWithCPAndQKVOA2A.backward")
3271
3272
        cp_size = get_distributed_world_size(ctx.cp_group)

3273
3274
3275
3276
3277
3278
3279
3280
3281
3282
3283
        (
            q,
            k,
            v,
            out,
            cu_seqlens_q,
            cu_seqlens_kv,
            cu_seqlens_q_padded,
            cu_seqlens_kv_padded,
            *aux_ctx_tensors,
        ) = restore_from_saved(ctx.tensor_objects, ctx.saved_tensors)
3284
3285
3286
3287
3288

        qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format
        causal = "causal" in ctx.attn_mask_type
        seq_dim = ctx.qkv_format.index("s")

3289
        dout_dtype = dout.dtype
3290
3291
        fused_attn_backend = None
        fused_attn_dqkv_dtype = None
3292
3293
3294
        if ctx.fp8:
            if ctx.use_fused_attention:
                fused_attn_backend = FusedAttnBackend["FP8"]
3295
                if ctx.is_output_fp8:
3296
                    assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!"
3297
                    ctx.dO_quantizer = dout._quantizer
3298
                else:
3299
                    dout = ctx.dO_quantizer(dout)
3300
                fused_attn_dqkv_dtype = TE_DType[dout._data.dtype]
3301
                dout = dout._data
3302
                fp8_meta_kwargs = {}
3303
3304
3305
3306
                fp8_meta_kwargs["s_quantizer"] = ctx.S_quantizer
                fp8_meta_kwargs["dp_quantizer"] = ctx.dP_quantizer
                fp8_meta_kwargs["dqkv_quantizer"] = ctx.dQKV_quantizer

3307
3308
3309
            else:
                assert False, "FP8 is only supported with Fused Attention!"
        else:
3310
3311
3312
3313
3314
3315
3316
3317
3318
3319
3320
3321
3322
3323
3324
3325
            if ctx.fp8_meta is not None:
                if ctx.is_output_fp8:
                    assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!"
                    ctx.dO_quantizer = dout._quantizer
                    dout = dout._data
                if ctx.is_input_fp8:
                    q = ctx.QKV_quantizer.create_tensor_from_data(
                        q, fake_dtype=ctx.qkv_dtype, internal=True
                    )
                    k = ctx.QKV_quantizer.create_tensor_from_data(
                        k, fake_dtype=ctx.qkv_dtype, internal=True
                    )
                    v = ctx.QKV_quantizer.create_tensor_from_data(
                        v, fake_dtype=ctx.qkv_dtype, internal=True
                    )
                    q, k, v = [x.dequantize(dtype=ctx.qkv_dtype) for x in [q, k, v]]
3326
3327
            if ctx.use_fused_attention:
                fp8_meta_kwargs = {}
3328
                fused_attn_dqkv_dtype = TE_DType[dout_dtype]
3329
3330
3331
3332
3333
3334
                fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"]

        if not ctx.use_fused_attention:
            out = out.view(ctx.batch_size, -1, *out.shape[-2:])
        dout = dout.view(*out.shape)

3335
        chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_before_attn(cp_size, out.device)
3336
3337
3338
        out, dout = flash_attn_a2a_communicate(
            [out, dout], chunk_ids_for_a2a, seq_dim, cp_size, ctx.cp_group, ctx.cp_stream, True
        )
3339
3340
3341
3342
3343
3344
3345
3346
3347
        if not ctx.fp8 and ctx.fp8_meta is not None and ctx.is_output_fp8:
            out = ctx.O_quantizer.create_tensor_from_data(
                out, fake_dtype=ctx.qkv_dtype, internal=True
            )
            dout = ctx.dO_quantizer.create_tensor_from_data(
                dout, fake_dtype=dout_dtype, internal=True
            )
            out = out.dequantize(dtype=ctx.qkv_dtype)
            dout = dout.dequantize(dtype=dout_dtype)
3348

3349
        flash_attn_bwd = None
3350
3351
        if not ctx.use_fused_attention:
            fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale}
3352
3353
3354
3355
            if ctx.use_flash_attn_3:
                flash_attn_bwd = (
                    _flash_attn_bwd_v3  # pylint: disable=possibly-used-before-assignment
                )
3356
3357
3358
                fa_backward_kwargs["window_size"] = ctx.window_size
                fa_backward_kwargs["deterministic"] = ctx.deterministic
            else:
3359
3360
3361
3362
                if ctx.qkv_format == "thd":
                    flash_attn_bwd = _flash_attn_varlen_bwd
                else:
                    flash_attn_bwd = _flash_attn_bwd
3363
                fa_backward_kwargs["dropout_p"] = ctx.dropout_p
3364
                if fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus:
3365
                    fa_backward_kwargs["window_size"] = ctx.window_size
3366
                elif fa_utils.v2_7_0_plus:
3367
3368
                    fa_backward_kwargs["window_size_left"] = ctx.window_size[0]
                    fa_backward_kwargs["window_size_right"] = ctx.window_size[1]
3369
                if fa_utils.v2_4_plus:
3370
                    fa_backward_kwargs["alibi_slopes"] = None
3371
                if fa_utils.v2_4_1_plus:
3372
                    fa_backward_kwargs["deterministic"] = ctx.deterministic
3373
                if fa_utils.v2_6_0_plus:
3374
                    fa_backward_kwargs["softcap"] = 0.0
3375
3376

        if ctx.use_fused_attention:
3377
3378
3379
3380
3381
3382
3383
3384
3385
3386
3387
3388
3389
3390
3391
3392
3393
3394
3395
3396
            q_part = q
            k_part = k
            v_part = v
            out_part = out
            dout_part = dout

            if ctx.fp8:
                q_part = ctx.QKV_quantizer.create_tensor_from_data(
                    q_part, fake_dtype=ctx.qkv_dtype, internal=True
                )
                k_part = ctx.QKV_quantizer.create_tensor_from_data(
                    k_part, fake_dtype=ctx.qkv_dtype, internal=True
                )
                v_part = ctx.QKV_quantizer.create_tensor_from_data(
                    v_part, fake_dtype=ctx.qkv_dtype, internal=True
                )
                out_part = ctx.O_quantizer.create_tensor_from_data(
                    out_part, fake_dtype=ctx.qkv_dtype, internal=True
                )
                dout_part = ctx.dO_quantizer.create_tensor_from_data(
3397
                    dout_part, fake_dtype=dout_dtype, internal=True
3398
3399
                )

3400
3401
3402
3403
3404
            dq, dk, dv, _ = fused_attn_bwd(
                ctx.max_seqlen_q,
                ctx.max_seqlen_kv,
                cu_seqlens_q,
                cu_seqlens_kv,
3405
3406
3407
3408
3409
                q_part,
                k_part,
                v_part,
                out_part,
                dout_part,
3410
                dout_dtype,
3411
3412
3413
3414
3415
3416
3417
3418
3419
3420
3421
3422
3423
3424
                fused_attn_dqkv_dtype,
                aux_ctx_tensors,
                fused_attn_backend,
                cu_seqlens_q_padded=cu_seqlens_q_padded,
                cu_seqlens_kv_padded=cu_seqlens_kv_padded,
                attn_scale=ctx.softmax_scale,
                dropout=ctx.dropout_p,
                qkv_layout=qkv_layout,
                attn_mask_type=ctx.attn_mask_type,
                attn_bias_type=ctx.attn_bias_type,
                window_size=ctx.window_size,
                deterministic=ctx.deterministic,
                **fp8_meta_kwargs,
            )
3425
3426
3427
3428
            if ctx.fp8:
                dq = dq._data
                dk = dk._data
                dv = dv._data
3429
3430
3431
        else:
            softmax_lse, rng_state = aux_ctx_tensors
            dq, dk, dv = [torch.empty_like(x) for x in [q, k, v]]
3432
3433
3434
3435
3436
3437
3438
3439
3440
3441
3442
3443
3444
            fa_backward_args_thd = get_fa_args(
                False,
                ctx.use_flash_attn_3,
                ctx.qkv_format,
                cu_seqlens_q=cu_seqlens_q,
                cu_seqlens_kv=cu_seqlens_kv,
                max_seqlen_q=ctx.max_seqlen_q,
                max_seqlen_kv=ctx.max_seqlen_kv,
                dq=dq,
                dk=dk,
                dv=dv,
            )
            if not ctx.use_flash_attn_3:
3445
3446
                fa_backward_kwargs["rng_state"] = rng_state
            flash_attn_bwd(
3447
3448
3449
3450
3451
3452
                dout,
                q,
                k,
                v,
                out,
                softmax_lse,
3453
                *fa_backward_args_thd,
3454
3455
                causal=causal,
                **fa_backward_kwargs,
3456
3457
            )

3458
        chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size, q.device)
3459
3460
3461
3462
        dq, dk, dv = flash_attn_a2a_communicate(
            [dq, dk, dv], chunk_ids_for_a2a, seq_dim, cp_size, ctx.cp_group, ctx.cp_stream, False
        )

3463
        if ctx.qkv_format == "bshd":
3464
            dq, dk, dv = [x.view(ctx.batch_size, -1, *x.shape[-2:]) for x in [dq, dk, dv]]
3465
        elif ctx.qkv_format == "sbhd":
3466
3467
3468
            dq, dk, dv = [x.view(-1, ctx.batch_size, *x.shape[-2:]) for x in [dq, dk, dv]]

        if ctx.fp8:
3469
3470
3471
3472
3473
3474
3475
3476
3477
            dq = ctx.dQKV_quantizer.create_tensor_from_data(
                dq, fake_dtype=dout_dtype, internal=not ctx.is_input_fp8
            )
            dk = ctx.dQKV_quantizer.create_tensor_from_data(
                dk, fake_dtype=dout_dtype, internal=not ctx.is_input_fp8
            )
            dv = ctx.dQKV_quantizer.create_tensor_from_data(
                dv, fake_dtype=dout_dtype, internal=not ctx.is_input_fp8
            )
3478
            if not ctx.is_input_fp8:
3479
                dq, dk, dv = [x.dequantize(dtype=dout_dtype) for x in [dq, dk, dv]]
3480
        nvtx_range_pop("transformer_engine.AttnFuncWithCPAndQKVOA2A.backward")
3481
3482
3483
3484
3485
3486
3487
3488
3489
3490
3491
3492
3493
3494
3495
3496
3497
3498
3499
3500
3501
3502
3503

        return (
            None,
            dq,
            dk,
            dv,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
3504
3505
3506
            None,
            None,
            None,
3507
            None,
3508
3509
3510
        )


3511
def attn_forward_func_with_cp(
3512
3513
3514
3515
3516
    is_training,
    q,
    k,
    v,
    cu_seqlens_q,
3517
    cu_seqlens_kv,
3518
    max_seqlen_q,
3519
    max_seqlen_kv,
3520
3521
    cu_seqlens_q_padded,
    cu_seqlens_kv_padded,
3522
3523
3524
3525
    dropout_p,
    cp_group,
    cp_global_ranks,
    cp_stream,
3526
    cp_comm_type,
3527
3528
3529
3530
3531
3532
3533
    softmax_scale=None,
    qkv_format="bshd",
    attn_mask_type="causal",
    attn_bias_type="no_bias",
    attn_bias=None,
    deterministic=False,
    use_fused_attention=False,
3534
    window_size=None,
3535
3536
    fp8=False,
    fp8_meta=None,
3537
    quantizers=None,
3538
    pad_between_seqs=False,
3539
    use_flash_attn_3=False,
3540
) -> torch.Tensor:
3541
3542
3543
3544
    """
    Attention implementation with context parallelism.
    """

3545
3546
3547
3548
3549
3550
3551
3552
3553
3554
3555
3556
3557
3558
3559
3560
    if cp_comm_type == "a2a+p2p":
        assert isinstance(
            cp_group, list
        ), "Hierarchical CP implementation needs multi-level CP groups!"
        assert len(cp_group) == 2, "Current implementation only supports two-level CP groups!"
        if get_distributed_world_size(cp_group[0]) == 1:
            cp_group = cp_group[1]
            cp_comm_type = "p2p"
        elif get_distributed_world_size(cp_group[1]) == 1:
            cp_group = cp_group[0]
            cp_comm_type = "a2a"
    else:
        assert isinstance(
            cp_group, dist_group_type
        ), f"Unsupported process group for CP communication type {cp_comm_type}!"

3561
3562
3563
3564
3565
3566
3567
3568
3569
3570
3571
3572
    assert qkv_format in [
        "bshd",
        "sbhd",
        "thd",
    ], f"QKV format of {qkv_format} is not supported with context parallelism!"
    assert (
        qkv_format != "sbhd" or use_fused_attention
    ), "FlashAttention does not support sbhd format!"
    assert attn_bias is None or (use_fused_attention and "padding" not in attn_mask_type), (
        """Attention bias is only supported with FusedAttention and "causal" """
        """or "no_mask" mask types!"""
    )
3573
    assert qkv_format != "thd" or (
3574
        cu_seqlens_q_padded is not None and cu_seqlens_kv_padded is not None
3575
    ), "cu_seqlens_padded cannot be None with context parallelism + THD format!"
3576
3577
3578

    sliding_window_attn = (
        window_size is not None and window_size != (-1, 0) and window_size != (-1, -1)
3579
    )
3580
3581
3582
3583
    assert not sliding_window_attn or cp_comm_type in [
        "a2a",
        "all_gather",
    ], "The context parallel running configs cannot support sliding window attetnion!"
3584

3585
3586
3587
3588
3589
3590
3591
3592
3593
3594
3595
3596
3597
3598
3599
3600
3601
3602
3603
3604
3605
    args = [
        is_training,
        q,
        k,
        v,
        cu_seqlens_q,
        cu_seqlens_kv,
        max_seqlen_q,
        max_seqlen_kv,
        cu_seqlens_q_padded,
        cu_seqlens_kv_padded,
        dropout_p,
        softmax_scale,
        qkv_format,
        attn_mask_type,
        attn_bias_type,
        attn_bias,
        deterministic,
        use_fused_attention,
    ]

3606
    if cp_comm_type in ["p2p", "a2a+p2p"]:
3607
3608
3609
3610
3611
3612
3613
3614
3615
3616
        args += [
            fp8,
            fp8_meta,
            cp_group,
            cp_global_ranks,
            cp_stream,
            quantizers,
            pad_between_seqs,
            use_flash_attn_3,
        ]
3617
3618
3619
3620
        out = AttnFuncWithCPAndKVP2P.apply(*args)
    elif cp_comm_type == "all_gather":
        args.pop(5)
        args.pop(8)
3621
        args += [window_size, cp_group, cp_stream, use_flash_attn_3]
3622
3623
        out = AttnFuncWithCPAndKVAllGather.apply(*args)
    elif cp_comm_type == "a2a":
3624
        args += [window_size, fp8, fp8_meta, cp_group, cp_stream, quantizers, use_flash_attn_3]
3625
        out = AttnFuncWithCPAndQKVOA2A.apply(*args)
3626
3627
3628
    else:
        raise ValueError(f"Unsupported communication type: {cp_comm_type}!")

3629
3630
3631
    return out


cyanguwa's avatar
cyanguwa committed
3632
class _SplitAlongDim(torch.autograd.Function):
3633
3634
3635
    """"""

    @staticmethod
3636
3637
3638
3639
3640
    def forward(
        ctx,
        mixed_x_layer: torch.Tensor,
        split_dim: int,
        split_size_or_sections: Union[int, List[int], Tuple[int]],
3641
        squeeze=False,
3642
    ) -> Tuple[torch.Tensor, ...]:
3643
        # pylint: disable=missing-function-docstring
cyanguwa's avatar
cyanguwa committed
3644
3645
        ctx.split_dim = split_dim
        ctx.split_size_or_sections = split_size_or_sections
3646
3647
3648
3649
3650
3651
3652
3653
3654
3655
3656
3657
3658
3659
3660
3661
3662
        if isinstance(mixed_x_layer, Float8TensorBase) and not isinstance(
            mixed_x_layer, Float8Tensor
        ):
            return tuple(
                Float8TensorBase(
                    fp8_scale_inv=mixed_x_layer._scale_inv,
                    fp8_dtype=mixed_x_layer._fp8_dtype,
                    data=x.squeeze(split_dim) if squeeze else x,
                    shape=x.squeeze(split_dim).shape if squeeze else x.shape,
                    quantizer=mixed_x_layer._quantizer,
                )
                for x in torch.split(
                    mixed_x_layer._data,
                    split_size_or_sections=split_size_or_sections,
                    dim=split_dim,
                )
            )
3663
        if isinstance(mixed_x_layer, Float8Tensor):
3664
3665
3666
            return tuple(
                Float8Tensor.make_like(
                    mixed_x_layer,
3667
3668
                    data=x.squeeze(split_dim) if squeeze else x,
                    shape=x.squeeze(split_dim).shape if squeeze else x.shape,
3669
3670
                )
                for x in torch.split(
3671
3672
                    mixed_x_layer._data,
                    split_size_or_sections=split_size_or_sections,
3673
3674
3675
                    dim=split_dim,
                )
            )
3676
3677
3678
3679
        out_list = torch.split(mixed_x_layer, split_size_or_sections, dim=split_dim)
        if squeeze:
            out_list = [x.squeeze(split_dim) for x in out_list]
        return out_list
3680
3681

    @staticmethod
3682
    def backward(ctx, *grad_outputs):
3683
        # pylint: disable=missing-function-docstring
3684
3685
        assert len(grad_outputs) > 0, "No gradients received for backprop!"

cyanguwa's avatar
cyanguwa committed
3686
3687
        if isinstance(ctx.split_size_or_sections, (list, tuple)):
            split_sizes = ctx.split_size_or_sections
3688
3689
3690
            assert len(grad_outputs) == len(
                split_sizes
            ), "Unequal number of gradients vs split sections for backprop!"
cyanguwa's avatar
cyanguwa committed
3691
3692
3693
3694
3695
        if isinstance(ctx.split_size_or_sections, int):
            split_sizes = [ctx.split_size_or_sections] * len(grad_outputs)
        dims = len(grad_outputs[0].shape)
        split_dim = (ctx.split_dim + dims) % dims

3696
3697
3698
3699
3700
3701
3702
3703
        if isinstance(grad_outputs[0], Float8Tensor):
            noop_ok = True
            strides = grad_outputs[0].stride()
            data_ptr = grad_outputs[0]._data.untyped_storage().data_ptr()
            shape = list(grad_outputs[0].shape)
            for i, tensor in enumerate(grad_outputs):
                shape_i = shape
                shape_i[split_dim] = split_sizes[i]
3704
3705
3706
3707
3708
3709
3710
                offset_size = sum(split_sizes[:i]) * np.prod(shape[split_dim + 1 :])
                if (
                    tensor.stride() != strides
                    or list(tensor.shape) != shape_i
                    or tensor._data.untyped_storage().data_ptr() != data_ptr
                    or tensor.storage_offset() != offset_size
                ):
3711
3712
3713
                    noop_ok = False
                    break
            if noop_ok:
3714
3715
3716
                ret = torch.Tensor().to(
                    device=grad_outputs[0].device, dtype=grad_outputs[0]._data.dtype
                )
3717
3718
                new_shape = list(shape)
                new_shape[split_dim] = sum(split_sizes)
3719
3720
3721
3722
3723
                ret.set_(
                    grad_outputs[0]._data.untyped_storage(),
                    grad_outputs[0]._data.storage_offset(),
                    new_shape,
                    strides,
3724
                )
3725
3726
3727
3728
3729
                return (
                    Float8Tensor.make_like(grad_outputs[0], data=ret, shape=ret.shape),
                    None,
                    None,
                )
3730
3731

            grad_outputs_data = [x._data for x in grad_outputs]
3732
            data = torch.cat(grad_outputs_data, dim=split_dim)
3733
            return (
3734
3735
                Float8Tensor.make_like(grad_outputs[0], data=data, shape=data.shape),
                None,
3736
3737
3738
                None,
                None,
            )
3739
3740
        noop_ok = True
        strides = grad_outputs[0].stride()
3741
        data_ptr = grad_outputs[0].untyped_storage().data_ptr()
cyanguwa's avatar
cyanguwa committed
3742
        shape = list(grad_outputs[0].shape)
3743
        for i, tensor in enumerate(grad_outputs):
cyanguwa's avatar
cyanguwa committed
3744
3745
            shape_i = shape
            shape_i[split_dim] = split_sizes[i]
3746
3747
3748
3749
3750
3751
3752
            offset_size = sum(split_sizes[:i]) * np.prod(shape[split_dim + 1 :])
            if (
                tensor.stride() != strides
                or list(tensor.shape) != shape_i
                or tensor.untyped_storage().data_ptr() != data_ptr
                or tensor.storage_offset() != offset_size
            ):
3753
3754
3755
                noop_ok = False
                break
        if noop_ok:
3756
            ret = torch.Tensor().to(device=grad_outputs[0].device, dtype=grad_outputs[0].dtype)
3757
            new_shape = list(shape)
cyanguwa's avatar
cyanguwa committed
3758
            new_shape[split_dim] = sum(split_sizes)
3759
3760
3761
3762
3763
            ret.set_(
                grad_outputs[0].untyped_storage(),
                grad_outputs[0].storage_offset(),
                new_shape,
                strides,
3764
            )
cyanguwa's avatar
cyanguwa committed
3765
            return ret, None, None
3766

3767
        return torch.cat(grad_outputs, dim=split_dim), None, None
3768
3769
3770
3771
3772
3773
3774
3775
3776


class UnfusedDotProductAttention(torch.nn.Module):
    """Parallel attention w/o QKV and Proj Gemms
    BMM1 -> softmax + dropout -> BMM2
    """

    def __init__(
        self,
3777
        softmax_scale: float,
3778
        attention_type: str = "self",
3779
3780
3781
3782
3783
3784
        attention_dropout: float = 0.0,
        attention_dropout_ctx: Optional[Callable] = nullcontext,
        layer_number: Optional[int] = None,
    ) -> None:
        super().__init__()

3785
        self.softmax_scale = softmax_scale
3786
        self.attention_type = attention_type
3787
3788
3789
        self.attention_dropout_ctx = attention_dropout_ctx
        self.layer_number = layer_number

3790
        self.scale_mask_softmax = FusedScaleMaskSoftmax(attention_mask_func)
3791
3792
3793
3794
3795
3796

        # Dropout. Note that for a single iteration, this layer will generate
        # different outputs on different number of parallel partitions but
        # on average it should not be partition dependent.
        self.attention_dropout = torch.nn.Dropout(attention_dropout)

3797
3798
        # An FP16 training trick required for certain GPT-like models.
        self.apply_qk_layer_scaling = (
3799
3800
            bool(int(os.getenv("NVTE_APPLY_QK_LAYER_SCALING", "0"))) and layer_number is not None
        )
3801

3802
3803
3804
3805
3806
    def forward(
        self,
        query_layer: torch.Tensor,
        key_layer: torch.Tensor,
        value_layer: torch.Tensor,
3807
        qkv_layout: str = "sbh3d",
3808
3809
        cu_seqlens_q: Optional[torch.Tensor] = None,  # pylint: disable=unused-argument
        cu_seqlens_kv: Optional[torch.Tensor] = None,  # pylint: disable=unused-argument
3810
        attn_mask_type: str = "causal",
3811
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
3812
        window_size: Optional[Tuple[int, int]] = None,
3813
3814
        core_attention_bias_type: str = "no_bias",
        core_attention_bias: Optional[torch.Tensor] = None,
3815
        alibi_slopes: Optional[torch.Tensor] = None,
3816
        inference_params: Optional[InferenceParams] = None,
3817
    ) -> torch.Tensor:
3818
        """Unfused attention fprop"""
3819
3820
3821
        assert (
            qkv_layout in QKVLayouts
        ), f"UnfusedDotProductAttention does not support qkv_layout = {qkv_layout}!"
3822
3823
3824
3825
3826
3827

        # get q_format and kv_format for training and inference
        qkv_format, q_format, _ = dpa_utils.get_qkv_format(qkv_layout, inference_params)
        if inference_params is not None and inference_params.is_paged:
            key_layer, value_layer = inference_params.convert_paged_to_nonpaged(self.layer_number)

3828
        if qkv_format == "bshd":
3829
            # convert to sbhd and use sbhd implementation for now
3830
3831
3832
            query_layer, key_layer, value_layer = [
                x.transpose(0, 1) for x in [query_layer, key_layer, value_layer]
            ]
3833
3834
3835
3836
3837
3838
3839
3840
3841
3842
3843
3844
3845
3846
3847
        if qkv_format == "sbhd_2bshd":
            key_layer, value_layer = [x.transpose(0, 1) for x in [key_layer, value_layer]]

        total_tokens, batch_size = None, None
        if qkv_format == "thd_2bshd":
            total_tokens, batch_size = query_layer.shape[0], key_layer.shape[0]
            query_layer = tex.convert_thd_to_bshd(
                query_layer,
                cu_seqlens_q,
                batch_size,
                inference_params.max_ctx_len,
            )
            query_layer, key_layer, value_layer = [
                x.transpose(0, 1) for x in [query_layer, key_layer, value_layer]
            ]
3848
3849
3850
3851
3852
        batch_size, max_seqlen_q, max_seqlen_kv = (
            query_layer.shape[1],
            query_layer.shape[0],
            key_layer.shape[0],
        )
3853

3854
3855
3856
3857
        if "padding" in attn_mask_type and attention_mask is None:
            attention_mask = dpa_utils.get_padding_mask(
                batch_size, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv
            )
3858
3859
3860
3861
3862
3863
3864
3865
3866
        attn_mask_type, attention_mask, actual_seqlens_q, actual_seqlens_kv = (
            dpa_utils.get_full_mask(
                max_seqlen_q,
                max_seqlen_kv,
                attn_mask_type=attn_mask_type,
                attention_mask=attention_mask,
                window_size=window_size,
                attention_type=self.attention_type,
            )
3867
        )
3868

3869
        batch_size, seqlen = query_layer.shape[1], query_layer.shape[0]
3870
        apply_qk_layer_scaling = self.apply_qk_layer_scaling and key_layer.dtype == torch.float16
3871
3872
3873
3874
3875
3876
3877
3878
3879

        # [b, np, sq, sk]
        output_size = (
            query_layer.size(1),
            query_layer.size(2),
            query_layer.size(0),
            key_layer.size(0),
        )

3880
        if key_layer.shape[2] != query_layer.shape[2]:
3881
3882
3883
            assert (
                query_layer.shape[2] % key_layer.shape[2] == 0
            ), "The number of attention heads must be divisible by the number of GQA groups!"
3884
            key_layer = key_layer.repeat_interleave(
3885
3886
                int(query_layer.shape[2] / key_layer.shape[2]), dim=2
            )
3887
            value_layer = value_layer.repeat_interleave(
3888
3889
                int(query_layer.shape[2] / value_layer.shape[2]), dim=2
            )
3890

3891
        # [sq, b, np, hn] -> [sq, b * np, hn]
3892
        query_layer = query_layer.reshape(output_size[2], output_size[0] * output_size[1], -1)
3893
3894
3895
3896
3897
3898
3899
3900
        # [sk, b, np, hn] -> [sk, b * np, hn]
        key_layer = key_layer.reshape(output_size[3], output_size[0] * output_size[1], -1)

        # preallocting result tensor: [b * np, sq, sk]
        matmul_result = torch.empty(
            output_size[0] * output_size[1],
            output_size[2],
            output_size[3],
3901
            dtype=query_layer.dtype,
3902
3903
3904
            device=torch.cuda.current_device(),
        )

3905
        scale = self.softmax_scale
3906
        if apply_qk_layer_scaling:
3907
            scale /= self.layer_number
3908
3909

        # Raw attention scores. [b * np, sq, sk]
3910
3911
3912
3913
3914
3915
        if core_attention_bias_type == "no_bias":
            matmul_result = torch.baddbmm(
                matmul_result,
                query_layer.transpose(0, 1),  # [b * np, sq, hn]
                key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
                beta=0.0,
3916
                alpha=scale,
3917
            ).view(*output_size)
3918
3919
3920
3921
3922
3923
3924

        elif core_attention_bias_type == "pre_scale_bias":
            assert core_attention_bias is not None, "core_attention_bias should not be None!"
            matmul_result = torch.bmm(
                query_layer.transpose(0, 1),  # [b * np, sq, hn]
                key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
            )
3925
            matmul_result = matmul_result.view(*output_size) + core_attention_bias
3926
            matmul_result *= scale
3927

3928
3929
3930
3931
        elif core_attention_bias_type in ["post_scale_bias", "alibi"]:
            if core_attention_bias_type == "post_scale_bias":
                assert core_attention_bias is not None, "core_attention_bias should not be None!"
            if core_attention_bias_type == "alibi":
3932
3933
                _, core_attention_bias = dpa_utils.get_alibi(
                    _alibi_cache,
3934
3935
3936
                    output_size[1],
                    output_size[2],
                    output_size[3],
3937
3938
                    actual_seqlens_q=actual_seqlens_q if "padding" in attn_mask_type else None,
                    actual_seqlens_kv=actual_seqlens_kv if "padding" in attn_mask_type else None,
3939
3940
                    alibi_slopes=alibi_slopes,
                    bottom_right_alignment=attn_mask_type not in ["causal", "padding_causal"],
3941
                )
3942
3943
3944
3945
3946
            matmul_result = torch.baddbmm(
                matmul_result,
                query_layer.transpose(0, 1),  # [b * np, sq, hn]
                key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
                beta=0.0,
3947
                alpha=scale,
3948
            )
3949
3950
            matmul_result = (matmul_result.view(*output_size) + core_attention_bias).to(
                dtype=query_layer.dtype
3951
            )
3952
3953
3954

        # attention scores and attention mask [b, np, sq, sk]
        softmax_scale = self.layer_number if apply_qk_layer_scaling else None
3955
        attention_probs = self.scale_mask_softmax(
3956
            matmul_result, attention_mask, attn_mask_type, softmax_scale
3957
        )
3958

3959
3960
3961
3962
3963
        # mask out the pad positions in softmax results, mostly for the rows (pad tokens from q)
        # the columns (pad tokens from k) are already zeroed out during softmax
        if "padding" in attn_mask_type:
            attention_probs = attention_probs.masked_fill(attention_mask, 0)

3964
3965
3966
3967
3968
3969
3970
3971
3972
3973
3974
3975
3976
3977
3978
        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        with self.attention_dropout_ctx():
            attention_probs = self.attention_dropout(attention_probs)

        # value_layer -> context layer.
        # [sk, b, np, hn] --> [b, np, sq, hn]
        output_size = (
            value_layer.size(1),
            value_layer.size(2),
            query_layer.size(0),
            value_layer.size(3),
        )

        # change view [sk, b * np, hn]
3979
        value_layer = value_layer.reshape(value_layer.size(0), output_size[0] * output_size[1], -1)
3980
3981

        # change view [b * np, sq, sk]
3982
        attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
3983
3984
3985
3986
3987
3988
3989

        # matmul: [b * np, sq, hn]
        context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))

        # change view [b, np, sq, hn]
        context_layer = context_layer.view(*output_size)

3990
        if q_format == "sbhd":
3991
3992
            # [b, np, sq, hn] --> [sq, b, np, hn]
            context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
3993

3994
3995
3996
            # [sq, b, np, hn] --> [sq, b, hp]
            context_layer = context_layer.view(seqlen, batch_size, -1)

3997
        if q_format == "bshd":
3998
3999
4000
4001
4002
            # [b, np, sq, hn] --> [b, sq, np, hn]
            context_layer = context_layer.permute(0, 2, 1, 3).contiguous()

            # [b, sq, np, hn] --> [b, sq, hp]
            context_layer = context_layer.view(batch_size, seqlen, -1)
4003

4004
4005
4006
4007
4008
4009
4010
4011
4012
4013
4014
4015
4016
4017
        if q_format == "thd":
            # [b, np, sq, hn] --> [b, sq, np, hn]
            context_layer = context_layer.permute(0, 2, 1, 3).contiguous()

            # [b, sq, np, hn] --> [tq, np, hn]
            context_layer = tex.convert_bshd_to_thd(
                context_layer,
                cu_seqlens_q,
                total_tokens,
            )

            # [tq, np, hn] --> [tq, hp]
            context_layer = context_layer.view(total_tokens, -1)

4018
4019
4020
4021
4022
        return context_layer


class _PrepareQKVForFA(torch.autograd.Function):
    """This class converts QKV from interleaved (s, b, ...) layout
4023
    to separate contiguous q, k, v tensors in (b, s, ...) layout."""
4024
4025

    @staticmethod
4026
4027
4028
4029
    def forward(
        _ctx: torch.autograd.function.FunctionCtx,  # unused
        query_layer: torch.Tensor,
        key_layer: torch.Tensor,
4030
        value_layer: torch.Tensor,
4031
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
4032
        # pylint: disable=missing-function-docstring
4033
4034
4035
4036
4037
4038
4039
4040
4041
4042
4043
        # All inputs received are non-contiguous tensors.
        # The `query_layer` tensor is used to access the
        # full memory region of the QKV tensor.
        qkv = tex.fa_prepare_fwd(query_layer)
        q, k, v = split_tensor_along_dim(qkv, 0, 3)
        query_layer = torch.squeeze(q, 0)
        key_layer = torch.squeeze(k, 0)
        value_layer = torch.squeeze(v, 0)
        return query_layer, key_layer, value_layer

    @staticmethod
4044
4045
4046
4047
    def backward(
        _ctx: torch.autograd.function.FunctionCtx,  # unused
        dq: torch.Tensor,
        dk: torch.Tensor,
4048
        dv: torch.Tensor,
4049
    ) -> Tuple[Union[torch.Tensor, None], ...]:
4050
        # pylint: disable=missing-function-docstring
4051
4052
4053
4054
        dqkv = tex.fa_prepare_bwd(dq, dk, dv)
        dq, dk, dv = split_tensor_along_dim(dqkv, -1, 3)
        return dq, dk, dv

4055

4056
class FlashAttention(torch.nn.Module):
4057
    """Dot product attention, using HazyResearch flash-attn package:
4058
    https://github.com/Dao-AILab/flash-attention
4059
4060
4061
4062
    """

    def __init__(
        self,
4063
        softmax_scale: float,
4064
4065
        attention_dropout: float = 0.0,
        attention_dropout_ctx: Optional[Callable] = nullcontext,
4066
4067
        attention_type: str = "self",
        layer_number: Optional[int] = None,
4068
        deterministic: bool = False,
4069
4070
4071
    ) -> None:
        super().__init__()

4072
        if fa_utils.is_installed:
4073
            assert (
4074
4075
                fa_utils.version >= fa_utils.version_required
            ), f"FlashAttention minimum version {fa_utils.version_required} is required."
4076
            assert (
4077
4078
                fa_utils.version <= fa_utils.max_version
            ), f"FlashAttention maximum version {fa_utils.max_version} is supported."
4079

4080
        self.softmax_scale = softmax_scale
4081
4082
        self.attention_dropout_ctx = attention_dropout_ctx
        self.attention_dropout = attention_dropout
4083
4084
        self.attention_type = attention_type
        self.layer_number = 1 if layer_number is None else layer_number
4085
        self.deterministic = deterministic
4086
        self.logger = logging.getLogger("FlashAttention")
4087
        self.logger.setLevel(attn_log._log_level)
4088
        if not self.logger.hasHandlers():
4089
            self.logger.addHandler(attn_log._stream_handler)
4090
4091
4092
4093
4094
4095

    def forward(
        self,
        query_layer: torch.Tensor,
        key_layer: torch.Tensor,
        value_layer: torch.Tensor,
4096
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
4097
4098
4099
        qkv_layout: str = "sbh3d",
        cu_seqlens_q: Optional[torch.Tensor] = None,
        cu_seqlens_kv: Optional[torch.Tensor] = None,
4100
4101
        max_seqlen_q: Optional[int] = None,
        max_seqlen_kv: Optional[int] = None,
4102
        attn_mask_type: str = "causal",
4103
        window_size: Optional[Tuple[int, int]] = None,
4104
        alibi_slopes: Optional[torch.Tensor] = None,
4105
        cp_group: Optional[Union[dist_group_type, List[dist_group_type]]] = None,
4106
        cp_global_ranks: List[int] = None,
4107
        cp_stream: torch.cuda.Stream = None,
4108
        cp_comm_type: str = "p2p",
4109
4110
        fp8: bool = False,
        fp8_meta: Optional[Dict[str, Any]] = None,
4111
        quantizers=None,
4112
4113
        inference_params: Optional[InferenceParams] = None,
        flash_attention_backend: Optional[PkgVersion] = PkgVersion("0"),
4114
4115
4116
    ) -> torch.Tensor:
        """flash-attn fprop"""

4117
4118
4119
4120
        assert all(
            x.dtype in [torch.float16, torch.bfloat16] or isinstance(x, Float8Tensor)
            for x in [query_layer, key_layer, value_layer]
        ), "FlashAttention only supports FP16 and BF16 data types, or Float8Tensors."
4121
4122
        assert (
            query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda
4123
        ), "FlashAttention currently only supports CUDA tensors."
4124
4125
        assert (
            qkv_layout in QKVLayouts
4126
        ), f"FlashAttention does not support qkv_layout = {qkv_layout}!"
4127

4128
4129
4130
4131
4132
4133
        cp_size = 1
        if isinstance(cp_group, dist_group_type):
            cp_size = get_distributed_world_size(cp_group)
        elif isinstance(cp_group, list):
            for group in cp_group:
                cp_size *= get_distributed_world_size(group)
4134
        context_parallel = cp_size > 1
4135

4136
4137
        # get q_format and kv_format for training and inference
        qkv_format, q_format, kv_format = dpa_utils.get_qkv_format(qkv_layout, inference_params)
4138

4139
        # convert q, k, v to bshd if they are in sbhd; qkv_format doesn't change
4140
4141
4142
4143
4144
4145
4146
4147
4148
4149
4150
4151
4152
        if all(not isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer]):
            if qkv_format == "sbhd":
                # For now just 128, will make it more general in the future
                if (
                    query_layer.shape[-1] == 128
                    and query_layer.shape[0] * query_layer.shape[1] >= 512
                    and qkv_layout == "sbh3d"
                ):
                    query_layer, key_layer, value_layer = _PrepareQKVForFA.apply(
                        query_layer, key_layer, value_layer
                    )
                else:
                    query_layer, key_layer, value_layer = [
4153
4154
                        x.transpose(0, 1).contiguous()
                        for x in (query_layer, key_layer, value_layer)
4155
                    ]
4156
4157
            elif q_format == "sbhd" and kv_format == "bshd":
                query_layer = query_layer.transpose(0, 1).contiguous()
4158
            if context_parallel:
4159
                query_layer, key_layer, value_layer = [
4160
4161
4162
4163
4164
                    x.contiguous() for x in (query_layer, key_layer, value_layer)
                ]
        else:
            if qkv_format == "sbhd":
                query_layer._data, key_layer._data, value_layer._data = [
4165
                    x.transpose(0, 1).contiguous()
4166
4167
                    for x in (query_layer._data, key_layer._data, value_layer._data)
                ]
4168
                query_layer, key_layer, value_layer = [
4169
                    Float8Tensor.make_like(x, data=x._data, shape=x._data.shape)
4170
4171
                    for x in (query_layer, key_layer, value_layer)
                ]
4172
4173
4174
4175
4176
            elif q_format == "sbhd" and kv_format == "bshd":
                query_layer._data = query_layer._data.transpose(0, 1).contiguous()
                query_layer = Float8Tensor.make_like(
                    query_layer, data=query_layer._data, shape=query_layer._data.shape
                )
4177
            if context_parallel:
4178
4179
                query_layer._data, key_layer._data, value_layer._data = [
                    x.contiguous() for x in (query_layer._data, key_layer._data, value_layer._data)
4180
                ]
4181

4182
4183
4184
4185
4186
4187
4188
4189
        # get batch_size, max_seqlen and cu_seqlens
        batch_size, context_len = None, None
        if inference_params is None:
            if qkv_format in ["sbhd", "bshd"]:
                batch_size = query_layer.shape[0]
                max_seqlen_q, max_seqlen_kv = query_layer.shape[1], key_layer.shape[1]
                max_seqlen_q *= cp_size
                max_seqlen_kv *= cp_size
4190

4191
4192
4193
4194
                if "padding" in attn_mask_type:
                    assert (
                        not context_parallel
                    ), "Padding mask not supported with context parallelism!"
4195

4196
4197
4198
4199
4200
                    # [b * s, h, d]
                    query_layer, key_layer, value_layer = [
                        x.reshape(x.shape[0] * x.shape[1], *x.shape[2:])
                        for x in [query_layer, key_layer, value_layer]
                    ]
4201

4202
                    if self.attention_type == "self":
4203
                        assert (
4204
4205
4206
4207
4208
4209
4210
4211
4212
4213
4214
4215
4216
4217
                            max_seqlen_q == max_seqlen_kv
                        ), "Maximum sequence length for Q and KV should be the same."
                        if cu_seqlens_q is None:
                            assert (
                                attention_mask is not None
                            ), "Please provide attention_mask for padding!"
                            cu_seqlens_q, indices_q = dpa_utils.get_cu_seqlens_and_indices(
                                attention_mask
                            )
                        else:
                            indices_q = dpa_utils.get_indices(max_seqlen_q, cu_seqlens_q)
                        cu_seqlens_kv = cu_seqlens_q
                        query_layer, key_layer, value_layer = dpa_utils.PackTensors.apply(
                            indices_q, query_layer, key_layer, value_layer
4218
                        )
4219
                    else:
4220
4221
4222
4223
4224
4225
4226
4227
4228
4229
4230
4231
4232
4233
4234
4235
4236
                        if cu_seqlens_q is None or cu_seqlens_kv is None:
                            assert (
                                attention_mask is not None
                            ), "Please provide attention_mask for padding!"
                            cu_seqlens_q, indices_q = dpa_utils.get_cu_seqlens_and_indices(
                                attention_mask[0]
                            )
                            cu_seqlens_kv, indices_kv = dpa_utils.get_cu_seqlens_and_indices(
                                attention_mask[1]
                            )
                        else:
                            indices_q = dpa_utils.get_indices(max_seqlen_q, cu_seqlens_q)
                            indices_kv = dpa_utils.get_indices(max_seqlen_kv, cu_seqlens_kv)
                        query_layer = dpa_utils.PackTensors.apply(indices_q, query_layer)
                        key_layer, value_layer = dpa_utils.PackTensors.apply(
                            indices_kv, key_layer, value_layer
                        )
4237
                else:
4238
4239
4240
4241
4242
4243
                    # Cumulative sequence lengths for unpadded data
                    if cu_seqlens_q is None:
                        cu_seqlens_q = dpa_utils.get_full_cu_seqlens(
                            batch_size,
                            max_seqlen_q,
                            query_layer.device,
4244
                        )
4245
4246
4247
4248
4249
                    if cu_seqlens_kv is None:
                        cu_seqlens_kv = dpa_utils.get_full_cu_seqlens(
                            batch_size,
                            max_seqlen_kv,
                            key_layer.device,
4250
                        )
4251
4252
4253
4254
4255
4256
4257
4258
4259
4260
4261
4262
4263
4264
4265
4266
4267
4268
4269
4270
4271
4272
4273
            elif qkv_format == "thd":
                assert (
                    cu_seqlens_q is not None and cu_seqlens_kv is not None
                ), "cu_seqlens_q and cu_seqlens_kv can not be None when qkv_format = thd!"
                if max_seqlen_q is None:
                    seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
                    max_seqlen_q = seqlens_q.max().item()
                if max_seqlen_kv is None:
                    seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
                    max_seqlen_kv = seqlens_kv.max().item()
        else:
            if qkv_format in ["sbhd_2bshd", "bshd"]:
                # q is in bshd in both cases from conversion above or the original input
                batch_size, context_len = query_layer.shape[:2]
                cu_seqlens_q = cu_seqlens_q[: batch_size + 1]
                cu_seqlens_kv = cu_seqlens_kv[: batch_size + 1]
                # convert from bshd to thd_2bshd for flash_attn_varlen_func/_with_kvcache;
                # kernel assumes tensor is contiguous
                if isinstance(query_layer, Float8Tensor):
                    query_layer._data = tex.convert_bshd_to_thd(
                        query_layer._data,
                        cu_seqlens_q,
                        batch_size * context_len,
4274
                    )
4275
4276
                    query_layer = Float8Tensor.make_like(
                        query_layer, data=query_layer._data, shape=query_layer._data.shape
4277
                    )
4278
4279
4280
4281
4282
                else:
                    query_layer = tex.convert_bshd_to_thd(
                        query_layer,
                        cu_seqlens_q,
                        batch_size * context_len,
4283
                    )
4284

4285
4286
4287
        use_flash_attn_3 = False
        if flash_attention_backend is not None and flash_attention_backend > PkgVersion("3.0.0b"):
            use_flash_attn_3 = True
4288
4289
4290
        if context_parallel and all(
            not isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer]
        ):
4291
4292
4293
            assert (
                alibi_slopes is None
            ), "Alibi slope bias addition is not supported with context parallelism."
4294
            with self.attention_dropout_ctx():
4295
                output = attn_forward_func_with_cp(
4296
4297
4298
4299
4300
4301
4302
4303
                    self.training,
                    query_layer,
                    key_layer,
                    value_layer,
                    cu_seqlens_q,
                    cu_seqlens_kv,
                    max_seqlen_q,
                    max_seqlen_kv,
4304
4305
                    cu_seqlens_q if qkv_format == "thd" else None,
                    cu_seqlens_kv if qkv_format == "thd" else None,
4306
                    self.attention_dropout if self.training else 0.0,
4307
4308
4309
                    cp_group,
                    cp_global_ranks,
                    cp_stream,
4310
                    cp_comm_type,
4311
                    softmax_scale=self.softmax_scale,
4312
                    qkv_format="bshd" if qkv_format == "sbhd" else qkv_format,
4313
                    attn_mask_type=attn_mask_type,
4314
                    deterministic=self.deterministic,
4315
                    window_size=window_size,
4316
                    quantizers=quantizers,
4317
                    pad_between_seqs=False,
4318
                    use_flash_attn_3=use_flash_attn_3,
4319
4320
                )
        else:
4321
4322

            from .cpu_offload import CPUOffloadEnabled
4323

4324
4325
4326
4327
4328
4329
            if CPUOffloadEnabled:
                tensor_list = [query_layer, key_layer, value_layer, cu_seqlens_q, cu_seqlens_kv]
                for tensor in tensor_list:
                    if tensor is not None:
                        tensor.activation_offloading = True

4330
            with self.attention_dropout_ctx():
4331
4332
4333
4334
4335
4336
4337
4338
4339
4340
4341
4342
                #       | API                     | use cases
                # ----------------------------------------------------------------------
                # FA v2 | flash_attn_func         | bshd/sbhd + not padding
                #       | flash_attn_varlen_func  | bshd/sbhd + padding
                #       |                         | thd + padding
                #       |                         | KV cache (not-paged/paged), i.e.
                #       |                         |     bshd/sbhd/thd + padding
                # FA v3 | flash_attn_func         | bshd/sbhd + not padding
                #       | flash_attn_varlen_func  | bshd/sbhd + padding
                #       |                         | thd + padding
                #       | flash_attn_with_kvcache | KV cache (not-paged/paged), i.e.
                #       |                         |     bshd/sbhd/thd + padding
4343
4344
4345
                fa_optional_forward_args_thd = []
                if qkv_format in ["bshd", "sbhd"] and "padding" not in attn_mask_type:
                    func = (
4346
4347
4348
4349
4350
4351
4352
4353
4354
4355
4356
4357
4358
4359
4360
4361
4362
4363
4364
4365
4366
4367
4368
4369
4370
4371
4372
4373
4374
4375
4376
4377
4378
4379
4380
4381
4382
4383
4384
4385
                        flash_attn_func if not use_flash_attn_3 else flash_attn_func_v3
                    )  # pylint: disable=possibly-used-before-assignment
                else:
                    if not use_flash_attn_3:
                        func = flash_attn_varlen_func
                    elif inference_params is None:
                        func = flash_attn_varlen_func_v3  # pylint: disable=possibly-used-before-assignment
                    else:
                        func = flash_attn_with_kvcache_v3  # pylint: disable=possibly-used-before-assignment
                    if not use_flash_attn_3 or inference_params is None:
                        fa_optional_forward_args_thd.append(cu_seqlens_q)
                        fa_optional_forward_args_thd.append(cu_seqlens_kv)
                        fa_optional_forward_args_thd.append(max_seqlen_q)
                        fa_optional_forward_args_thd.append(max_seqlen_kv)
                if not use_flash_attn_3:
                    fa_optional_forward_kwargs = {}
                    if fa_utils.v2_3_plus:
                        fa_optional_forward_kwargs["window_size"] = window_size
                    if fa_utils.v2_4_plus:
                        fa_optional_forward_kwargs["alibi_slopes"] = alibi_slopes
                    if fa_utils.v2_4_1_plus:
                        fa_optional_forward_kwargs["deterministic"] = self.deterministic
                    if inference_params is not None:
                        # use block_table kwarg to support thd_2bshd for non-paged
                        fa_optional_forward_kwargs["block_table"] = (
                            inference_params.cache_manager.page_table[:batch_size]
                            if inference_params.is_paged
                            else inference_params.cache_manager.batch_indices_post_step.unsqueeze(
                                1
                            )[:batch_size]
                        )
                    output = func(
                        query_layer,
                        key_layer,
                        value_layer,
                        *fa_optional_forward_args_thd,
                        self.attention_dropout if self.training else 0.0,
                        softmax_scale=self.softmax_scale,
                        causal="causal" in attn_mask_type,
                        **fa_optional_forward_kwargs,
4386
                    )
4387
                else:
4388
4389
                    fa_3_optional_forward_kwargs = {}
                    fa_3_optional_forward_kwargs["window_size"] = window_size
4390
4391
4392
4393
4394
4395
4396
4397
4398
4399
4400
4401
                    if inference_params is None:
                        fa_3_optional_forward_kwargs["deterministic"] = self.deterministic
                    else:
                        fa_3_optional_forward_kwargs["cu_seqlens_q"] = cu_seqlens_q
                        fa_3_optional_forward_kwargs["max_seqlen_q"] = max_seqlen_q
                        cache_seqlens = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
                        fa_3_optional_forward_kwargs["cache_seqlens"] = cache_seqlens
                        # flash_attn_with_kvcache accepts thd_2bshd for non-paged
                        if inference_params.is_paged:
                            fa_3_optional_forward_kwargs["page_table"] = (
                                inference_params.cache_manager.page_table[:batch_size]
                            )
4402
                    if fp8:
4403
                        QKV_quantizer = quantizers["scaling_fwd"][META_QKV]
4404
                        torch_dtype = get_fp8_torch_dtype(fp8_meta["recipe"], fprop_tensor=True)
4405
                        torch_orig_dtype = query_layer.dtype
4406
4407
4408
4409
4410
4411
4412
4413
4414
4415
4416

                        def convert_to_torch_float8(tensor, dtype):
                            out = torch.Tensor().to(device=tensor.device, dtype=dtype)
                            out.set_(
                                tensor._data.untyped_storage(),
                                tensor._data.storage_offset(),
                                tensor._data.shape,
                                tensor._data.stride(),
                            )
                            return out

4417
4418
4419
4420
4421
                        # "fp8_mha" decides outputs in fp8, while inputs are inferred from
                        # the real dtype
                        assert isinstance(key_layer, query_layer.__class__) and isinstance(
                            value_layer, query_layer.__class__
                        ), "q, k, and v must have the same type."
4422
                        if not isinstance(query_layer, Float8Tensor):
4423
                            query_layer, key_layer, value_layer = (
4424
                                QKV_quantizer(x) for x in [query_layer, key_layer, value_layer]
4425
                            )
4426
4427
4428
4429
                        batch_size = cu_seqlens_q.shape[0] - 1
                        num_heads_k = key_layer.shape[-2]
                        fa_3_optional_forward_kwargs["q_descale"] = (
                            query_layer._scale_inv.unsqueeze(0).repeat(batch_size, num_heads_k)
4430
                        )
4431
                        fa_3_optional_forward_kwargs["k_descale"] = key_layer._scale_inv.unsqueeze(
4432
                            0
4433
4434
4435
                        ).repeat(batch_size, num_heads_k)
                        fa_3_optional_forward_kwargs["v_descale"] = (
                            value_layer._scale_inv.unsqueeze(0).repeat(batch_size, num_heads_k)
4436
                        )
4437
4438
4439
                        query_layer, key_layer, value_layer = (
                            convert_to_torch_float8(x, torch_dtype)
                            for x in [query_layer, key_layer, value_layer]
4440
                        )
4441
                    try:
4442
                        output = func(
4443
4444
4445
4446
4447
4448
4449
4450
                            query_layer,
                            key_layer,
                            value_layer,
                            *fa_optional_forward_args_thd,
                            softmax_scale=self.softmax_scale,
                            causal="causal" in attn_mask_type,
                            **fa_3_optional_forward_kwargs,
                        )
4451
4452
                        if isinstance(output, (List, Tuple)):
                            output = output[0]
4453
                    except TypeError as e:
4454
                        if fa_utils.v3_0_0_beta:
4455
4456
4457
4458
                            e.args = (
                                e.args[0]
                                + ". Please update your flash-attn v3 (beta) installation as it "
                                + "may have added more supported arguments to its API. \n"
4459
                                + fa_utils.v3_installation_steps,
4460
4461
4462
4463
4464
4465
4466
4467
                            ) + e.args[1:]
                        raise

                    if fp8:
                        output = output.to(dtype=torch_orig_dtype)
                    if fp8 and fp8_meta["recipe"].fp8_mha:
                        O_quantizer = quantizers["scaling_fwd"][META_O]
                        output = O_quantizer(output)
4468

4469
4470
4471
4472
4473
4474
4475
4476
4477
4478
4479
4480
4481
4482
4483
4484
4485
4486
4487
4488
4489
        if inference_params is None:
            if qkv_format in ["sbhd", "bshd"] and "padding" in attn_mask_type:
                output = dpa_utils.UnpackTensor.apply(indices_q, batch_size * max_seqlen_q, output)
        elif qkv_format in ["bshd", "sbhd_2bshd"]:
            # all KV caching cases use thd_2bshd for calculation
            # convert results back to bshd from thd_2bshd
            if isinstance(query_layer, Float8Tensor):
                output._data = tex.convert_thd_to_bshd(
                    output._data,
                    cu_seqlens_q,
                    batch_size,
                    context_len,
                )
                output = Float8Tensor.make_like(output, data=output._data, shape=output._data.shape)
            else:
                output = tex.convert_thd_to_bshd(
                    output,
                    cu_seqlens_q,
                    batch_size,
                    context_len,
                )
4490

4491
        if q_format == "sbhd":
4492
4493
4494
4495
4496
4497
4498
4499
4500
4501
4502
4503
4504
4505
            # (bs)hd -> bs(hd) -> sb(hd)
            if fp8 and fp8_meta["recipe"].fp8_mha:
                output_data = (
                    output._data.reshape(batch_size, max_seqlen_q // cp_size, -1)
                    .transpose(0, 1)
                    .contiguous()
                )
                output = Float8Tensor.make_like(
                    output,
                    data=output_data,
                    shape=output_data.shape,
                )
            else:
                output = output.view(batch_size, max_seqlen_q // cp_size, -1).transpose(0, 1)
4506
        elif q_format == "bshd":
4507
4508
            # (bs)hd -> bs(hd)
            output = output.reshape(batch_size, max_seqlen_q // cp_size, -1)
4509
        elif q_format == "thd":
4510
4511
4512
4513
4514
4515
4516
4517
4518
4519
4520
4521
4522
4523
4524
4525
4526
4527
4528
4529
4530
4531
4532
4533
4534
4535
4536
4537
4538
4539
4540
4541
            # thd -> t(hd)
            output = output.reshape(output.shape[0], -1)

        return output.contiguous()


def _combine_tensors(
    tensors: List[torch.Tensor],
    dim: int,
) -> torch.Tensor:
    """Combine tensors along a particular dimension"""

    num_tensors = len(tensors)
    new_shape = list(tensors[0].shape)
    new_shape.insert(dim, num_tensors)
    if isinstance(tensors[0], Float8Tensor):
        new_stride = list(tensors[0]._data.stride())
        new_stride.insert(dim, int(new_stride[dim - 1] / num_tensors))
        combined_tensor = torch.Tensor().to(device=tensors[0].device, dtype=tensors[0]._data.dtype)
        combined_tensor.set_(
            tensors[0]._data.untyped_storage(),
            tensors[0]._data.storage_offset(),
            new_shape,
            new_stride,
        )
        combined_tensor = Float8Tensor.make_like(tensors[0], data=combined_tensor, shape=new_shape)
    else:
        new_stride = list(tensors[0].stride())
        new_stride.insert(dim, int(new_stride[dim - 1] / num_tensors))
        combined_tensor = torch.Tensor().to(device=tensors[0].device, dtype=tensors[0].dtype)
        combined_tensor.set_(
            tensors[0].untyped_storage(), tensors[0].storage_offset(), new_shape, new_stride
4542
4543
        )

4544
4545
    return combined_tensor

4546

4547
4548
4549
4550
class FusedAttnFunc(torch.autograd.Function):
    """Function for FusedAttention with separate Q, K, V tensors"""

    @staticmethod
4551
4552
4553
4554
4555
4556
4557
    def forward(
        ctx,
        is_training,
        max_seqlen_q,
        max_seqlen_kv,
        cu_seqlens_q,
        cu_seqlens_kv,
4558
4559
        cu_seqlens_q_padded,
        cu_seqlens_kv_padded,
4560
4561
        page_table_k,
        page_table_v,
4562
4563
4564
4565
4566
4567
4568
4569
4570
4571
        q,
        k,
        v,
        attn_bias,
        attn_scale,
        dropout_p,
        fast_zero_fill,
        qkv_layout,
        attn_bias_type,
        attn_mask_type,
4572
        window_size,
4573
4574
4575
4576
4577
        rng_gen,
        fused_attention_backend,
        use_FAv2_bwd,
        fp8,
        fp8_meta,
4578
        quantizers,
4579
        deterministic,
4580
    ):
4581
        # pylint: disable=missing-function-docstring
4582
        # "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype
4583
        is_input_fp8 = False
4584
        is_output_fp8 = fp8_meta["recipe"].fp8_mha if "recipe" in fp8_meta else False
4585
4586
4587
4588

        # FP16/BF16 attn:                  fake_dtype = torch.float16 or torch.bfloat16
        # FP8 attn, is_output_fp8 = False: fake_dtype = torch.float16 or torch.bfloat16
        # FP8 attn, is_output_fp8 = True:  fake_dtype = torch.float8_e4m3fn
4589
4590
4591
        fake_dtype = q.dtype

        QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer = (
4592
            dpa_utils.get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False)
4593
        )
4594
4595
        if fp8:
            fused_attention_backend = FusedAttnBackend["FP8"]
4596
4597
4598
            assert isinstance(k, q.__class__) and isinstance(
                v, q.__class__
            ), "q, k, and v must have the same type."
4599

4600
            is_input_fp8 = isinstance(q, Float8Tensor)
4601
            q_fp8, k_fp8, v_fp8 = None, None, None
4602
            if is_input_fp8:
4603
                q_fp8, k_fp8, v_fp8 = q, k, v
4604
4605
            else:
                # 1: qkv packed, 2: kv packed, 3: qkv separate
4606
                qkv_group = len(qkv_layout.replace("paged_kv_", "").split("_"))
4607
4608
4609
4610
4611
4612
4613
4614
4615
4616
4617
4618
4619
4620
4621
4622
4623
4624
4625
4626
                match qkv_group:
                    case 1:
                        dim = qkv_layout.find("3")
                        qkv = _combine_tensors([q, k, v], dim)
                        qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1])
                        qkv_fp8 = QKV_quantizer(qkv)
                        q_fp8, k_fp8, v_fp8 = _SplitAlongDim.apply(qkv_fp8, dim, [1, 1, 1], True)
                    case 2:
                        q_fp8 = QKV_quantizer(q)
                        dim = qkv_layout.split("_")[1].find("2")
                        kv = _combine_tensors([k, v], dim)
                        kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1])
                        kv_fp8 = QKV_quantizer(kv_c)
                        k_fp8, v_fp8 = _SplitAlongDim.apply(kv_fp8, dim, [1, 1], True)
                    case 3:
                        q_fp8 = QKV_quantizer(q)
                        k_fp8 = QKV_quantizer(k)
                        v_fp8 = QKV_quantizer(v)
                    case _:
                        raise "Invalid qkv_layout " + qkv_layout
4627
            # q_fp8, k_fp8, v_fp8, out_fp8: torch.float8_e4m3fn
4628
            out_fp8, aux_ctx_tensors = fused_attn_fwd(
4629
4630
4631
4632
4633
4634
4635
4636
                is_training,
                max_seqlen_q,
                max_seqlen_kv,
                cu_seqlens_q,
                cu_seqlens_kv,
                q_fp8,
                k_fp8,
                v_fp8,
4637
                fake_dtype,
4638
4639
                fused_attention_backend,
                attn_bias,
4640
4641
                cu_seqlens_q_padded,
                cu_seqlens_kv_padded,
4642
4643
                None,
                None,
4644
4645
                S_quantizer,
                O_quantizer,
4646
4647
4648
4649
4650
4651
                attn_scale,
                dropout_p,
                fast_zero_fill,
                qkv_layout,
                attn_bias_type,
                attn_mask_type,
4652
                window_size,
4653
4654
                rng_gen,
            )
4655
            if is_output_fp8:
4656
                out_ret = out_fp8
4657
            else:
4658
                out_ret = out_fp8.dequantize().view(out_fp8.shape)
4659
4660
            # is_output_fp8 = False: out_save.dtype = torch.float16 or torch.bfloat16
            # is_output_fp8 = True:  out_save.dtype = torch.float8_e4m3fn
4661
4662
            out_save = out_ret

4663
            if not int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
4664
                # 1: qkv packed, 2: kv packed, 3: qkv separate
4665
                if is_input_fp8:
4666
                    qkv_group = len(qkv_layout.replace("paged_kv_", "").split("_"))
4667
4668
4669
4670
                    if qkv_group == 1:
                        dim = qkv_layout.find("3")
                        qkv = _combine_tensors([q, k, v], dim)
                        qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1])
4671
4672
                        qkv_no_fp8 = qkv_c.dequantize().view(qkv.shape)
                        q, k, v = _SplitAlongDim.apply(qkv_no_fp8, dim, [1, 1, 1], True)
4673
                    if qkv_group == 2:
4674
                        q = q.dequantize()
4675
                        dim = qkv_layout.replace("paged_kv_", "").split("_")[1].find("2")
4676
4677
                        kv = _combine_tensors([k, v], dim)
                        kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1])
4678
4679
                        kv_no_fp8 = kv.dequantize()
                        k, v = _SplitAlongDim.apply(kv_no_fp8, dim, [1, 1], True)
4680
                    if qkv_group == 3:
4681
4682
4683
                        q = q.dequantize()
                        k = k.dequantize()
                        v = v.dequantize()
4684
                if is_output_fp8:
4685
4686
4687
                    out_save = out_fp8.dequantize()

            fp8_tensors = (q_fp8, k_fp8, v_fp8, out_fp8)
4688
        else:
4689
            # q, k, v, out_ret: torch.float16 or torch.bfloat16
4690
            out_ret, aux_ctx_tensors = fused_attn_fwd(
4691
4692
4693
4694
4695
4696
4697
4698
                is_training,
                max_seqlen_q,
                max_seqlen_kv,
                cu_seqlens_q,
                cu_seqlens_kv,
                q,
                k,
                v,
4699
                fake_dtype,
4700
4701
                fused_attention_backend,
                attn_bias,
4702
4703
                cu_seqlens_q_padded,
                cu_seqlens_kv_padded,
4704
4705
                page_table_k,
                page_table_v,
4706
4707
                None,  # s_quantizer
                None,  # o_quantizer
4708
4709
4710
4711
4712
4713
                attn_scale,
                dropout_p,
                fast_zero_fill,
                qkv_layout,
                attn_bias_type,
                attn_mask_type,
4714
                window_size,
4715
4716
                rng_gen,
            )
4717
            out_save = out_ret
4718
            fp8_tensors = (None, None, None, None)
4719

4720
4721
        ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1"))

4722
        from .cpu_offload import CPUOffloadEnabled
4723

4724
        if CPUOffloadEnabled:
4725
4726
4727
4728
4729
4730
4731
            if ctx.fp8:
                tensor_list = fp8_tensors
            else:
                tensor_list = [q, k, v, out_save]

            tensor_list.extend(aux_ctx_tensors)

4732
            qkv_layout = "sbhd_sbhd_sbhd"
4733
4734
4735
4736
            for tensor in tensor_list:
                if tensor is not None:
                    tensor.activation_offloading = True

4737
4738
        ctx.is_input_fp8 = is_input_fp8
        ctx.is_output_fp8 = is_output_fp8
4739
        qkvo_tensors = (q, k, v, out_save) if not ctx.fp8 else (None, None, None, None)
4740
4741
        tensors_to_save, tensor_objects = prepare_for_saving(
            *fp8_tensors,
4742
4743
4744
            *qkvo_tensors,
            cu_seqlens_q,
            cu_seqlens_kv,
4745
4746
            cu_seqlens_q_padded,
            cu_seqlens_kv_padded,
4747
4748
            *aux_ctx_tensors,
        )
4749
4750
        ctx.save_for_backward(*tensors_to_save)
        ctx.tensor_objects = tensor_objects
4751
        ctx.fp8_meta = fp8_meta
4752
4753
4754
4755
4756

        ctx.dQKV_quantizer = dQKV_quantizer
        ctx.dO_quantizer = dO_quantizer
        ctx.dP_quantizer = dP_quantizer
        ctx.S_quantizer = S_quantizer
4757
4758
4759
        if ctx.fp8:
            ctx.S_quantizer = S_quantizer.copy()
            ctx.S_quantizer.scale = S_quantizer.scale.clone()
4760

4761
4762
4763
4764
4765
4766
4767
4768
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_kv = max_seqlen_kv
        ctx.attn_scale = attn_scale
        ctx.dropout_p = dropout_p
        ctx.fast_zero_fill = fast_zero_fill
        ctx.qkv_layout = qkv_layout
        ctx.attn_bias_type = attn_bias_type
        ctx.attn_mask_type = attn_mask_type
4769
        ctx.window_size = window_size
4770
        ctx.fused_attention_backend = (
4771
            fused_attention_backend if ctx.fp8 else FusedAttnBackend["F16_arbitrary_seqlen"]
4772
        )
4773
        ctx.use_FAv2_bwd = use_FAv2_bwd
4774
        ctx.deterministic = deterministic
4775

4776
        return out_ret
4777
4778
4779

    @staticmethod
    def backward(ctx, d_out):
4780
        # pylint: disable=missing-function-docstring
4781
        if ctx.is_output_fp8:
4782
4783
4784
            assert isinstance(
                d_out, Float8Tensor
            ), "Gradient of the DPA output must be in Float8Tensor type for FP8 MHA."
4785

4786
4787
4788
4789
4790
        # FP16/BF16 attn:                  fake_dtype = torch.float16 or torch.bfloat16
        # FP8 attn, is_output_fp8 = False: fake_dtype = torch.float16 or torch.bfloat16
        # FP8 attn, is_output_fp8 = True:  fake_dtype = torch.float8_e5m2
        fake_dtype = d_out.dtype

4791
        d_out = d_out.contiguous()
4792
        (
4793
4794
4795
4796
            q_fp8,
            k_fp8,
            v_fp8,
            out_fp8,
4797
4798
4799
4800
4801
4802
            q,
            k,
            v,
            out,
            cu_seqlens_q,
            cu_seqlens_kv,
4803
4804
            cu_seqlens_q_padded,
            cu_seqlens_kv_padded,
4805
4806
4807
4808
4809
            *other_tensors,
        ) = restore_from_saved(ctx.tensor_objects, ctx.saved_tensors)

        aux_ctx_tensors = other_tensors

4810
4811
        if not aux_ctx_tensors[0].is_contiguous():
            aux_ctx_tensors[0] = aux_ctx_tensors[0].contiguous()
4812
        rest = [None]
4813
        if ctx.use_FAv2_bwd:
4814
            softmax_lse, rng_state = aux_ctx_tensors
4815
4816
4817
            dq = torch.empty_like(q)
            dk = torch.empty_like(k)
            dv = torch.empty_like(v)
4818
            d_out, q, k, v, out = [maybe_contiguous(x) for x in (d_out, q, k, v, out)]
4819
            flash_attn_cuda_bwd(
4820
4821
4822
4823
4824
4825
4826
4827
4828
4829
4830
4831
4832
4833
4834
4835
4836
4837
4838
                d_out,
                q,
                k,
                v,
                out,
                softmax_lse,
                dq,
                dk,
                dv,
                cu_seqlens_q,
                cu_seqlens_kv,
                ctx.max_seqlen_q,
                ctx.max_seqlen_kv,
                ctx.dropout_p,
                ctx.attn_scale,
                False,
                "causal" in ctx.attn_mask_type,
                None,
                rng_state,
4839
            )
4840
4841
4842
            dq = dq[..., : d_out.shape[-1]]
            dk = dk[..., : d_out.shape[-1]]
            dv = dv[..., : d_out.shape[-1]]
4843
        else:
4844
4845
            with torch.cuda.nvtx.range("_FusedAttn"):
                if ctx.fp8:
4846
                    if ctx.is_output_fp8:
4847
4848
                        d_out_fp8 = d_out
                    else:
4849
                        d_out_fp8 = ctx.dO_quantizer(d_out)
4850
4851
4852
                    dqkv_dtype = TE_DType[d_out_fp8._data.dtype]
                    # q_fp8, k_fp8, v_fp8, out_fp8:      torch.float8_e4m3fn
                    # d_out_fp8, dq_fp8, dk_fp8, dv_fp8: torch.float8_e5m2
4853
                    dq_fp8, dk_fp8, dv_fp8, *rest = fused_attn_bwd(
4854
4855
4856
4857
4858
4859
4860
4861
4862
                        ctx.max_seqlen_q,
                        ctx.max_seqlen_kv,
                        cu_seqlens_q,
                        cu_seqlens_kv,
                        q_fp8,
                        k_fp8,
                        v_fp8,
                        out_fp8,
                        d_out_fp8,
4863
4864
                        fake_dtype,
                        dqkv_dtype,
4865
                        aux_ctx_tensors,
4866
                        ctx.fused_attention_backend,
4867
4868
                        cu_seqlens_q_padded,
                        cu_seqlens_kv_padded,
4869
4870
4871
                        ctx.S_quantizer,
                        ctx.dP_quantizer,
                        ctx.dQKV_quantizer,
4872
4873
4874
4875
4876
4877
                        ctx.attn_scale,
                        ctx.dropout_p,
                        ctx.fast_zero_fill,
                        ctx.qkv_layout,
                        ctx.attn_bias_type,
                        ctx.attn_mask_type,
4878
4879
                        ctx.window_size,
                        ctx.deterministic,
4880
                    )
4881

4882
4883
                    # is_input_fp8 = False: dq, dk, dv: torch.float16 or torch.bfloat16
                    # is_input_fp8 = True:  dq, dk, dv: torch.float8_e5m2
4884
                    if not ctx.is_input_fp8:
4885
                        qkv_group = len(ctx.qkv_layout.replace("paged_kv_", "").split("_"))
4886
                        if qkv_group == 1:
4887
                            dim = ctx.qkv_layout.find("3")
4888
4889
                            dqkv_fp8_data = _combine_tensors(
                                [dq_fp8._data, dk_fp8._data, dv_fp8._data], dim
4890
                            )
4891
4892
4893
4894
4895
                            dqkv_fp8 = dq_fp8.make_like(
                                tensor=dq_fp8, data=dqkv_fp8_data, shape=dqkv_fp8_data.shape
                            )
                            dqkv = dqkv_fp8.dequantize()
                            dq, dk, dv = _SplitAlongDim.apply(dqkv, dim, [1, 1, 1], True)
4896
                        if qkv_group == 2:
4897
                            dq = dq_fp8.dequantize()
4898
4899
4900
4901
4902
                            dim = ctx.qkv_layout.split("_")[1].find("2")
                            dkv_fp8 = _combine_tensors([dk_fp8, dv_fp8], dim)
                            dkv_c_fp8 = dkv_fp8.view(
                                -1, dkv_fp8.shape[-3] * dkv_fp8.shape[-2] * dkv_fp8.shape[-1]
                            )
4903
4904
                            dkv = dkv_c_fp8.dequantize()
                            dk, dv = _SplitAlongDim.apply(dkv, dim, [1, 1], True)
4905
                        if qkv_group == 3:
4906
4907
4908
4909
4910
                            dq = dq_fp8.dequantize()
                            dk = dk_fp8.dequantize()
                            dv = dv_fp8.dequantize()
                    else:
                        dq, dk, dv = dq_fp8, dk_fp8, dv_fp8
4911
                else:
4912
4913
                    if isinstance(d_out, QuantizedTensor):
                        d_out = d_out.dequantize()
4914
4915
                    dqkv_dtype = TE_DType[d_out.dtype]
                    # q, k, v, out, d_out, dq, dk, dv: torch.float16 or torch.bfloat16
4916
                    dq, dk, dv, *rest = fused_attn_bwd(
4917
4918
4919
4920
4921
4922
4923
4924
4925
                        ctx.max_seqlen_q,
                        ctx.max_seqlen_kv,
                        cu_seqlens_q,
                        cu_seqlens_kv,
                        q,
                        k,
                        v,
                        out,
                        d_out,
4926
4927
                        fake_dtype,
                        dqkv_dtype,
4928
                        aux_ctx_tensors,
4929
                        ctx.fused_attention_backend,
4930
4931
                        cu_seqlens_q_padded,
                        cu_seqlens_kv_padded,
4932
4933
4934
4935
4936
4937
4938
4939
4940
                        None,
                        None,
                        None,
                        ctx.attn_scale,
                        ctx.dropout_p,
                        ctx.fast_zero_fill,
                        ctx.qkv_layout,
                        ctx.attn_bias_type,
                        ctx.attn_mask_type,
4941
4942
                        ctx.window_size,
                        ctx.deterministic,
4943
                    )
4944

4945
4946
        # if no_bias or alibi, return dqkv
        if ctx.attn_bias_type in ["no_bias", "alibi"]:
4947
4948
4949
4950
4951
4952
4953
4954
            return (
                None,
                None,
                None,
                None,
                None,
                None,
                None,
4955
4956
                None,
                None,
4957
4958
4959
4960
4961
4962
4963
4964
4965
4966
4967
4968
4969
4970
4971
4972
4973
4974
4975
                dq,
                dk,
                dv,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
            )
4976
        # else, return (dqkv, dbias)
4977
4978
4979
4980
4981
4982
4983
4984
        return (
            None,
            None,
            None,
            None,
            None,
            None,
            None,
4985
4986
            None,
            None,
4987
4988
4989
4990
4991
4992
4993
4994
4995
4996
4997
4998
4999
5000
5001
5002
5003
            dq,
            dk,
            dv,
            rest[0],
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
5004
            None,
5005
        )
5006

5007

5008
class FusedAttention(torch.nn.Module):
5009
5010
5011
5012
5013
5014
5015
5016
5017
    """Dot product attention, with multiple backends:

    1. FusedAttnBackend["F16_max512_seqlen"]
       cuDNN based fused attention for FP16/BF16 and <=512 sequence length.
    2. FusedAttnBackend["F16_arbitrary_seqlen"]
       cuDNN based fused attention for FP16/BF16 and any sequence length.

    Support matrix:

5018
5019
5020
5021
    | backend       | 1                       | 2                              |
    | flash based   | no                      | yes                            |
    | cuDNN based   | yes                     | yes                            |
    | qkv dtype     | fp16/bf16               | fp16/bf16                      |
5022
    | attn_type     | self/cross              | self/cross                     |
5023
    | qkv_layout    |                         |                                |
5024
    |  - (q,k,v)    | sb3hd, bs3hd            | sb3hd, bs3hd, sbh3d, bsh3d     |
5025
    |               | sbhd_sb2hd, bshd_bs2hd  | sbhd_sb2hd, bshd_bs2hd         |
5026
5027
    |               | bshd_bshd_bshd          | sbhd_sbh2d, bshd_bsh2d         |
    |               |                         | sbhd_sbhd_sbhd, bshd_bshd_bshd |
5028
5029
    | mask_type     | causal/padding/no_mask  | causal/padding/no_mask         |
    | bias_type     | post_scale_bias/no_bias | post_scale_bias/alibi/no_bias  |
5030
    | dropout       | yes                     | yes                            |
5031
5032
    | max_seqlen    | <=512, multiple of 64   | any, multiple of 64            |
    | head_dim      | 64                      | <=128, multiple of 8           |
5033
    | output dtype  | fp16/bf16               | fp16/bf16                      |
5034
5035
5036
5037
    """

    def __init__(
        self,
5038
        softmax_scale: float,
5039
5040
5041
        attention_dropout: float = 0.0,
        attention_dropout_ctx: Optional[Callable] = nullcontext,
        attention_type: str = "self",
5042
5043
        layer_number: Optional[int] = None,
        deterministic: bool = False,
5044
5045
5046
    ) -> None:
        super().__init__()

5047
        self.softmax_scale = softmax_scale
5048
5049
5050
        self.attention_dropout = attention_dropout
        self.attention_dropout_ctx = attention_dropout_ctx
        self.attention_type = attention_type
5051
5052
5053
        self.use_FAv2_bwd = os.getenv(
            "NVTE_FUSED_ATTN_USE_FAv2_BWD", "0"
        ) == "1" and get_device_compute_capability() == (9, 0)
5054
        self.layer_number = 1 if layer_number is None else layer_number
5055
        self.deterministic = deterministic
5056

5057
        def remove_extra_states_check(self, incompatible_keys):  # pylint: disable=unused-argument
5058
5059
            """
            Temporarily remove fused_attention._extra_state as a missing key
5060
            or an unexpected key when loading Transformer Engine checkpoints.
5061
5062
            Please store FP8 metadata as DotProductAttention's _extra_state,
            rather than FusedAttention's _extra_state. This hook will be
5063
            phased out in Transformer Engine 2.0.
5064
5065
            """
            for key in incompatible_keys.missing_keys:
5066
                if "fused_attention._extra_state" in key:
5067
                    incompatible_keys.missing_keys.remove(key)
5068
5069
5070
5071
5072
5073
5074
            for key in incompatible_keys.unexpected_keys:
                if "fused_attention._extra_state" in key:
                    incompatible_keys.unexpected_keys.remove(key)
                    warnings.warn(
                        "fused_attention._extra_state is not loaded from checkpoint. Please map "
                        "FusedAttention's _extra_state to DotProductAttention's _extra_state."
                    )
5075

5076
5077
        self.register_load_state_dict_post_hook(remove_extra_states_check)

5078
    @no_torch_dynamo()
5079
5080
5081
5082
5083
    def forward(
        self,
        query_layer: torch.Tensor,
        key_layer: torch.Tensor,
        value_layer: torch.Tensor,
5084
5085
5086
        qkv_layout: str = "sbh3d",
        cu_seqlens_q: Optional[torch.Tensor] = None,
        cu_seqlens_kv: Optional[torch.Tensor] = None,
5087
5088
        cu_seqlens_q_padded: Optional[torch.Tensor] = None,
        cu_seqlens_kv_padded: Optional[torch.Tensor] = None,
5089
5090
        max_seqlen_q: Optional[int] = None,
        max_seqlen_kv: Optional[int] = None,
5091
        attn_mask_type: str = "causal",
5092
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
5093
        window_size: Optional[Tuple[int, int]] = None,
5094
        fused_attention_backend: tex.NVTE_Fused_Attn_Backend = tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend,
5095
5096
5097
        core_attention_bias_type: str = "no_bias",
        core_attention_bias: Optional[torch.Tensor] = None,
        fast_zero_fill: bool = True,
5098
        cp_group: Optional[Union[dist_group_type, List[dist_group_type]]] = None,
5099
5100
        cp_global_ranks: List[int] = None,
        cp_stream: torch.cuda.Stream = None,
5101
        cp_comm_type: str = "p2p",
5102
5103
        fp8: bool = False,
        fp8_meta: Optional[Dict[str, Any]] = None,
5104
        quantizers=None,
5105
        pad_between_seqs: bool = False,
5106
        inference_params: Optional[InferenceParams] = None,
5107
5108
    ) -> torch.Tensor:
        """fused attention fprop"""
5109
5110
5111
        assert (
            fused_attention_backend != tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend
        ), "No fused attention backend supports this input combination!"
5112
5113
5114
5115
        assert all(
            x.dtype in [torch.float16, torch.bfloat16] or isinstance(x, Float8Tensor)
            for x in [query_layer, key_layer, value_layer]
        ), "FusedAttention only supports FP16 and BF16 data types, or Float8Tensors."
5116
5117
        assert (
            query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda
5118
        ), "FusedAttention only supports CUDA tensors."
5119
5120
        assert (
            qkv_layout in QKVLayouts
5121
        ), f"FusedAttention does not support qkv_layout = {qkv_layout}!"
5122

5123
5124
5125
5126
5127
5128
        cp_size = 1
        if isinstance(cp_group, dist_group_type):
            cp_size = get_distributed_world_size(cp_group)
        elif isinstance(cp_group, list):
            for group in cp_group:
                cp_size *= get_distributed_world_size(group)
5129
        context_parallel = cp_size > 1
5130

5131
5132
        # get q_format and kv_format for training and inference
        qkv_format, q_format, kv_format = dpa_utils.get_qkv_format(qkv_layout, inference_params)
5133

5134
5135
5136
5137
5138
5139
5140
5141
5142
5143
        # cuDNN can work with 0-length sequences in the batch for both bshd/sbhd and thd formats
        # however, for bshd/sbhd, q/k/v tensors need to have the same batch size as indicated by
        # cu_seqlens, whereas thd does not have this requirement
        # e.g. if q_format = bshd, and q.shape = [3, 1, 16, 64], we should have k.shape[0] =
        # v.shape[0] = q.shape[0], and cu_seqlens_q.shape = cu_seqlens_kv.shape = [4]
        if q_format in ["bshd", "sbhd"] or kv_format in ["bshd", "sbhd"]:
            batch_size = query_layer.shape[0] if q_format == "bshd" else query_layer.shape[1]
            cu_seqlens_q = cu_seqlens_q[: batch_size + 1]
            cu_seqlens_kv = cu_seqlens_kv[: batch_size + 1]

5144
5145
5146
5147
5148
5149
5150
5151
5152
5153
5154
5155
5156
5157
5158
5159
5160
5161
5162
5163
5164
5165
5166
5167
5168
5169
5170
5171
5172
5173
5174
5175
5176
5177
        page_table = None
        if inference_params is None:
            if qkv_format in ["sbhd", "bshd"]:
                if qkv_format == "sbhd":
                    batch_size = query_layer.shape[1]
                    max_seqlen_q = query_layer.shape[0]
                    max_seqlen_kv = key_layer.shape[0]
                if qkv_format == "bshd":
                    batch_size = query_layer.shape[0]
                    max_seqlen_q = query_layer.shape[1]
                    max_seqlen_kv = key_layer.shape[1]
                max_seqlen_q *= cp_size
                max_seqlen_kv *= cp_size
                if "padding" in attn_mask_type:
                    assert (
                        not context_parallel
                    ), "Padding mask not supported with context parallelism!"
                    if cu_seqlens_q is None or cu_seqlens_kv is None:
                        if attention_mask is None:
                            raise RuntimeError(
                                "Please provide attention_mask or cu_seqlens for padding!"
                            )
                        if self.attention_type == "self":
                            cu_seqlens_q = dpa_utils.get_cu_seqlens(attention_mask)
                            cu_seqlens_kv = cu_seqlens_q
                        else:
                            cu_seqlens_q = dpa_utils.get_cu_seqlens(attention_mask[0])
                            cu_seqlens_kv = dpa_utils.get_cu_seqlens(attention_mask[1])
                else:
                    if cu_seqlens_q is None:
                        cu_seqlens_q = dpa_utils.get_full_cu_seqlens(
                            batch_size,
                            max_seqlen_q,
                            query_layer.device,
5178
                        )
5179
5180
5181
5182
5183
5184
5185
5186
5187
5188
5189
5190
5191
5192
5193
5194
5195
                    if cu_seqlens_kv is None:
                        cu_seqlens_kv = dpa_utils.get_full_cu_seqlens(
                            batch_size,
                            max_seqlen_kv,
                            key_layer.device,
                        )
            if qkv_format == "thd":
                assert (
                    max_seqlen_q is not None
                    and max_seqlen_kv is not None
                    and cu_seqlens_q is not None
                    and cu_seqlens_kv is not None
                ), "max_seqlen_q/kv and cu_seqlens_q/kv can not be None when qkv_format is thd!"
        elif inference_params.is_paged:
            page_table = inference_params.cache_manager.page_table

        if (q_format == "thd" or "padding" in attn_mask_type) and cu_seqlens_q_padded is None:
5196
            cu_seqlens_q_padded = cu_seqlens_q
5197
        if (kv_format == "thd" or "padding" in attn_mask_type) and cu_seqlens_kv_padded is None:
5198
            cu_seqlens_kv_padded = cu_seqlens_kv
5199

5200
5201
5202
5203
5204
        use_FAv2_bwd = (
            self.use_FAv2_bwd
            and (core_attention_bias_type == "no_bias")
            and (fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen)
        )
5205

5206
5207
5208
5209
5210
5211
5212
5213
5214
5215
5216
        if fp8:
            assert fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_FP8, (
                f"cuDNN attention sub-backend {int(tex.NVTE_Fused_Attn_Backend.NVTE_FP8)}"
                " is required for FP8 attention!"
            )
            assert fp8_meta is not None, "FP8 metadata fp8_meta is required for FP8 attention!"
            assert not context_parallel or fp8_meta["recipe"].reduce_amax, (
                "Amax reduction across TP+CP group is necessary when using context parallelism with"
                " FP8!"
            )

5217
        if context_parallel:
5218
            assert (
5219
5220
                fp8
                or fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen
5221
5222
5223
5224
5225
5226
5227
            ), f"{fused_attention_backend} does not work with context parallelism!"
            assert core_attention_bias_type not in [
                "alibi"
            ], f"{core_attention_bias_type} is not supported with context parallelism!"
            query_layer, key_layer, value_layer = [
                x.contiguous() for x in (query_layer, key_layer, value_layer)
            ]
5228
5229
5230
            with self.attention_dropout_ctx():
                output = attn_forward_func_with_cp(
                    self.training,
5231
5232
5233
5234
5235
5236
5237
                    query_layer,
                    key_layer,
                    value_layer,
                    cu_seqlens_q,
                    cu_seqlens_kv,
                    max_seqlen_q,
                    max_seqlen_kv,
5238
5239
                    cu_seqlens_q_padded,
                    cu_seqlens_kv_padded,
5240
                    self.attention_dropout if self.training else 0.0,
5241
5242
5243
                    cp_group,
                    cp_global_ranks,
                    cp_stream,
5244
                    cp_comm_type,
5245
                    softmax_scale=self.softmax_scale,
5246
                    qkv_format=qkv_format,
5247
                    attn_mask_type=attn_mask_type,
5248
5249
                    attn_bias_type=core_attention_bias_type,
                    attn_bias=core_attention_bias,
5250
                    deterministic=self.deterministic,
5251
                    use_fused_attention=True,
5252
                    window_size=window_size,
5253
5254
                    fp8=fp8,
                    fp8_meta=fp8_meta,
5255
                    quantizers=quantizers,
5256
                    pad_between_seqs=pad_between_seqs,
5257
5258
                )
        else:
5259
5260
5261
5262
5263
5264
5265
            with self.attention_dropout_ctx():
                output = FusedAttnFunc.apply(
                    self.training,
                    max_seqlen_q,
                    max_seqlen_kv,
                    cu_seqlens_q,
                    cu_seqlens_kv,
5266
5267
                    cu_seqlens_q_padded,
                    cu_seqlens_kv_padded,
5268
5269
                    page_table,
                    page_table,
5270
5271
5272
5273
5274
5275
5276
5277
5278
5279
                    query_layer,
                    key_layer,
                    value_layer,
                    core_attention_bias,
                    self.softmax_scale,
                    self.attention_dropout if self.training else 0.0,
                    fast_zero_fill,
                    qkv_layout,
                    core_attention_bias_type,
                    attn_mask_type,
5280
                    window_size,
5281
5282
5283
5284
5285
                    None,  # rng_gen
                    fused_attention_backend,
                    use_FAv2_bwd,
                    fp8,
                    fp8_meta,
5286
                    quantizers,
5287
                    self.deterministic,
5288
                )
5289

5290
5291
        # ...hd -> ...(hd)
        return output.view(*output.shape[:-2], -1)
5292
5293


5294
class DotProductAttention(TransformerEngineBaseModule):
5295
5296
5297
5298
5299
5300
    """Allows the model to jointly attend to information from different
    representation subspaces as described in the paper:
    `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.

    .. note::

5301
        Argument :attr:`attention_mask` in the `forward` call is only used when
5302
        :attr:`attn_mask_type` includes '"padding"' or `"arbitrary"`.
5303
5304
5305

    .. warning::

5306
        FlashAttention uses a non-deterministic algorithm for optimal performance. To observe
5307
        deterministic behavior at the cost of performance, use FlashAttention version >= `2.4.1`
5308
5309
        and set the environment variable :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO=0`. In order
        to disable`flash-attn` entirely, set :attr:`NVTE_FLASH_ATTN=0`.
5310

5311
5312
5313
5314
5315
5316
5317
    .. note::

        Transformer Engine stores the FP8 metadata under a `._extra_state` key when checkpointing.
        As the FP8 attention support expands from one backend to multiple backends, the location
        of that key has also shifted (see `FP8 checkpoint compatibility <https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/faq.html#fp8-checkpoint-compatibility>`_).


5318
5319
5320
5321
    Parameters
    ----------
    num_attention_heads : int
                         number of attention heads in the transformer layer.
5322
5323
5324
    kv_channels : Union[int, Tuple[int, int]]
                the head size in key and value tensors. If the same, :attr:`kv_channels` can be
                an integer; if not, :attr:`kv_channels` should be a tuple of two integers.
5325
5326
5327
5328
5329
5330
5331
5332
    num_gqa_groups : Optional[int] = None
                    number of GQA groups in the transformer layer.
                    Grouped Query Attention is described in
                    `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
                    This only affects the keys and values, not the queries.
                    GQA-1 is equivalent to Multi-Query Attention
                    (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
                    is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
5333
5334
    attention_dropout: float, default = 0.0
                      dropout probability for the dropout op during multi-head attention.
5335
    attn_mask_type: str, default = `causal`
5336
                   type of attention mask passed into softmax operation, options are "`no_mask`",
5337
5338
5339
5340
5341
5342
5343
5344
5345
                   "`padding`", "`causal`", "`padding,causal`", "`causal,padding`",
                   "`padding_causal`", "`causal_bottom_right`", "`padding_causal_bottom_right`", and
                   "`arbitrary`", where "`padding,causal`", "`causal,padding`" and "`padding_causal`"
                   are equivalent. This arg can be overridden by :attr:`attn_mask_type` in the
                   `forward` method. It is useful for cases involving compilation/tracing, e.g.
                   ONNX export, and the forward arg is useful for dynamically changing mask types,
                   e.g. a different mask for training and inference.
                   1. For "`no_mask`", no attention mask is applied.
                   2. For "`causal`", "`causal_bottom_right`", or the causal mask in
5346
                   "`padding_causal`" and "`padding_causal_bottom_right`", Transformer Engine
5347
5348
5349
5350
5351
5352
5353
5354
5355
5356
5357
5358
5359
5360
                   calculates and applies an upper triangular mask to the softmax input.
                   No user input is needed. Causal masks without the "`bottom_right`" appendix align
                   the diagonal line to the top left corner of the softmax matrix. With
                   "`bottom_right`", the causal mask is aligned to the bottom right corner, which is
                   often used in inference/KV caching.
                   3. For "`padding`", or the padding mask in "`padding_causal`" and
                   "`padding_causal_bottom_right`", users need to provide the locations of padded
                   tokens, either via :attr:`cu_seqlens_q` and :attr:`cu_seqlens_kv` (both in shape
                   [batch_size + 1]), or via :attr:`attention_mask` (one tensor for self-attention
                   in shape [batch_size, 1, 1, max_seqlen_q], or two tensors in a tuple for
                   cross-attention in shapes [batch_size, 1, 1, max_seqlen_q] and
                   [batch_size, 1, 1, max_seqlen_kv]).
                   4. For "`arbitrary`", users need to provide a mask that is broadcastable to
                   the shape of softmax input [batch_size, num_heads, max_seqlen_q, max_seqlen_kv].
5361
5362
5363
5364
    window_size: Optional[Tuple[int, int]], default = `None`
                sliding window size for local attention, where query at position i attends to keys
                in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
                + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
5365
5366
5367
                window and causal mask specifically. Both `causal` and `causal_bottom_right` masks
                map to `window_size = (-1, 0)` and Transformer Engine distinguishes them based on
                `attn_mask_type`. Similar to :attr:`attn_mask_type`, `window_size` can
5368
                be overridden by :attr:`window_size` in `forward` as well.
5369
5370
    attention_type: str, default = `self`
                   type of attention, either "`self`" and "`cross`".
5371
5372
5373
    layer_number: int, default = `None`
                 layer number of the current `DotProductAttention` when multiple such modules
                 are concatenated, for instance in consecutive transformer blocks.
5374
5375
5376
    qkv_format: str, default = `sbhd`
               dimension format for `query_layer`, `key_layer` and `value_layer`,
               {`sbhd`, `bshd`, `thd`}. `s` stands for the sequence length, `b` batch size,
5377
               `h` the number of heads, `d` head size, and `t` the total number of tokens
5378
5379
5380
5381
5382
               in a batch, with `t = sum(s_i), for i = 0...b-1`. `sbhd` and `bshd` formats
               are used for when sequences in a batch are of equal length or padded to
               equal length, and the `thd` format is used for when sequences in a batch
               have different lengths. Please note that these formats do not reflect how
               tensors `query_layer`, `key_layer`, `value_layer` are laid out in memory.
5383
               For that, please use `get_qkv_layout` to gain the layout information.
5384
5385
    softmax_scale: Optional[float], default = `None`
                softmax scale for the attention scores. If `None`, defaults to
5386
                `1.0/math.sqrt(kv_channels if isinstance(kv_channels, int) else kv_channels[0])`.
5387
5388
5389
5390
5391
5392
5393
5394
5395

    Parallelism parameters
    ----------------------
    sequence_parallel : bool, default = `False`
                       if set to `True`, uses sequence parallelism.
    tp_size : int, default = 1
             tensor parallel world size.
    tp_group : ProcessGroup, default = `None`
              tensor parallel process group.
5396
    cp_group : Union[ProcessGroup, List[ProcessGroup]], default = `None`
5397
              context parallel process group.
5398
5399
5400
              ProcessGroup is for cp_comm_type of "p2p", "all_gather", and "a2a".
              List[ProcessGroup] is for cp_comm_type of "a2a+p2p", where cp_group[0]
              and cp_group[1] are for a2a and p2p communications respectively.
5401
5402
5403
5404
5405
5406
5407
    cp_global_ranks : list of global rank IDs, default = `None`
                     global rank IDs of GPUs that are in cp_group.
    cp_stream : CUDA stream, default = `None`
               context parallelism splits flash attention into multiple steps for
               compute and communication overlapping. To address the wave quantization
               issue of each split step, we add an additional CUDA stream so that we
               can overlap two flash attention kernels.
5408
    cp_comm_type : str, default = `p2p`
5409
                  inter-gpu communication type for context parallelism.
5410
                  Can be "p2p" or "all_gather" or "a2a" or "a2a+p2p".
5411
5412
5413
5414
5415
5416
                  "p2p": Exchange KV chunks with P2P communications in ring topology.
                         P2P is async and can be overlapped with attention compute.
                  "all_gather": All-gather to get full sequence of KV before attention.
                                The all-gather is not async, and cannot be overlapped.
                  "a2a": Like DeepSpeed Ulysses, scatter attention heads across the CP
                         group, and gather to get full sequence of QKV.
5417
5418
5419
                  "a2a+p2p": hierarchical CP implementation. First applying a2a to QKV
                  across each CP sub-group (e.g., via NVLink), then exchanging KV with
                  p2p between sub-groups (e.g., via IBLink).
5420
5421
5422
5423
5424
    """

    def __init__(
        self,
        num_attention_heads: int,
5425
        kv_channels: Union[int, Tuple[int, int]],
5426
        num_gqa_groups: Optional[int] = None,
5427
        attention_dropout: float = 0.0,
5428
        qkv_format: str = "sbhd",
5429
        attn_mask_type: str = "causal",
5430
        window_size: Optional[Tuple[int, int]] = None,
5431
5432
5433
5434
5435
        sequence_parallel: bool = False,
        tp_size: int = 1,
        get_rng_state_tracker: Optional[Callable] = None,
        tp_group: Optional[dist_group_type] = None,
        layer_number: Optional[int] = None,
5436
        attention_type: str = "self",
5437
        cp_group: Optional[Union[dist_group_type, List[dist_group_type]]] = None,
5438
        cp_global_ranks: List[int] = None,
5439
        cp_stream: torch.cuda.Stream = None,
5440
        cp_comm_type: str = "p2p",
5441
        softmax_scale: Optional[float] = None,
5442
5443
5444
    ) -> None:
        super().__init__()

5445
        self.logger = logging.getLogger("DotProductAttention")
5446
        self.logger.setLevel(attn_log._log_level)
5447
        if not self.logger.hasHandlers():
5448
            self.logger.addHandler(attn_log._stream_handler)
5449
        self.qkv_format = qkv_format
5450
        attn_mask_type = attn_mask_type.replace(",", "_")
5451
5452
        if attn_mask_type == "causal_padding":
            attn_mask_type = "padding_causal"
5453
        self.attn_mask_type = attn_mask_type
5454
        self.window_size = dpa_utils.check_set_window_size(attn_mask_type, window_size)
5455
5456
5457
5458
5459
5460
5461
        if tp_group is None:
            self.tp_size = tp_size
            if tp_size == 1:
                self.set_tensor_parallel_group(tp_group)
        else:
            self.tp_size = get_distributed_world_size(tp_group)
            self.set_tensor_parallel_group(tp_group)
5462
        self.get_rng_state_tracker = get_rng_state_tracker
5463
        self.num_attention_heads = num_attention_heads
5464
        self.layer_number = 1 if layer_number is None else layer_number
5465
5466
5467
        self.cp_group = cp_group
        self.cp_global_ranks = cp_global_ranks
        self.cp_stream = cp_stream
5468
        self.cp_comm_type = cp_comm_type
5469

5470
5471
5472
5473
5474
5475
        self.hidden_size_per_attention_head_k = (
            kv_channels if isinstance(kv_channels, int) else kv_channels[0]
        )
        self.hidden_size_per_attention_head_v = (
            kv_channels if isinstance(kv_channels, int) else kv_channels[1]
        )
5476

5477
        self.num_gqa_groups = num_attention_heads if num_gqa_groups is None else num_gqa_groups
5478
        self.num_gqa_groups_per_partition = int(self.num_gqa_groups // self.tp_size)
5479

5480
5481
5482
        assert (
            num_attention_heads % self.num_gqa_groups == 0
        ), "The number of attention heads must be divisible by the number of GQA groups!"
5483

5484
        self.rng_states_tracker = None
5485
5486
5487
        if sequence_parallel or get_rng_state_tracker is None:
            attention_dropout_ctx = nullcontext
        else:
5488
5489
5490
            self.rng_states_tracker = get_rng_state_tracker()
            set_all_rng_states(self.rng_states_tracker.get_states())
            attention_dropout_ctx = self.rng_states_tracker.fork
5491

5492
        if softmax_scale is None:
5493
5494
5495
            softmax_scale = 1.0 / math.sqrt(
                kv_channels if isinstance(kv_channels, int) else kv_channels[0]
            )
5496

5497
5498
5499
        self.deterministic = (
            not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1")))
            or torch.are_deterministic_algorithms_enabled()
5500
        )
5501
5502
5503
5504
5505
5506
5507
5508
5509
5510
5511
5512
5513
5514
5515
5516
5517
5518
5519
        # To use the workspace optimization path for determinism, please
        # set NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT=1 for cuDNN >=8.9.5 and <9.0.0,
        # and set NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 for cuDNN >=9.0.0.
        cudnn_version = get_cudnn_version()
        if (8, 9, 5) <= cudnn_version < (9, 0, 0):
            if self.deterministic:
                os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] = "1"

            # CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT
            # - unset:       enables workspace optimization when required workspace is <= 256MB
            #                or when bias gradient needs to be computed
            # - n:           enables workspace optimization when required workspace is <= n bytes
            # - -1:          enables workspace optimization always
            # - 0:           disables workspace optimization always
            if "NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT" in os.environ:
                if os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] == "0":
                    os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "0"
                if os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] == "1":
                    os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "-1"
5520

5521
        assert attention_type in AttnTypes, f"attention_type {attention_type} not supported"
5522
5523
5524
5525

        self.attention_type = attention_type
        self.attention_dropout = attention_dropout

5526
5527
5528
5529
5530
        attn_kwargs = {
            "attention_dropout": attention_dropout,
            "attention_dropout_ctx": attention_dropout_ctx,
        }

5531
5532
5533
5534
5535
5536
5537
        self.flash_attention = FlashAttention(
            softmax_scale,
            attention_type=attention_type,
            layer_number=layer_number,
            deterministic=self.deterministic,
            **attn_kwargs,
        )
5538

5539
        # Instantiating three types since use of flash-attn and FusedAttention
5540
        # might be ruled out due to forward inputs.
5541
5542
5543
5544
5545
5546
5547
        self.fused_attention = FusedAttention(
            softmax_scale,
            attention_type=attention_type,
            layer_number=layer_number,
            deterministic=self.deterministic,
            **attn_kwargs,
        )
5548

5549
        self.unfused_attention = UnfusedDotProductAttention(
5550
5551
5552
5553
            softmax_scale,
            attention_type=attention_type,
            **attn_kwargs,
            layer_number=layer_number,
5554
        )
5555

5556
5557
5558
        def remove_extra_states_check(self, incompatible_keys):  # pylint: disable=unused-argument
            """
            Temporarily remove core_attention._extra_state as a missing key
5559
5560
            when loading older Transformer Engine checkpoints. Will phase out
            this hook in Transformer Engine 2.0.
5561
5562
5563
5564
5565
5566
5567
            """
            for key in incompatible_keys.missing_keys:
                if "core_attention._extra_state" in key:
                    incompatible_keys.missing_keys.remove(key)

        self.register_load_state_dict_post_hook(remove_extra_states_check)

5568
5569
5570
5571
5572
5573
5574
5575
5576
5577
5578
5579
5580
5581
5582
5583
5584
5585
5586
5587
5588
5589
    def _load_from_state_dict(
        self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
    ):
        """
        This function helps to load Transformer Engine 1.6 and 1.7 checkpoints, where FP8 attention
        metadata is stored under the `core_attention.fused_attention._extra_state` key and not the
        `core_attention._extra_state` key. Please see `FP8 checkpoint compatibility
        <https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/faq.html#fp8-checkpoint-compatibility>`_ for more details.
        """
        fused_attn_key = False
        dot_product_attn_key = False
        for k in state_dict.keys():
            if "core_attention.fused_attention._extra_state" in k:
                fused_attn_key = True
            if "core_attention._extra_state" in k:
                dot_product_attn_key = True
        if fused_attn_key and not dot_product_attn_key:
            prefix = prefix + "fused_attention."
        super()._load_from_state_dict(
            state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
        )

5590
5591
5592
5593
    def _checkpointed_attention_forward(
        self,
        attention_func: Callable,
        *forward_args: Tuple[torch.Tensor, ...],
5594
        **forward_kwargs: Dict[str, Any],
5595
5596
5597
    ) -> torch.Tensor:
        """Forward method with activation checkpointing."""

5598
5599
        def custom_forward(*input_args, **input_kwargs):
            return attention_func(*input_args, **input_kwargs)
5600
5601
5602

        hidden_states = checkpoint(
            custom_forward,
5603
5604
5605
            distribute_saved_activations=False,
            get_rng_state_tracker=self.get_rng_state_tracker,
            tp_group=self.tp_group,
5606
            *forward_args,
5607
            **forward_kwargs,
5608
5609
5610
5611
        )

        return hidden_states

5612
5613
    def set_context_parallel_group(
        self,
5614
        cp_group: Union[dist_group_type, List[dist_group_type], None],
5615
5616
        cp_global_ranks: List[int],
        cp_stream: torch.cuda.Stream,
5617
        cp_comm_type: str = "p2p",
5618
    ) -> None:
5619
5620
5621
5622
5623
5624
        """
        Set the context parallel attributes for the given
        module before executing the forward pass.

        Parameters
        ----------
5625
        cp_group : Union[ProcessGroup, List[ProcessGroup]]
5626
                  context parallel process group.
5627
5628
5629
                  ProcessGroup is for cp_comm_type of "p2p", "all_gather", and "a2a".
                  List[ProcessGroup] is for cp_comm_type of "a2a+p2p", where cp_group[0]
                  and cp_group[1] are for a2a and p2p communications respectively.
5630
5631
5632
5633
        cp_global_ranks : List[int]
                         list of global ranks in the context group.
        cp_stream : torch.cuda.Stream
                   cuda stream for context parallel execution.
5634
        cp_comm_type : str, default = `p2p`
5635
                      inter-gpu communication type for context parallelism.
5636
                      Can be "p2p" or "all_gather" or "a2a" or "a2a+p2p".
5637
5638
5639
5640
5641
5642
                      "p2p": Exchange KV chunks with P2P communications in ring topology.
                             P2P is async and can be overlapped with attention compute.
                      "all_gather": All-gather to get full sequence of KV before attention.
                                    The all-gather is not async, and cannot be overlapped.
                      "a2a": Like DeepSpeed Ulysses, scatter attention heads across the CP
                             group, and gather to get full sequence of QKV.
5643
5644
5645
                      "a2a+p2p": hierarchical CP implementation. First applying a2a to QKV
                      across each CP sub-group (e.g., via NVLink), then exchanging KV with
                      p2p between sub-groups (e.g., via IBLink).
5646
        """
5647
5648
5649
        self.cp_group = cp_group
        self.cp_global_ranks = cp_global_ranks
        self.cp_stream = cp_stream
5650
        self.cp_comm_type = cp_comm_type
5651

5652
    @no_torch_dynamo(recursive=False)
5653
5654
5655
5656
5657
    def forward(
        self,
        query_layer: torch.Tensor,
        key_layer: torch.Tensor,
        value_layer: torch.Tensor,
5658
5659
5660
5661
5662
5663
5664
5665
        attention_mask: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] = None,
        qkv_format: str = None,
        cu_seqlens_q: torch.Tensor = None,
        cu_seqlens_kv: torch.Tensor = None,
        cu_seqlens_q_padded: torch.Tensor = None,
        cu_seqlens_kv_padded: torch.Tensor = None,
        max_seqlen_q: int = None,
        max_seqlen_kv: int = None,
5666
        attn_mask_type: Optional[str] = None,
5667
        window_size: Optional[Tuple[int, int]] = None,
5668
        checkpoint_core_attention: bool = False,
5669
5670
        core_attention_bias_type: str = "no_bias",
        core_attention_bias: Optional[torch.Tensor] = None,
5671
        alibi_slopes: Optional[torch.Tensor] = None,
5672
        fast_zero_fill: bool = True,
5673
        inference_params: Optional[InferenceParams] = None,
5674
        pad_between_seqs: Optional[bool] = None,
5675
5676
5677
5678
5679
5680
    ) -> torch.Tensor:
        """
        Dot Product Attention Layer.

        .. note::

5681
5682
            Argument :attr:`attention_mask` is only used when :attr:`attn_mask_type`
            includes '"padding"' or `"arbitrary"`.
5683

5684
5685
        .. note::

5686
5687
5688
5689
5690
5691
5692
5693
5694
5695
5696
5697
5698
            DotProductAttention supports three backends: 1) FlashAttention which calls
            HazyResearch/Dao-AILab's `flash-attn <https://arxiv.org/pdf/2305.13245.pdf>`_
            PyTorch API, 2) FusedAttention which has multiple fused attention implementations
            based on `cuDNN Graph API
            <https://docs.nvidia.com/deeplearning/cudnn/developer-guide/index.html#op-fusion>`_
            (see :attr:`FusedAttention` for more details on FusedAttention backends), and 3)
            UnfusedDotProductAttention which is the native PyTorch implementation
            with fused scaled masked softmax.

        .. note::

            Users can use environment variables :attr:`NVTE_FLASH_ATTN`, :attr:`NVTE_FUSED_ATTN`,
            and :attr:`NVTE_FUSED_ATTN_BACKEND` to control which DotProductAttention backend,
5699
            and FusedAttention backend if applicable, to use. Transformer Engine prioritizes
5700
5701
5702
5703
            FlashAttention over FusedAttention and over UnfusedDotProductAttention.
            If FusedAttention is being used, users can also choose to switch to flash-attn's
            implementation for backward by setting :attr:`NVTE_FUSED_ATTN_USE_FAv2_BWD=1`
            (default: 0), because of the performance differences between various versions of
5704
5705
            flash-attn and FusedAttention. Further, :attr:`NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT`
            can be used to enable (:attr:`1`) or disable (:attr:`0`) the workspace related
5706
            optimizations in FusedAttention. When unset, Transformer Engine determines the code path
5707
5708
            based on its internal logic. These optimizations trade memory for performance
            and should be used with care.
5709

5710
5711
5712
5713
5714
5715
5716
5717
5718
5719
5720
5721
5722
5723
5724
5725
5726
5727
5728
5729
5730
5731
5732
5733
5734
5735
5736
5737
5738
5739
5740
5741
5742
5743
5744
5745
5746
5747
5748
5749
5750
5751
5752
5753
5754
5755
5756
5757
5758
5759
5760
5761
5762
5763
        .. note::
            .. _cu_seqlens note:

            When training data has variable sequence lengths, users have two options.

            1. Manipulate the data and pad all sequences to the same length. Use
               :attr:`qkv_format` = {"bshd", "sbhd"} and
               :attr:`attn_mask_type` = {"padding", "padding_causal", "padding_causal_bottom_right"}.
               Pass in :attr:`cu_seqlens_q` and :attr:`cu_seqlens_kv`, or :attr:`attention_mask`
               (which will be converted to :attr:`cu_seqlens_q` and :attr:`cu_seqlens_kv`), to provide
               the real sequence length information. For example, a batch of 3 sequences
               [a a a b b c c c c] can be padded to [a a a PAD b b PAD PAD c c c c], and the cumulative
               sequence length tensors would be
               :attr:`cu_seqlens_q` = :attr:`cu_seqlens_kv` = [0, 3, 5, 9] for self-attention.

            2. Do not perform padding on training data. Use :attr:`qkv_format` = "thd" and
               :attr:`attn_mask_type` = {"padding", "padding_causal", "padding_causal_bottom_right"}.
               Pass in :attr:`cu_seqlens_q` and :attr:`cu_seqlens_kv`, or :attr:`attention_mask`,
               as in option 1. For example, a batch of 3 sequences [a a a b b c c c c] can be processed
               without any padding, and the sequence length tensors would be
               :attr:`cu_seqlens_q` = :attr:`cu_seqlens_kv` = [0, 3, 5, 9] for self-attention.

               In certain use cases, a varying number of identifier tokens are inserted between
               sequences. These tokens do not participate in the attention calculation.
               :attr:`cu_seqlens_q_padded` and :attr:`cu_seqlens_kv_padded` must be specified
               in such cases to correctly identify the start and end of each sequence in a batch.
               For example, a batch of 3 sequences [a a a 1 b b 2 2 c c c c 3] would have
               :attr:`cu_seqlens_q` = :attr:`cu_seqlens_kv` = [0, 3, 5, 9], and
               :attr:`cu_seqlens_q_padded` = :attr:`cu_seqlens_kv_padded` = [0, 4, 8, 13]
               for self-attention.

        .. note::
            .. _max_seqlen note:

            When :attr:`qkv_format` = {"bshd", "sbhd"}, sequences are of equal length in a batch.
            :attr:`max_seqlen_q` and :attr:`max_seqlen_kv` should be the same as the "s" dimension of
            :attr:`query_layer` and :attr:`key_layer` tensors. When unset, Transformer Engine will
            infer them as such.

            When :attr:`qkv_format` = "thd", sequences have varying lengths. :attr:`max_seqlen_q` and
            :attr:`max_seqlen_kv` should be the maximum query and key/value sequence length in a batch.
            When unset, Transformer Engine deduces them from :attr:`cu_seqlens_q` and :attr:`cu_seqlens_kv`.
            This deduction costs a small kernel and some CPU-GPU synchronization, and to avoid this
            overhead, users are recommended to obtain the maximum sequence lengths from the data loaders
            and pass them in.

            - As the maximum sequence lengths, batch size, and number of tokens change from batch to batch,
              dynamic shapes need to be supported for tensor construction. FlashAttention and
              UnfusedDotProductAttention naturally do so, while FusedAttention requires parameters to be static
              to create graphs before performance heuristics analysis. To reduce the number of graphs created
              per run, Transformer Engine 1.13+ quantizes relevant parameters: for cuDNN < 9.6, {batch size,
              :attr:`max_seqlen_q`, :attr:`max_seqlen_kv`}, and for cuDNN >= 9.6, {"t" dimension of
              :attr:`query_layer`, "t" dimension of :attr:`key_layer`}.

5764
5765
5766
5767
5768
5769
5770
5771
        Parameters
        ----------
        query_layer : torch.Tensor
                     Query tensor.
        key_layer : torch.Tensor
                   Key tensor.
        value_layer : torch.Tensor
                     Value tensor.
5772
5773
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]],
             default = `None`. Boolean tensor(s) used to mask out attention softmax input.
5774
             It should be `None` for causal masks and "`no_mask`". For padding masks, it should be
5775
5776
             a single tensor of [batch_size, 1, 1, seqlen_q] for self-attention, and a tuple of
             two tensors in shapes [batch_size, 1, 1, seqlen_q] and [batch_size, 1, 1, seqlen_kv]
5777
5778
5779
5780
             for cross-attention. For "`arbitrary`" mask, it should be in a shape broadcastable
             to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv]. A `True` value means
             the corresponding position is masked out and a `False` means that position
             is allowed to participate in attention.
5781
5782
5783
        qkv_format: str, default = `None`
                   If provided, overrides :attr:`qkv_format` from initialization.
        cu_seqlens_q: Optional[torch.Tensor], default = `None`
5784
                   Cumulative sum of sequence lengths (without offset) in a batch for `query_layer`,
5785
                   with shape [batch_size + 1] and dtype torch.int32.
5786
                   See :ref:`note<cu_seqlens note>` for more details.
5787
        cu_seqlens_kv: Optional[torch.Tensor], default = `None`
5788
5789
                   Cumulative sum of sequence lengths (without offset) in a batch for `key_layer`
                   and `value_layer`, with shape [batch_size + 1] and dtype torch.int32.
5790
                   See :ref:`note<cu_seqlens note>` for more details.
5791
5792
5793
5794
5795
        cu_seqlens_q_padded: Optional[torch.Tensor], default = `None`
                   Cumulative sum of sequence lengths (with offset) in a batch for
                   `query_layer`, with shape [batch_size + 1] and dtype torch.int32.
                   When there is no padding between sequences in a batch,
                   `cu_seqlens_q_padded = cu_seqlens_q`.
5796
                   See :ref:`note<cu_seqlens note>` for more details.
5797
5798
5799
5800
5801
        cu_seqlens_kv_padded: Optional[torch.Tensor], default = `None`
                   Cumulative sum of sequence lengths (with offset) in a batch for `key_layer`
                   and `value_layer`, with shape [batch_size + 1] and dtype torch.int32.
                   When there is no padding between sequences in a batch,
                   `cu_seqlens_kv_padded = cu_seqlens_kv`.
5802
                   See :ref:`note<cu_seqlens note>` for more details.
5803
5804
        max_seqlen_q: Optional[int], default = `None`
                      Maximum sequence length in `query_layer`.
5805
                      See :ref:`note<max_seqlen note>` for more details.
5806
5807
        max_seqlen_kv: Optional[int], default = `None`
                       Maximum sequence length in `key_layer` and `value_layer`.
5808
                       See :ref:`note<max_seqlen note>` for more details.
5809
5810
5811
5812
5813
5814
5815
        attn_mask_type: {'no_mask', 'padding', 'causal', 'padding,causal', 'causal,padding',
                       'padding_causal', 'causal_bottom_right', 'padding_causal_bottom_right',
                       'arbitrary'}, default = `None`. Type of attention mask passed into
                       softmax operation. 'padding,causal', 'causal,padding' and 'padding_causal'
                       are equivalent. By default, causal masks are aligned to the top left corner
                       of the softmax matrix. When "`bottom_right`" is specified in the mask type,
                       causal masks are aligned to the bottom right corner.
5816
        window_size: Optional[Tuple[int, int]], default = `None`
5817
                    Sliding window size for local attention.
5818
5819
5820
5821
5822
        checkpoint_core_attention : bool, default = `False`
                                   If true, forward activations for attention are recomputed
                                   during the backward pass in order to save memory that would
                                   otherwise be occupied to store the forward activations until
                                   backprop.
5823
        core_attention_bias_type: str, default = `no_bias`
5824
                    Bias type, {`no_bias`, `pre_scale_bias`, `post_scale_bias`, `alibi`}
5825
        core_attention_bias: Optional[torch.Tensor], default = `None`
5826
5827
                    Bias tensor for Q * K.T, shape [1, num_head, max_seqlen_q, max_seqlen_kv].
                    It should be 'None' for 'no_bias' and 'alibi' bias types.
5828
5829
5830
5831
        alibi_slopes: Optional[torch.Tensor], default = `None`
                     ALiBi slopes in FP32 and shape [nheads] or [batch_size, nheads].
                     It adds a bias of (-alibi_slope * (i + seqlen_k - seqlen_q - j))
                     to the attention score of query i and key j.
5832
        fast_zero_fill: bool, default = `True`
5833
                    Whether to use the fast path to set output tensors to 0 or not.
5834
5835
5836
5837
5838
5839
5840
5841
5842
5843
        inference_params: Optional[InferenceParams], default = `None`
            Optimizes execution performance during inference by caching Keys and Values of the
            current decoding iteration. These cached values are appended to the K and V values
            computed in previous iterations, eliminating the need to recalculate them for the
            entire sequence.
            Initialization of `inference_params` is required prior to use to ensure sufficient
            memory allocation.
            Adjustments of the sequence_len_offset should be done after a complete forward pass.
            If rotary positional embeddings (RoPE) are utilized, they must be prepared beforehand.
            Supports "sbhd" and "bshd" layouts, with the "sbhd" layout being more efficient.
5844
5845
5846
        pad_between_seqs: Optional[bool], default = `None`
            If None, inferred from qkv_format, cu_seqlens and cu_seqlens_padded.
            If true, there are padding tokens between individual sequences in a packed batch.
5847
        """
5848

5849
5850
5851
5852
5853
        with self.prepare_forward(
            query_layer,
            num_gemms=3,
            allow_non_contiguous=True,
        ) as query_layer:
5854
5855
5856
5857
5858
5859
5860
5861
5862
5863
            # checks for RNG
            if self.rng_states_tracker is not None and is_graph_capturing():
                assert isinstance(
                    self.rng_states_tracker, CudaRNGStatesTracker
                ), "Unsupported RNG states tracker."
                assert (
                    graph_safe_rng_available()
                ), "Upgrade PyTorch version to get RNG manipulation support for cuda graph capture."

            # checks for FP8
5864
5865
5866
5867
            if self.fp8:
                if self.fp8_meta["recipe"].fp8_mha:
                    if not self.fp8_meta["recipe"].fp8_dpa:
                        self.fp8_meta["recipe"].fp8_dpa = True
5868
                        self.logger.warning(
5869
5870
5871
                            """Forcing fp8_meta["recipe"].fp8_dpa=True due to """
                            """fp8_meta["recipe"].fp8_mha=True"""
                        )
5872
5873
5874
5875
5876
5877
5878
5879
5880
5881
            if self.fp8 and self.fp8_meta["recipe"].fp8_dpa:
                forward_dtype = get_fp8_te_dtype(self.fp8_meta["recipe"], fprop_tensor=True)
                backward_dtype = get_fp8_te_dtype(self.fp8_meta["recipe"], fprop_tensor=False)
                assert forward_dtype in [
                    tex.DType.kFloat8E4M3,
                    tex.DType.kFloat8E5M2,
                ] and backward_dtype in [
                    tex.DType.kFloat8E4M3,
                    tex.DType.kFloat8E5M2,
                ], """DotProductAttention only supports "E4M3" and "E5M2" FP8 data types."""
5882

5883
            # checks for q/k/v shapes
5884
5885
5886
            assert (
                query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda
            ), "DotProductAttention only supports CUDA tensors."
5887
5888
5889
            assert (
                query_layer.dtype == key_layer.dtype and query_layer.dtype == value_layer.dtype
            ), "Queries, keys and values must have the same data type!"
5890
5891
5892
            assert (
                key_layer.shape[:-1] == value_layer.shape[:-1]
            ), "Keys and values must have the same batch size, sequence length and number of heads!"
5893
5894
            num_attention_heads = query_layer.shape[-2]
            num_gqa_groups = key_layer.shape[-2]
5895
            assert (
5896
5897
5898
5899
5900
5901
                query_layer.shape[-1] == key_layer.shape[-1]
            ), "Queries and keys must have the same head dimension!"
            head_dim_qk, head_dim_v = query_layer.shape[-1], value_layer.shape[-1]
            assert (
                head_dim_qk == self.hidden_size_per_attention_head_k
            ), f"Keys have head_dim = {head_dim_qk}, "
5902
5903
            "but expected head_dim = {self.hidden_size_per_attention_head_k}!"
            assert (
5904
5905
                head_dim_v == self.hidden_size_per_attention_head_v
            ), f"Values have head_dim = {head_dim_v}, "
5906
            "but expected head_dim = {self.hidden_size_per_attention_head_v}!"
5907
5908
5909
5910
            assert num_gqa_groups == self.num_gqa_groups_per_partition, (
                "Keys and values must have num_gqa_group ="
                f" {self.num_gqa_groups_per_partition} heads! Found {num_gqa_groups}."
            )
5911

5912
            # checks for attention mask
5913
5914
5915
5916
5917
5918
            if attn_mask_type is None:
                attn_mask_type = self.attn_mask_type
            else:
                attn_mask_type = attn_mask_type.replace(",", "_")
                if attn_mask_type == "causal_padding":
                    attn_mask_type = "padding_causal"
5919
            assert (
5920
5921
                attn_mask_type in AttnMaskTypes
            ), f"Attention mask type {attn_mask_type} is not supported!"
5922

5923
            # checks for sliding window
5924
5925
            if window_size is None:
                window_size = self.window_size
5926
            window_size = dpa_utils.check_set_window_size(attn_mask_type, window_size)
5927

5928
5929
5930
            # checks for qkv_format
            if qkv_format is None:
                qkv_format = self.qkv_format
5931
5932
5933
5934
5935
            assert qkv_format in [
                "sbhd",
                "bshd",
                "thd",
            ], "DotProductAttention only supports qkv_format = {'sbhd', 'bshd', 'thd'}!"
5936
5937
5938
5939
5940
5941
5942
5943
5944
5945
5946
5947
5948
            batch_size = None
            if qkv_format in ["sbhd", "bshd"]:
                assert all(
                    len(x.shape) == 4 for x in (query_layer, key_layer, value_layer)
                ), f"Queries, keys and values must be 4D tensors when {qkv_format=}!"
                if qkv_format == "sbhd":
                    batch_size = query_layer.shape[1]
                    max_seqlen_q = query_layer.shape[0] if max_seqlen_q is None else max_seqlen_q
                    max_seqlen_kv = key_layer.shape[0] if max_seqlen_kv is None else max_seqlen_kv
                else:
                    batch_size = query_layer.shape[0]
                    max_seqlen_q = query_layer.shape[1] if max_seqlen_q is None else max_seqlen_q
                    max_seqlen_kv = key_layer.shape[1] if max_seqlen_kv is None else max_seqlen_kv
5949
            if qkv_format == "thd":
5950
                assert all(
5951
5952
                    len(x.shape) == 3 for x in (query_layer, key_layer, value_layer)
                ), "Queries, keys and values must be 3D tensors when qkv_format = thd!"
5953
5954
5955
                assert (
                    "padding" in attn_mask_type
                ), "Attention mask type must be padding or padding_causal for qkv_format=thd!"
5956
5957
5958
5959
5960
5961
5962
5963
5964
5965
5966
                assert (
                    cu_seqlens_q is not None and cu_seqlens_kv is not None
                ), "cu_seqlens_q and cu_seqlens_kv can not be None when qkv_format = thd!"
                assert (
                    cu_seqlens_q.shape == cu_seqlens_kv.shape
                    and len(cu_seqlens_q.shape) == 1
                    and len(cu_seqlens_kv.shape) == 1
                ), "cu_seqlens_q and cu_seqlens_q must both have shape [batch_size + 1]!"
                assert (
                    cu_seqlens_q.dtype == torch.int32 and cu_seqlens_kv.dtype == torch.int32
                ), "cu_seqlens_q and cu_seqlens_q must both be in dtype torch.int32!"
5967
                batch_size = len(cu_seqlens_q) - 1
5968
                if max_seqlen_q is None:
5969
5970
5971
5972
                    if cu_seqlens_q_padded is not None:
                        seqlens_q = cu_seqlens_q_padded[1:] - cu_seqlens_q_padded[:-1]
                    else:
                        seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
5973
                    max_seqlen_q = int((seqlens_q.max().item() + 63) // 64 * 64)
5974
                if max_seqlen_kv is None:
5975
5976
5977
5978
                    if cu_seqlens_kv_padded is not None:
                        seqlens_kv = cu_seqlens_kv_padded[1:] - cu_seqlens_kv_padded[:-1]
                    else:
                        seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
5979
                    max_seqlen_kv = int((seqlens_kv.max().item() + 63) // 64 * 64)
5980

5981
5982
5983
5984
5985
5986
5987
5988
5989
5990
5991
5992
5993
5994
5995
5996
5997
5998
5999
6000
6001
6002
6003
6004
6005
6006
6007
6008
6009
6010
6011
6012
6013
6014
6015
6016
6017
6018
6019
6020
6021
6022
6023
6024
6025
6026
6027
6028
6029
6030
6031
6032
6033
6034
6035
6036
6037
6038
6039
6040
6041
6042
6043
6044
6045
6046
6047
6048
6049
6050
            # update KV cache and retrieve saved tokens from cache for inference
            if inference_params is not None:
                assert self.layer_number is not None, "Layer number must be set!"

                # convert top-left causal to bottom-right causal due to KV caching
                # users can still use the same attention mask for inference as for training
                assert "padding" in attn_mask_type, "KV caching requires padding mask!"
                if attn_mask_type == "padding_causal":
                    attn_mask_type = attn_mask_type + "_bottom_right"

                self.attention_type = "cross"
                self.flash_attention.attention_type = self.attention_type
                self.fused_attention.attention_type = self.attention_type
                self.unfused_attention.attention_type = self.attention_type

                query_layer, key_layer, value_layer = [
                    x.contiguous() if not x.is_contiguous() else x
                    for x in [query_layer, key_layer, value_layer]
                ]

                # get full K/V tensors from cache and adjust cu_seqlens, qkv_format based on the cache
                (
                    key_layer,
                    value_layer,
                    cu_seqlens_q,
                    cu_seqlens_kv,
                    max_seqlen_kv,
                    qkv_format,
                ) = inference_params.step(
                    self.layer_number,
                    key_layer,
                    value_layer,
                    qkv_format,
                )
                cu_seqlens_q_padded = None
                cu_seqlens_kv_padded = None

            # get qkv's memory layout
            if all(isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer]):
                (
                    qkv_layout,
                    query_layer._data,
                    key_layer._data,
                    value_layer._data,
                    q_format,
                    kv_format,
                ) = dpa_utils.get_qkv_layout(
                    query_layer._data,
                    key_layer._data,
                    value_layer._data,
                    qkv_format=qkv_format,
                    inference_params=inference_params,
                )
            else:
                (
                    qkv_layout,
                    query_layer,
                    key_layer,
                    value_layer,
                    q_format,
                    kv_format,
                ) = dpa_utils.get_qkv_layout(
                    query_layer,
                    key_layer,
                    value_layer,
                    qkv_format=qkv_format,
                    inference_params=inference_params,
                )

            # adjust max_seqlen and cu_seqlens for CP
6051
6052
6053
6054
6055
6056
            cp_size = 1
            if isinstance(self.cp_group, dist_group_type):
                cp_size = get_distributed_world_size(self.cp_group)
            elif isinstance(self.cp_group, list):
                for group in self.cp_group:
                    cp_size *= get_distributed_world_size(group)
6057
            context_parallel = cp_size > 1
6058
            if q_format in ["sbhd", "bshd"]:
6059
                max_seqlen_q *= cp_size
6060
                if cu_seqlens_q is None:
6061
6062
6063
6064
                    if "padding" in attn_mask_type:
                        assert (
                            attention_mask is not None
                        ), "Please provide attention_mask for padding!"
6065
                        if self.attention_type == "self":
6066
                            cu_seqlens_q = dpa_utils.get_cu_seqlens(attention_mask)
6067
                        else:
6068
                            cu_seqlens_q = dpa_utils.get_cu_seqlens(attention_mask[0])
6069
                    else:
6070
                        cu_seqlens_q = dpa_utils.get_full_cu_seqlens(
6071
6072
6073
6074
                            batch_size,
                            max_seqlen_q,
                            query_layer.device,
                        )
6075
6076
6077
6078
6079
6080
6081
6082
6083
6084
6085
6086
            if kv_format in ["sbhd", "bshd"]:
                max_seqlen_kv *= cp_size
                if cu_seqlens_kv is None:
                    if "padding" in attn_mask_type:
                        assert (
                            attention_mask is not None
                        ), "Please provide attention_mask for padding!"
                        if self.attention_type == "self":
                            cu_seqlens_kv = dpa_utils.get_cu_seqlens(attention_mask)
                        else:
                            cu_seqlens_kv = dpa_utils.get_cu_seqlens(attention_mask[1])
                    else:
6087
                        cu_seqlens_kv = dpa_utils.get_full_cu_seqlens(
6088
6089
6090
6091
                            batch_size,
                            max_seqlen_kv,
                            key_layer.device,
                        )
6092

6093
            # set ALiBi attributes
6094
6095
6096
6097
6098
6099
6100
6101
            global _alibi_cache
            if alibi_slopes is not None:
                assert (
                    core_attention_bias_type == "alibi"
                ), "core_attention_bias_type must be alibi in order to use alibi_slopes!"
                if self.layer_number == 1:
                    _alibi_cache["_alibi_slopes_require_update"] = True
                    _alibi_cache["_alibi_bias_require_update"] = True
6102
            bottom_right_alignment = (attn_mask_type not in ["causal", "padding_causal"],)
6103
6104
6105
6106
6107
6108
6109
6110
            if core_attention_bias_type == "alibi":
                assert (
                    core_attention_bias is None
                ), "core_attention_bias must be None when core_attention_bias_type is alibi!"
                if (
                    _alibi_cache["_num_heads"] != query_layer.shape[-2]
                    or _alibi_cache["_max_seqlen_q"] != max_seqlen_q
                    or _alibi_cache["_max_seqlen_kv"] != max_seqlen_kv
6111
                    or _alibi_cache["_bottom_right_alignment"] != bottom_right_alignment
6112
6113
6114
6115
6116
                    or _alibi_cache["_alibi_slopes"] is None
                ):
                    _alibi_cache["_alibi_slopes_require_update"] = True
                    _alibi_cache["_alibi_bias_require_update"] = True

6117
            # detect bias shape
6118
6119
            core_attention_bias_shape = None
            if core_attention_bias is not None:
6120
                if (
6121
6122
                    core_attention_bias.shape[0] == batch_size
                    and core_attention_bias.shape[1] == query_layer.shape[-2]
6123
                ):
6124
6125
6126
6127
6128
6129
6130
6131
6132
6133
6134
6135
6136
6137
6138
6139
6140
                    core_attention_bias_shape = "bhss"
                elif (
                    core_attention_bias.shape[0] == 1
                    and core_attention_bias.shape[1] == query_layer.shape[-2]
                ):
                    core_attention_bias_shape = "1hss"
                elif (
                    core_attention_bias.shape[0] == batch_size and core_attention_bias.shape[1] == 1
                ):
                    core_attention_bias_shape = "b1ss"
                elif core_attention_bias.shape[0] == 1 and core_attention_bias.shape[1] == 1:
                    core_attention_bias_shape = "11ss"
                else:
                    assert (
                        False
                    ), "core_attention_bias must be in one of {bhss, 1hss, b1ss, 11ss} shapes"

6141
6142
6143
6144
6145
6146
6147
6148
6149
6150
6151
            if pad_between_seqs is None:
                if qkv_format == "thd":
                    pad_between_seqs = (
                        cu_seqlens_q_padded is not None
                        and not torch.equal(cu_seqlens_q_padded[:-1], cu_seqlens_q[:-1])
                    ) or (
                        cu_seqlens_kv_padded is not None
                        and not torch.equal(cu_seqlens_kv_padded[:-1], cu_seqlens_kv[:-1])
                    )
                else:
                    pad_between_seqs = False
6152

6153
            # gather attention params for get_attention_backend
6154
            attention_params = dpa_utils.AttentionParams(
6155
6156
6157
6158
                qkv_type=type(query_layer),
                qkv_dtype=query_layer.dtype,
                qkv_layout=qkv_layout,
                batch_size=batch_size,
6159
6160
                num_heads=num_attention_heads,
                num_gqa_groups=num_gqa_groups,
6161
6162
                max_seqlen_q=max_seqlen_q,
                max_seqlen_kv=max_seqlen_kv,
6163
6164
                head_dim_qk=head_dim_qk,
                head_dim_v=head_dim_v,
6165
6166
6167
6168
6169
6170
6171
6172
6173
6174
6175
                attn_mask_type=attn_mask_type,
                window_size=window_size,
                alibi_slopes_shape=alibi_slopes.shape if alibi_slopes is not None else None,
                core_attention_bias_type=core_attention_bias_type,
                core_attention_bias_shape=core_attention_bias_shape,
                core_attention_bias_requires_grad=(
                    core_attention_bias.requires_grad if core_attention_bias is not None else False
                ),
                pad_between_seqs=pad_between_seqs,
                attention_dropout=self.attention_dropout,
                context_parallel=context_parallel,
6176
6177
                deterministic=self.deterministic,
                is_training=self.training,
6178
6179
                fp8=self.fp8,
                fp8_meta=self.fp8_meta,
6180
                inference_params=inference_params,
6181
            )
6182
            global _attention_backends
6183
6184
6185
6186
6187
6188
6189
6190
6191
            if (
                _attention_backends["attention_params"] is None
                or attention_params != _attention_backends["attention_params"]
            ):
                _attention_backends["attention_params"] = attention_params
                _attention_backends["backend_selection_requires_update"] = True
            if _attention_backends["backend_selection_requires_update"]:
                (
                    use_flash_attention,
6192
                    flash_attention_backend,
6193
6194
6195
6196
                    use_fused_attention,
                    fused_attention_backend,
                    use_unfused_attention,
                    _,
6197
6198
6199
6200
                ) = dpa_utils.get_attention_backend(attention_params)
                # Set global _attention_backends var using return value
                # from get_attention_backend()
                _attention_backends["use_flash_attention"] = use_flash_attention
6201
                _attention_backends["flash_attention_backend"] = flash_attention_backend
6202
6203
6204
6205
                _attention_backends["use_fused_attention"] = use_fused_attention
                _attention_backends["fused_attention_backend"] = fused_attention_backend
                _attention_backends["use_unfused_attention"] = use_unfused_attention
                _attention_backends["backend_selection_requires_update"] = False
6206
                if use_flash_attention:
6207
6208
                    self.logger.info(
                        "Running with FlashAttention backend (version %s)",
6209
                        flash_attention_backend,
6210
                    )
6211
6212
6213
6214
                elif use_fused_attention:
                    self.logger.info(
                        "Running with FusedAttention backend (sub-backend %s)",
                        int(fused_attention_backend),
6215
                    )
6216
6217
6218
6219
                elif use_unfused_attention:
                    self.logger.info("Running with UnfusedDotProductAttention backend")
            else:
                use_flash_attention = _attention_backends["use_flash_attention"]
6220
                flash_attention_backend = _attention_backends["flash_attention_backend"]
6221
6222
6223
                use_fused_attention = _attention_backends["use_fused_attention"]
                fused_attention_backend = _attention_backends["fused_attention_backend"]
                use_unfused_attention = _attention_backends["use_unfused_attention"]
6224

6225
6226
            # raise exception if no backend is available
            if sum([use_flash_attention, use_fused_attention, use_unfused_attention]) == 0:
6227
6228
6229
6230
6231
                raise ValueError(
                    "No dot product attention backend is available for the provided inputs. Please"
                    " run with NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=2 to find out the reasons for"
                    " disabling all backends."
                )
6232
6233

            # run attention
6234
6235
            if use_flash_attention:
                if core_attention_bias_type == "alibi":
6236
6237
                    alibi_slopes, _ = dpa_utils.get_alibi(
                        _alibi_cache,
6238
6239
6240
6241
6242
6243
6244
6245
6246
6247
6248
6249
6250
6251
6252
6253
6254
6255
6256
                        query_layer.shape[-2],
                        max_seqlen_q,
                        max_seqlen_kv,
                        alibi_slopes=alibi_slopes,
                    )
                return self.flash_attention(
                    query_layer,
                    key_layer,
                    value_layer,
                    attention_mask=attention_mask,
                    qkv_layout=qkv_layout,
                    cu_seqlens_q=cu_seqlens_q,
                    cu_seqlens_kv=cu_seqlens_kv,
                    attn_mask_type=attn_mask_type,
                    window_size=window_size,
                    alibi_slopes=alibi_slopes,
                    cp_group=self.cp_group,
                    cp_global_ranks=self.cp_global_ranks,
                    cp_stream=self.cp_stream,
6257
                    cp_comm_type=self.cp_comm_type,
6258
6259
                    max_seqlen_q=max_seqlen_q,
                    max_seqlen_kv=max_seqlen_kv,
6260
6261
                    fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa,
                    fp8_meta=self.fp8_meta,
6262
                    quantizers=self.quantizers,
6263
6264
                    inference_params=inference_params,
                    flash_attention_backend=flash_attention_backend,
6265
                )
6266

6267
            if use_fused_attention:
6268
6269
                fu_core_attention_bias_type = core_attention_bias_type
                fu_core_attention_bias = core_attention_bias
6270
6271
6272
                if core_attention_bias_type == "alibi" and (
                    alibi_slopes is not None or max_seqlen_q != max_seqlen_kv
                ):
6273
                    fu_core_attention_bias_type = "post_scale_bias"
6274
6275
                    _, fu_core_attention_bias = dpa_utils.get_alibi(
                        _alibi_cache,
6276
6277
6278
6279
6280
                        query_layer.shape[-2],
                        max_seqlen_q,
                        max_seqlen_kv,
                        alibi_slopes=alibi_slopes,
                        bias_dtype=query_layer.dtype,
6281
                        bottom_right_alignment=attn_mask_type not in ["causal", "padding_causal"],
6282
                    )
6283
                # checkpoint_core_attention=False
6284
6285
6286
6287
6288
6289
6290
6291
6292
                if checkpoint_core_attention:
                    return self._checkpointed_attention_forward(
                        self.fused_attention,
                        query_layer,
                        key_layer,
                        value_layer,
                        qkv_layout=qkv_layout,
                        cu_seqlens_q=cu_seqlens_q,
                        cu_seqlens_kv=cu_seqlens_kv,
6293
6294
                        cu_seqlens_q_padded=cu_seqlens_q_padded,
                        cu_seqlens_kv_padded=cu_seqlens_kv_padded,
6295
6296
6297
6298
                        max_seqlen_q=max_seqlen_q,
                        max_seqlen_kv=max_seqlen_kv,
                        attn_mask_type=attn_mask_type,
                        attention_mask=attention_mask,
6299
                        window_size=window_size,
6300
6301
6302
6303
6304
6305
6306
                        fused_attention_backend=fused_attention_backend,
                        core_attention_bias_type=fu_core_attention_bias_type,
                        core_attention_bias=fu_core_attention_bias,
                        fast_zero_fill=fast_zero_fill,
                        cp_group=self.cp_group,
                        cp_global_ranks=self.cp_global_ranks,
                        cp_stream=self.cp_stream,
6307
                        cp_comm_type=self.cp_comm_type,
6308
6309
                        fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa,
                        fp8_meta=self.fp8_meta,
6310
                        quantizers=self.quantizers,
6311
                        pad_between_seqs=pad_between_seqs,
6312
                        inference_params=inference_params,
6313
6314
                    )
                return self.fused_attention(
6315
6316
6317
6318
6319
6320
                    query_layer,
                    key_layer,
                    value_layer,
                    qkv_layout=qkv_layout,
                    cu_seqlens_q=cu_seqlens_q,
                    cu_seqlens_kv=cu_seqlens_kv,
6321
6322
                    cu_seqlens_q_padded=cu_seqlens_q_padded,
                    cu_seqlens_kv_padded=cu_seqlens_kv_padded,
6323
6324
                    max_seqlen_q=max_seqlen_q,
                    max_seqlen_kv=max_seqlen_kv,
6325
6326
                    attn_mask_type=attn_mask_type,
                    attention_mask=attention_mask,
6327
                    window_size=window_size,
6328
                    fused_attention_backend=fused_attention_backend,
6329
6330
                    core_attention_bias_type=fu_core_attention_bias_type,
                    core_attention_bias=fu_core_attention_bias,
6331
6332
6333
6334
                    fast_zero_fill=fast_zero_fill,
                    cp_group=self.cp_group,
                    cp_global_ranks=self.cp_global_ranks,
                    cp_stream=self.cp_stream,
6335
                    cp_comm_type=self.cp_comm_type,
6336
6337
                    fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa,
                    fp8_meta=self.fp8_meta,
6338
                    quantizers=self.quantizers,
6339
                    pad_between_seqs=pad_between_seqs,
6340
                    inference_params=inference_params,
6341
                )
6342

6343
            from .cpu_offload import CPUOffloadEnabled
6344

6345
6346
6347
6348
6349
            if CPUOffloadEnabled:
                warnings.warn(
                    "Attention activation Offloading is only implemented"
                    "with Flash Attention and Fused Attention!"
                )
6350

6351
6352
6353
6354
6355
6356
6357
6358
6359
6360
6361
6362
            if use_unfused_attention:
                if checkpoint_core_attention:
                    return self._checkpointed_attention_forward(
                        self.unfused_attention,
                        query_layer,
                        key_layer,
                        value_layer,
                        qkv_layout=qkv_layout,
                        cu_seqlens_q=cu_seqlens_q,
                        cu_seqlens_kv=cu_seqlens_kv,
                        attn_mask_type=attn_mask_type,
                        attention_mask=attention_mask,
6363
                        window_size=window_size,
6364
6365
6366
                        core_attention_bias_type=core_attention_bias_type,
                        core_attention_bias=core_attention_bias,
                        alibi_slopes=alibi_slopes,
6367
                        inference_params=inference_params,
6368
6369
                    )
                return self.unfused_attention(
6370
6371
6372
                    query_layer,
                    key_layer,
                    value_layer,
6373
6374
6375
6376
6377
                    qkv_layout=qkv_layout,
                    cu_seqlens_q=cu_seqlens_q,
                    cu_seqlens_kv=cu_seqlens_kv,
                    attn_mask_type=attn_mask_type,
                    attention_mask=attention_mask,
6378
                    window_size=window_size,
6379
6380
6381
                    core_attention_bias_type=core_attention_bias_type,
                    core_attention_bias=core_attention_bias,
                    alibi_slopes=alibi_slopes,
6382
                    inference_params=inference_params,
6383
                )
6384
            return None
6385
6386


6387
6388
6389
6390
6391
6392
6393
class MultiheadAttention(torch.nn.Module):
    r"""
    Multi-head Attention (MHA), including Query,
    Key, Value and Output projection.

    .. note::

6394
6395
        Argument :attr:`attention_mask` in the `forward` call is only used when
        :attr:`attn_mask_type` includes '"padding"' or `"arbitrary"`.
6396

6397
6398
6399
6400
6401
6402
6403
6404
6405
6406
6407
6408
6409
6410
6411
6412
6413
6414
6415
6416
6417
6418
6419
6420
6421
    Parameters
    ----------
    hidden_size : int
                 size of each input sample.
    num_attention_heads : int
                         number of attention heads in the transformer layer.
    kv_channels: int, default = `None`
                number of key-value channels. defaults to
                :attr:`hidden_size` / :attr:`num_attention_heads` if `None`.
    attention_dropout: float, default = 0.1
                      dropout probability for the dropout op during multi-head attention.
    layernorm_epsilon : float, default = 1e-5
                       a value added to the denominator of layer normalization
                       for numerical stability.
    init_method : Callable, default = `None`
                 used for initializing weights of QKV and FC1 weights in the following way:
                 `init_method(weight)`. When set to `None`, defaults to
                 `torch.nn.init.normal_(mean=0.0, std=0.023)`.
    output_layer_init_method : Callable, default = `None`
                              used for initializing weights of PROJ and FC2 in the following way:
                              `output_layer_init_method(weight)`. When set to `None`, defaults to
                              `torch.nn.init.normal_(mean=0.0, std=0.023)`.
    layer_number: int, default = `None`
                 layer number of the current `TransformerLayer` when multiple such modules are
                 concatenated to form a transformer block.
6422
6423
    attn_mask_type: {'no_mask', 'padding', 'causal', 'padding_causal', 'causal_bottom_right',
                   'padding_causal_bottom_right','arbitrary'},
6424
                   default = `causal`
6425
6426
6427
6428
6429
                   type of attention mask passed into softmax operation. Overridden by
                   :attr:`attn_mask_type` in the `forward` method. The forward
                   arg is useful for dynamically changing mask types, e.g. a different
                   mask for training and inference. The init arg is useful for cases
                   involving compilation/tracing, e.g. ONNX export.
6430
6431
6432
6433
    window_size: Optional[Tuple[int, int]], default = `None`
                sliding window size for local attention, where query at position i attends to keys
                in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
                + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
6434
6435
6436
                window and causal mask specifically. Both `causal` and `causal_bottom_right` masks
                map to `window_size = (-1, 0)` and Transformer Engine distinguishes them based on
                `attn_mask_type`. Similar to :attr:`attn_mask_type`, `window_size` can
6437
                be overridden by :attr:`window_size` in `forward` as well.
6438
6439
6440
6441
6442
6443
6444
6445
6446
6447
6448
6449
6450
    num_gqa_groups : int, default = `None`
                         number of GQA groups in the transformer layer.
                         Grouped Query Attention is described in
                         `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
                         This only affects the keys and values, not the querys.
                         GQA-1 is equivalent to Multi-Query Attention
                         (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
                         is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
    return_layernorm_output : bool, default = `False`
                             if set to `True`, output of layernorm is returned from the forward
                             together with the output of the linear transformation.
                             Example use case: residual connection for transformer module is
                             taken post layernorm.
6451
6452
    input_layernorm: bool, default = `False`
                     if set to `True`, layer normalization to the input is applied.
6453
6454
6455
6456
6457
6458
6459
6460
6461
6462
6463
6464
6465
6466
6467
6468
6469
6470
6471
6472
    attention_type: { 'self', 'cross' }, default = 'self'
                   type of attention applied.
    zero_centered_gamma : bool, default = 'False'
                         if set to 'True', gamma parameter in LayerNorm is initialized to 0 and
                         the LayerNorm formula changes to

                         .. math::
                            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} *
                            (1 + \gamma) + \beta
    normalization : { 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm'
                   type of normalization applied.
    qkv_weight_interleaved : bool, default = `True`
                            if set to `False`, the QKV weight is interpreted as a concatenation of
                            query, key, and value weights along the `0th` dimension. The default
                            interpretation is that the individual `q`, `k`, and `v` weights for each
                            attention head are interleaved. This parameter is set to `False` when
                            using :attr:`fuse_qkv_params=False`.
    bias : bool, default = `True`
          if set to `False`, the transformer layer will not learn any additive biases.
    device : Union[torch.device, str], default = "cuda"
6473
          The device on which the parameters of the model will be allocated. It is the user's
6474
6475
          responsibility to ensure all parameters are moved to the GPU before running the
          forward pass.
6476
6477
6478
6479
6480
6481
6482
    qkv_format: str, default = `sbhd`
            dimension format for `query_layer`, `key_layer` and `value_layer`,
            {`sbhd`, `bshd`}. `s` stands for the sequence length, `b` batch size,
            `h` the number of heads and `d` head size. `sbhd` and `bshd` formats
            are used for when sequences in a batch are of equal length or padded to
            equal length. Please note that these formats do not reflect how
            tensors `query_layer`, `key_layer`, `value_layer` are laid out in memory.
6483
            For that, please use `get_qkv_layout` to gain the layout information.
6484
6485
6486
6487
6488
6489
6490
6491
6492
6493
6494
6495
6496
6497
6498
6499
6500
6501
6502
6503
6504
6505
6506
6507
6508
6509
6510
6511
6512
6513
6514
6515
6516
6517
6518
6519
6520
6521
6522
6523

    Parallelism parameters
    ----------------------
    set_parallel_mode : bool, default = `False`
                      if set to `True`, QKV and FC1 layers are used as Column Parallel
                      whereas PROJ and FC2 is used as Row Parallel as described
                      `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
    sequence_parallel : bool, default = `False`
                       if set to `True`, uses sequence parallelism.
    tp_group : ProcessGroup, default = `None`
              tensor parallel process group.
    tp_size : int, default = 1
             used as TP (tensor parallel) world size when TP groups are not formed during
             initialization. In this case, users must call the
             `set_tensor_parallel_group(tp_group)` method on the initialized module before the
             forward pass to supply the tensor parallel group needed for tensor and sequence
             parallel collectives.

    Optimization parameters
    -----------------------
    fuse_wgrad_accumulation : bool, default = 'False'
                             if set to `True`, enables fusing of creation and accumulation of
                             the weight gradient. When enabled, it is assumed that the weights
                             have an additional `main_grad` attribute (used instead of the
                             regular `grad`) which is a pre-allocated buffer of the correct
                             size to accumulate gradients in.
    params_dtype : torch.dtype, default = `torch.get_default_dtype()`
                  it controls the type used to allocate the initial parameters. Useful when
                  the model is trained with lower precision and the original FP32 parameters
                  would not fit in GPU memory.
    return_bias : bool, default = `False`
                 when set to `True`, this module will not apply the additive bias itself, but
                 instead return the bias value during the forward pass together with the
                 output of the linear transformation :math:`y = xA^T`. This is useful when
                 the bias addition can be fused to subsequent operations.
    fuse_qkv_params: bool, default = 'False'
                    if set to `True`, `TransformerLayer` module exposes a single fused
                    parameter for query-key-value. This enables optimizations such as QKV
                    fusion without concatentations/splits and also enables the argument
                    `fuse_wgrad_accumulation`.
6524
6525
6526
6527
6528
6529
    """

    def __init__(
        self,
        hidden_size: int,
        num_attention_heads: int,
6530
6531
6532
6533
6534
        kv_channels: Optional[int] = None,
        attention_dropout: float = 0.1,
        layernorm_epsilon: float = 1e-5,
        init_method: Optional[Callable] = None,
        output_layer_init_method: Optional[Callable] = None,
6535
        layer_number: Optional[int] = None,
6536
        attn_mask_type: str = "causal",
6537
        window_size: Optional[Tuple[int, int]] = None,
6538
6539
        tp_group: Optional[dist_group_type] = None,
        tp_size: int = 1,
6540
        num_gqa_groups: Optional[int] = None,
6541
6542
6543
        fuse_wgrad_accumulation: bool = False,
        get_rng_state_tracker: Optional[Callable] = None,
        sequence_parallel: bool = False,
6544
        params_dtype: Optional[torch.dtype] = None,
6545
        return_bias: bool = False,
6546
6547
6548
6549
6550
6551
6552
        return_layernorm_output: bool = False,
        input_layernorm: bool = False,
        attention_type: str = "self",
        set_parallel_mode: bool = False,
        fuse_qkv_params: bool = False,
        zero_centered_gamma: bool = False,
        qkv_weight_interleaved: bool = True,
6553
        ub_overlap_ag: bool = False,
6554
6555
6556
6557
        ub_overlap_rs: bool = False,
        ub_overlap_rs_dgrad: bool = False,
        ub_bulk_dgrad: bool = False,
        ub_bulk_wgrad: bool = False,
6558
        bias: bool = True,
6559
        normalization: str = "LayerNorm",
6560
        device: Union[torch.device, str] = "cuda",
6561
        qkv_format: str = "sbhd",
6562
6563
    ) -> None:
        super().__init__()
6564

6565
        self.qkv_format = qkv_format
6566
        self.attn_mask_type = attn_mask_type
6567
        self.window_size = dpa_utils.check_set_window_size(attn_mask_type, window_size)
6568
        self.layer_number = 1 if layer_number is None else layer_number
6569
6570
6571
6572
6573
        self.input_layernorm = input_layernorm
        self.attention_type = attention_type
        self.get_rng_state_tracker = get_rng_state_tracker
        self.tp_group = tp_group
        self.return_layernorm_output = return_layernorm_output
6574
        self.params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
6575
        self.num_attention_heads = num_attention_heads
6576
        self.return_bias = return_bias
6577
6578
        self.cp_size = 1
        self.cp_rank = 0
6579
6580
6581
6582
6583
6584
6585

        kv_channels = kv_channels if kv_channels else (hidden_size // num_attention_heads)

        if init_method is None:
            init_method = get_default_init_method()
        if output_layer_init_method is None:
            output_layer_init_method = get_default_init_method()
6586
6587
6588
6589
6590

        if not fuse_qkv_params:
            qkv_weight_interleaved = False
        self.qkv_weight_interleaved = qkv_weight_interleaved

6591
6592
6593
        assert attention_type in AttnTypes, f"attention_type {attention_type} not supported"
        if layer_number is not None:
            assert layer_number > 0, "layer_number must be a positive integer"
6594
6595
6596
6597
6598
6599

        tp_size = tp_size if tp_group is None else get_distributed_world_size(tp_group)
        self.tp_size = tp_size
        self.sequence_parallel = (tp_size > 1) and sequence_parallel

        self.num_attention_heads_per_partition = divide(num_attention_heads, tp_size)
6600
6601
6602
6603
6604
6605
6606
        self.num_gqa_groups = num_attention_heads if num_gqa_groups is None else num_gqa_groups
        assert (
            num_attention_heads % self.num_gqa_groups == 0
        ), "The number of attention heads must be divisible by the number of GQA groups!"
        assert (
            self.num_gqa_groups % tp_size == 0
        ), "The number of GQA groups must be divisible by tensor parallel size!"
6607
        self.num_gqa_groups_per_partition = int(self.num_gqa_groups // tp_size)
6608
6609
6610
6611

        self.hidden_size_per_attention_head = kv_channels
        self.hidden_size_q = self.hidden_size_per_attention_head * num_attention_heads
        self.hidden_size_kv = self.hidden_size_per_attention_head * self.num_gqa_groups
6612
6613
6614
6615
6616
6617
6618

        common_gemm_kwargs = {
            "fuse_wgrad_accumulation": fuse_wgrad_accumulation,
            "tp_group": tp_group,
            "tp_size": tp_size,
            "get_rng_state_tracker": get_rng_state_tracker,
            "sequence_parallel": sequence_parallel,
6619
            "params_dtype": self.params_dtype,
6620
            "device": device,
6621
6622
6623
6624
        }

        qkv_parallel_mode = "column" if set_parallel_mode else None

cyanguwa's avatar
cyanguwa committed
6625
        if self.attention_type == "self":
6626
6627
            parameters_split = None
            if not fuse_qkv_params:
6628
6629
6630
6631
6632
6633
6634
                parameters_split = collections.OrderedDict(
                    [
                        ("query", self.hidden_size_q),
                        ("key", self.hidden_size_kv),
                        ("value", self.hidden_size_kv),
                    ]
                )
6635
6636
6637
            if self.input_layernorm:
                self.layernorm_qkv = LayerNormLinear(
                    hidden_size,
6638
                    self.hidden_size_q + 2 * self.hidden_size_kv,
6639
6640
6641
6642
6643
6644
                    eps=layernorm_epsilon,
                    init_method=init_method,
                    bias=bias,
                    return_bias=False,
                    parallel_mode=qkv_parallel_mode,
                    return_layernorm_output=return_layernorm_output,
cyanguwa's avatar
cyanguwa committed
6645
                    parameters_split=parameters_split,
6646
6647
6648
                    zero_centered_gamma=zero_centered_gamma,
                    ub_bulk_wgrad=ub_bulk_wgrad,
                    ub_bulk_dgrad=ub_bulk_dgrad,
Jaemin Choi's avatar
Jaemin Choi committed
6649
                    ub_overlap_rs_dgrad=ub_overlap_rs_dgrad,
6650
                    ub_overlap_ag=ub_overlap_ag,
6651
                    normalization=normalization,
Przemyslaw Tredak's avatar
Przemyslaw Tredak committed
6652
                    ub_name="qkv",
6653
6654
6655
6656
6657
                    **common_gemm_kwargs,
                )
            else:
                self.qkv = Linear(
                    hidden_size,
6658
                    self.hidden_size_q + 2 * self.hidden_size_kv,
6659
6660
6661
6662
                    init_method=init_method,
                    bias=bias,
                    return_bias=False,
                    parallel_mode=qkv_parallel_mode,
cyanguwa's avatar
cyanguwa committed
6663
                    parameters_split=parameters_split,
6664
6665
                    **common_gemm_kwargs,
                )
cyanguwa's avatar
cyanguwa committed
6666
        elif self.attention_type == "cross":
6667
6668
6669
            if self.input_layernorm:
                self.layernorm_query = LayerNormLinear(
                    hidden_size,
6670
                    self.hidden_size_q,
6671
6672
6673
6674
6675
                    eps=layernorm_epsilon,
                    init_method=init_method,
                    bias=bias,
                    return_bias=False,
                    parallel_mode=qkv_parallel_mode,
6676
                    parameters_split=("query",) if not fuse_qkv_params else None,
6677
6678
6679
6680
                    return_layernorm_output=return_layernorm_output,
                    zero_centered_gamma=zero_centered_gamma,
                    ub_bulk_wgrad=ub_bulk_wgrad,
                    ub_bulk_dgrad=ub_bulk_dgrad,
Jaemin Choi's avatar
Jaemin Choi committed
6681
                    ub_overlap_rs_dgrad=ub_overlap_rs_dgrad,
6682
                    ub_overlap_ag=ub_overlap_ag,
6683
                    normalization=normalization,
Przemyslaw Tredak's avatar
Przemyslaw Tredak committed
6684
                    ub_name="qkv",
6685
6686
6687
6688
6689
                    **common_gemm_kwargs,
                )
            else:
                self.query_layer = Linear(
                    hidden_size,
6690
                    self.hidden_size_q,
6691
6692
6693
6694
6695
6696
6697
6698
                    init_method=init_method,
                    bias=bias,
                    return_bias=False,
                    parallel_mode=qkv_parallel_mode,
                    **common_gemm_kwargs,
                )
            self.key_value = Linear(
                hidden_size,
6699
                2 * self.hidden_size_kv,
6700
6701
6702
6703
                init_method=init_method,
                bias=bias,
                return_bias=False,
                parallel_mode=qkv_parallel_mode,
6704
                parameters_split=("key", "value") if not fuse_qkv_params else None,
6705
6706
6707
6708
6709
6710
                **common_gemm_kwargs,
            )

        # Attention.
        self.core_attention = DotProductAttention(
            num_attention_heads,
6711
            self.hidden_size_per_attention_head,
6712
6713
            num_gqa_groups=self.num_gqa_groups,
            attention_dropout=attention_dropout,
6714
            qkv_format=self.qkv_format,
6715
6716
6717
6718
            tp_size=tp_size,
            get_rng_state_tracker=get_rng_state_tracker,
            sequence_parallel=sequence_parallel,
            tp_group=tp_group,
6719
            layer_number=self.layer_number,
6720
            attention_type=self.attention_type,
6721
6722
6723
6724
        )

        # Linear
        self.proj = Linear(
6725
            self.hidden_size_q,
6726
6727
6728
            hidden_size,
            init_method=output_layer_init_method,
            bias=bias,
6729
            return_bias=return_bias,
6730
            parallel_mode="row" if set_parallel_mode else None,
6731
6732
            ub_overlap_rs=ub_overlap_rs,
            ub_overlap_ag=ub_overlap_ag,
Przemyslaw Tredak's avatar
Przemyslaw Tredak committed
6733
            ub_name="proj",
6734
6735
6736
6737
            **common_gemm_kwargs,
        )

    def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None:
6738
6739
6740
6741
6742
6743
6744
6745
6746
        """
        Set the tensor parallel group for the given
        module before executing the forward pass.

        Parameters
        ----------
        tp_group : ProcessGroup, default = `None`
                  tensor parallel process group.
        """
6747
6748
        self.tp_group = tp_group

6749
    def set_context_parallel_group(
6750
        self,
6751
        cp_group: Union[dist_group_type, List[dist_group_type], None],
6752
        cp_global_ranks: List[int],
6753
        cp_stream: torch.cuda.Stream,
6754
        cp_comm_type: str = "p2p",
6755
    ) -> None:
6756
6757
6758
6759
6760
6761
        """
        Set the context parallel attributes for the given
        module before executing the forward pass.

        Parameters
        ----------
6762
        cp_group : Union[ProcessGroup, List[ProcessGroup]]
6763
                  context parallel process group.
6764
6765
6766
                  ProcessGroup is for cp_comm_type of "p2p", "all_gather", and "a2a".
                  List[ProcessGroup] is for cp_comm_type of "a2a+p2p", where cp_group[0]
                  and cp_group[1] are for a2a and p2p communications respectively.
6767
6768
6769
6770
        cp_global_ranks : List[int]
                         list of global ranks in the context group.
        cp_stream : torch.cuda.Stream
                   cuda stream for context parallel execution.
6771
        cp_comm_type : str, default = `p2p`
6772
                      inter-gpu communication type for context parallelism.
6773
                      Can be "p2p" or "all_gather" or "a2a", "a2a+p2p".
6774
6775
6776
6777
6778
6779
                      "p2p": Exchange KV chunks with P2P communications in ring topology.
                             P2P is async and can be overlapped with attention compute.
                      "all_gather": All-gather to get full sequence of KV before attention.
                                    The all-gather is not async, and cannot be overlapped.
                      "a2a": Like DeepSpeed Ulysses, scatter attention heads across the CP
                             group, and gather to get full sequence of QKV.
6780
6781
6782
                      "a2a+p2p": hierarchical CP implementation. First applying a2a to QKV
                      across each CP sub-group (e.g., via NVLink), then exchanging KV with
                      p2p between sub-groups (e.g., via IBLink).
6783
        """
6784
6785
6786
6787
6788
6789
6790
6791
6792
6793
6794
6795
6796
6797
6798
        if isinstance(cp_group, dist_group_type):
            self.cp_size = get_distributed_world_size(cp_group)
            self.cp_rank = get_distributed_rank(cp_group)
        elif isinstance(cp_group, list):
            assert len(cp_group) == 2, "Current implementation only supports two-level CP groups!"
            assert (
                cp_comm_type == "a2a+p2p"
            ), "Only cp_comm_type of a2a+p2p requires hierarchical CP groups!"
            cp_size_a2a = get_distributed_world_size(cp_group[0])
            cp_rank_a2a = get_distributed_rank(cp_group[0])
            cp_size_p2p = get_distributed_world_size(cp_group[1])
            cp_rank_p2p = get_distributed_rank(cp_group[1])
            self.cp_size = cp_size_a2a * cp_size_p2p
            self.cp_rank = cp_size_a2a * cp_rank_p2p + cp_rank_a2a

6799
6800
6801
6802
6803
        # Deep iterate but skip self to avoid infinite recursion.
        for index, child in enumerate(self.modules()):
            if index == 0:
                continue
            if hasattr(child, "set_context_parallel_group"):
6804
                child.set_context_parallel_group(cp_group, cp_global_ranks, cp_stream, cp_comm_type)
6805

6806
6807
6808
    def forward(
        self,
        hidden_states: torch.Tensor,
6809
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
6810
        encoder_output: Optional[torch.Tensor] = None,
6811
        attn_mask_type: Optional[str] = None,
6812
        window_size: Optional[Tuple[int, int]] = None,
6813
6814
        is_first_microbatch: Optional[bool] = None,
        checkpoint_core_attention: bool = False,
6815
        inference_params: Optional[InferenceParams] = None,
6816
        rotary_pos_emb: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
6817
6818
        core_attention_bias_type: str = "no_bias",
        core_attention_bias: Optional[torch.Tensor] = None,
6819
        alibi_slopes: Optional[torch.Tensor] = None,
6820
6821
6822
6823
        cu_seqlens_q: Optional[torch.Tensor] = None,
        cu_seqlens_kv: Optional[torch.Tensor] = None,
        max_seqlen_q: Optional[int] = None,
        max_seqlen_kv: Optional[int] = None,
6824
        fast_zero_fill: bool = True,
6825
        pad_between_seqs: Optional[bool] = None,
6826
    ) -> Tuple[Union[torch.Tensor, None], ...]:
6827
6828
6829
6830
6831
        """
        Forward propagation for MultiheadAttention layer.

        .. note::

6832
6833
            Argument :attr:`attention_mask` is only used when :attr:`attn_mask_type`
            includes `"padding"` or `"arbitrary"`.
6834
6835
6836
6837
6838

        Parameters
        ----------
        hidden_states : torch.Tensor
             Input tensor.
6839
6840
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]],
             default = `None`. Boolean tensor(s) used to mask out attention softmax input.
6841
             It should be `None` for causal masks and "`no_mask`". For padding masks, it should be
6842
6843
             a single tensor of [batch_size, 1, 1, seqlen_q] for self-attention, and a tuple of
             two tensors in shapes [batch_size, 1, 1, seqlen_q] and [batch_size, 1, 1, seqlen_kv]
6844
6845
6846
6847
6848
6849
             for cross-attention. For "`arbitrary`" mask, it should be in a shape broadcastable to
             [batch_size, num_heads, max_seqlen_q, max_seqlen_kv]. A `True` value means
             the corresponding position is masked out and a `False` means that position
             is allowed to participate in attention.
        attn_mask_type: {'no_mask', 'padding', 'causal', 'padding_causal', 'causal_bottom_right',
                       'padding_causal_bottom_right','arbitrary'},
6850
                       default = `None`
6851
6852
6853
6854
                       type of attention mask passed into softmax operation. By default,
                       causal masks are aligned to the top left corner of the softmax matrix.
                       When "`bottom_right`" is specified in the mask type, causal masks are
                       aligned to the bottom right corner.
6855
6856
        window_size: Optional[Tuple[int, int]], default = `None`
                    sliding window size for local attention.
6857
6858
6859
6860
6861
6862
6863
6864
6865
6866
6867
6868
6869
6870
6871
6872
6873
6874
6875
6876
6877
6878
6879
6880
6881
        encoder_output : Optional[torch.Tensor], default = `None`
             Output of the encoder block to be fed into the decoder block if using
             `layer_type="decoder"`.
        is_first_microbatch : {True, False, None}, default = None
                             During training using either gradient accumulation or
                             pipeline parallelism a minibatch of data is further split
                             into microbatches. Between the microbatches of the same minibatch
                             the model weights are not updated. Setting this parameter indicates
                             whether the current microbatch is the first in a minibatch or not.
                             When set, this parameter enables additional optimizations:

                             * during FP8 training, it allows caching of the FP8 versions of
                               the weights
                             * it also allows skipping gradient accumulation during the
                               first microbatch (since it is the first gradient being
                               produced)
        checkpoint_core_attention: bool, default = `False`
                                  If true, forward activations for core attention are recomputed
                                  during the backward pass in order to save memory that would
                                  otherwise be occupied to store the forward activations until
                                  backprop.
        rotary_pos_emb: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], default = `None`
                       Embeddings for query and key tensors for applying rotary position
                       embedding. By default no input embedding is applied.
        core_attention_bias_type: str, default = `no_bias`
6882
                    Bias type, {`no_bias`, `pre_scale_bias`, 'post_scale_bias`, `alibi`}
6883
        core_attention_bias: Optional[torch.Tensor], default = `None`
6884
6885
                    Bias tensor for Q * K.T, shape [1, num_head, max_seqlen_q, max_seqlen_kv].
                    It should be 'None' for 'no_bias' and 'alibi' bias types.
6886
6887
6888
6889
        alibi_slopes: Optional[torch.Tensor], default = `None`
                     ALiBi slopes in FP32 and shape [nheads] or [batch_size, nheads].
                     It adds a bias of (-alibi_slope * (i + seqlen_k - seqlen_q - j))
                     to the attention score of query i and key j.
6890
6891
6892
6893
6894
6895
6896
6897
6898
6899
6900
6901
        cu_seqlens_q: Optional[torch.Tensor], default = `None`
                   Cumulative sum of sequence lengths (without offset) in a batch for `query_layer`,
                   with shape [batch_size + 1] and dtype torch.int32.
        cu_seqlens_kv: Optional[torch.Tensor], default = `None`
                   Cumulative sum of sequence lengths (without offset) in a batch for `key_layer`
                   and `value_layer`, with shape [batch_size + 1] and dtype torch.int32.
        max_seqlen_q: Optional[int], default = `None`
                      Maximum sequence length in `query_layer`.
                      Calculated from `cu_seqlens_q` if not provided.
        max_seqlen_kv: Optional[int], default = `None`
                       Maximum sequence length in `key_layer` and `value_layer`.
                       Calculated from `cu_seqlens_kv` if not provided.
6902
6903
        fast_zero_fill: bool, default = `True`
                    Whether to set output tensors to 0 or not before use.
6904
6905
6906
        pad_between_seqs: Optional[bool], default = `None`
            If None, inferred from qkv_format, cu_seqlens and cu_seqlens_padded.
            If true, there are padding tokens between individual sequences in a packed batch.
6907
        """
6908
6909
        # hidden_states: [sq, b, h]

6910
        if attn_mask_type is None:
6911
            attn_mask_type = self.attn_mask_type
6912
6913
        if window_size is None:
            window_size = self.window_size
6914
        window_size = dpa_utils.check_set_window_size(attn_mask_type, window_size)
6915

6916
        if "padding" in attn_mask_type and attention_mask is not None:
6917
6918
            for mask in attention_mask:
                assert mask.dtype == torch.bool, "Attention mask must be in boolean type!"
6919

6920
6921
6922
        assert (
            core_attention_bias_type in AttnBiasTypes
        ), f"core_attention_bias_type {core_attention_bias_type} is not supported!"
6923

6924
        # =================================================
6925
        # Pre-allocate memory for key-value cache for inference
6926
6927
        # =================================================

6928
6929
6930
6931
6932
        if (
            inference_params is not None
            and self.layer_number not in inference_params.cache_manager.cache
        ):
            inference_params.allocate_memory(self.layer_number)
6933

6934
        # ======================
6935
        # Query, Key, and Value
6936
        # ======================
6937

6938
6939
6940
6941
6942
        fp8_mha = (
            FP8GlobalStateManager.is_fp8_enabled()
            and FP8GlobalStateManager.get_fp8_recipe().fp8_mha
        )

6943
        layernorm_output = None
cyanguwa's avatar
cyanguwa committed
6944
6945
        if self.attention_type == "self":
            # Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn]
6946
6947
6948
6949
            if self.input_layernorm:
                layernorm_qkv_outputs = self.layernorm_qkv(
                    hidden_states,
                    is_first_microbatch=is_first_microbatch,
6950
                    fp8_output=fp8_mha and rotary_pos_emb is None,
6951
6952
6953
6954
6955
6956
6957
6958
6959
                )
                if self.return_layernorm_output:
                    mixed_x_layer, layernorm_output = layernorm_qkv_outputs
                else:
                    mixed_x_layer = layernorm_qkv_outputs
            else:
                mixed_x_layer = self.qkv(
                    hidden_states,
                    is_first_microbatch=is_first_microbatch,
6960
                    fp8_output=fp8_mha and rotary_pos_emb is None,
6961
6962
                )

6963
6964
6965
            num_queries_per_key_value = (
                self.num_attention_heads_per_partition // self.num_gqa_groups_per_partition
            )
6966
            if self.qkv_weight_interleaved:
cyanguwa's avatar
cyanguwa committed
6967
                # [sq, b, ng * (np/ng + 2) * hn] --> [sq, b, ng, (np/ng + 2), hn]
6968
                new_tensor_shape = mixed_x_layer.size()[:-1] + (
cyanguwa's avatar
cyanguwa committed
6969
6970
                    self.num_gqa_groups_per_partition,
                    (num_queries_per_key_value + 2),
6971
6972
6973
6974
                    self.hidden_size_per_attention_head,
                )
                # split along second last dimension
                split_dim = -2
cyanguwa's avatar
cyanguwa committed
6975
6976
6977
6978
6979
            else:
                # [sq, b, ng * (np/ng + 2) * hn] --> [sq, b, (np/ng + 2), ng, hn]
                new_tensor_shape = mixed_x_layer.size()[:-1] + (
                    (num_queries_per_key_value + 2),
                    self.num_gqa_groups_per_partition,
6980
                    self.hidden_size_per_attention_head,
cyanguwa's avatar
cyanguwa committed
6981
6982
6983
                )
                # split along third last dimension
                split_dim = -3
6984
6985
6986

            mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)

cyanguwa's avatar
cyanguwa committed
6987
6988
6989
6990
6991
6992
            # qkv_weight_interleaved:
            #  [sq, b, ng, (np/ng + 2), hn]
            #  --> [sq, b, ng, np/ng, hn], [sq, b, ng, 1, hn], [sq, b, ng, 1, hn]
            # not qkv_weight_interleaved:
            #  [sq, b, (np/ng + 2), ng, hn]
            #  --> [sq, b, np/ng, np, hn], [sq, b, 1, ng, hn], [sq, b, 1, ng, hn]
6993
6994
6995
            query_layer, key_layer, value_layer = _SplitAlongDim.apply(
                mixed_x_layer, split_dim, (num_queries_per_key_value, 1, 1)
            )
cyanguwa's avatar
cyanguwa committed
6996

6997
6998
6999
7000
7001
7002
7003
7004
7005
7006
7007
7008
            if self.qkv_format == "thd":
                query_layer, key_layer, value_layer = (
                    x.reshape(x.size(0), -1, self.hidden_size_per_attention_head)
                    for x in (query_layer, key_layer, value_layer)
                )
            else:
                # query: -> [sq, b, np, hn]
                # key, value: -> [sq, b, ng, hn]
                query_layer, key_layer, value_layer = (
                    x.reshape(x.size(0), x.size(1), -1, self.hidden_size_per_attention_head)
                    for x in (query_layer, key_layer, value_layer)
                )
cyanguwa's avatar
cyanguwa committed
7009
7010
        elif self.attention_type == "cross":
            # Attention heads [sk, b, h] --> [sk, b, (ng * 2 * hn)]
7011
            mixed_kv_layer = self.key_value(
cyanguwa's avatar
cyanguwa committed
7012
                encoder_output,
7013
                is_first_microbatch=is_first_microbatch,
7014
                fp8_output=fp8_mha and rotary_pos_emb is None,
7015
7016
7017
            )

            if self.qkv_weight_interleaved:
cyanguwa's avatar
cyanguwa committed
7018
                # [sq, b, (ng * 2 * hn)] --> [sq, b, ng, 2 * hn]
7019
                new_tensor_shape = mixed_kv_layer.size()[:-1] + (
7020
                    self.num_gqa_groups_per_partition,
7021
7022
7023
7024
7025
                    2 * self.hidden_size_per_attention_head,
                )
                # split along last dimension
                split_dim = -1
            else:
cyanguwa's avatar
cyanguwa committed
7026
                # [sq, b, (ng * 2 * hn)] --> [sq, b, 2 * ng, hn]
7027
                new_tensor_shape = mixed_kv_layer.size()[:-1] + (
7028
                    2 * self.num_gqa_groups_per_partition,
7029
7030
7031
7032
7033
7034
7035
                    self.hidden_size_per_attention_head,
                )
                # split along second last dimension
                split_dim = -2

            mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)

cyanguwa's avatar
cyanguwa committed
7036
            # mixed_kv_layer --> 2 [sk, b, ng, hn]
7037
7038
7039
7040
7041
            key_layer, value_layer = _SplitAlongDim.apply(
                mixed_kv_layer,
                split_dim,
                mixed_kv_layer.shape[split_dim] // 2,
            )
7042
7043
7044
7045
7046
7047
7048
7049
7050
            key_layer, value_layer = (
                x.reshape(
                    x.size(0),
                    x.size(1),
                    -1,
                    self.hidden_size_per_attention_head,
                )
                for x in (key_layer, value_layer)
            )
7051
7052
7053
7054
7055
7056

            # Attention head [sq, b, h] --> [sq, b, hp]
            if self.input_layernorm:
                layernorm_query_outputs = self.layernorm_query(
                    hidden_states,
                    is_first_microbatch=is_first_microbatch,
7057
                    fp8_output=fp8_mha and rotary_pos_emb is None,
7058
7059
7060
7061
7062
7063
7064
7065
7066
                )
                if self.return_layernorm_output:
                    query_layer, layernorm_output = layernorm_query_outputs
                else:
                    query_layer = layernorm_query_outputs
            else:
                query_layer = self.query_layer(
                    hidden_states,
                    is_first_microbatch=is_first_microbatch,
7067
                    fp8_output=fp8_mha and rotary_pos_emb is None,
7068
7069
7070
7071
7072
7073
7074
7075
7076
                )

            # [sq, b, hp] --> [sq, b, np, hn]
            new_tensor_shape = query_layer.size()[:-1] + (
                self.num_attention_heads_per_partition,
                self.hidden_size_per_attention_head,
            )
            query_layer = query_layer.view(*new_tensor_shape)

7077
7078
7079
        # ======================================================
        # Apply relative positional encoding (rotary embedding)
        # ======================================================
7080

7081
        if rotary_pos_emb is not None:
7082
7083
7084
            assert not isinstance(query_layer, Float8Tensor) and not isinstance(
                key_layer, Float8Tensor
            ), "RoPE is not supported for Float8Tensors!"
7085
            # duplicate the pos_emb for self attention
7086
            if not isinstance(rotary_pos_emb, tuple):
7087
                rotary_pos_emb = (rotary_pos_emb,) * 2
7088
7089

            q_pos_emb, k_pos_emb = rotary_pos_emb
7090
7091
7092
7093
7094
7095
7096

            # adjust key and value for inference
            if inference_params is not None:
                if self.qkv_format == "sbhd":
                    sequence_length = key_layer.size(0)
                elif self.qkv_format == "bshd":
                    sequence_length = key_layer.size(1)
7097
                else:
7098
7099
7100
                    raise ValueError(
                        f"qkv_format={self.qkv_format} not supported for KV caching and RoPE."
                    )
7101

7102
7103
                sequence_start = inference_params.get_seqlens_pre_step()
                # sequence_start = inference_params.seqlens[0]
7104
7105
7106
7107
7108
                sequence_end = sequence_start + sequence_length

                q_pos_emb = q_pos_emb[sequence_start:sequence_end, ...]
                k_pos_emb = k_pos_emb[sequence_start:sequence_end, ...]

7109
7110
7111
7112
7113
7114
7115
7116
7117
7118
7119
7120
7121
7122
7123
7124
7125
7126
            query_layer = apply_rotary_pos_emb(
                query_layer,
                q_pos_emb,
                self.qkv_format,
                fused=True,
                cu_seqlens=cu_seqlens_q,
                cp_size=self.cp_size,
                cp_rank=self.cp_rank,
            )
            key_layer = apply_rotary_pos_emb(
                key_layer,
                k_pos_emb,
                self.qkv_format,
                fused=True,
                cu_seqlens=cu_seqlens_kv,
                cp_size=self.cp_size,
                cp_rank=self.cp_rank,
            )
7127

7128
7129
7130
7131
        # ===========================
        # Core attention computation
        # ===========================

7132
7133
7134
7135
        context_layer = self.core_attention(
            query_layer,
            key_layer,
            value_layer,
7136
            qkv_format=self.qkv_format,
7137
7138
7139
7140
            cu_seqlens_q=cu_seqlens_q,
            cu_seqlens_kv=cu_seqlens_kv,
            max_seqlen_q=max_seqlen_q,
            max_seqlen_kv=max_seqlen_kv,
7141
7142
            attention_mask=attention_mask,
            attn_mask_type=attn_mask_type,
7143
            window_size=window_size,
7144
7145
7146
            checkpoint_core_attention=checkpoint_core_attention,
            core_attention_bias_type=core_attention_bias_type,
            core_attention_bias=core_attention_bias,
7147
            alibi_slopes=alibi_slopes,
7148
            fast_zero_fill=fast_zero_fill,
7149
            inference_params=inference_params,
7150
            pad_between_seqs=pad_between_seqs,
7151
7152
        )

7153
        # ===================
7154
        # Output. [sq, b, h]
7155
        # ===================
7156
        projection_output = self.proj(
7157
7158
            context_layer,
            is_first_microbatch=is_first_microbatch,
7159
            fp8_grad=isinstance(context_layer, QuantizedTensor),
7160
7161
        )

7162
7163
7164
7165
7166
7167
7168
7169
        if self.return_bias:
            attention_output, attention_bias = projection_output
        else:
            attention_output, attention_bias = projection_output, None

        outputs = (attention_output,)
        if self.return_bias:
            outputs += (attention_bias,)
7170
        if self.input_layernorm and self.return_layernorm_output:
7171
7172
            outputs += (layernorm_output,)
        return outputs if len(outputs) > 1 else outputs[0]