attention.py 263 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.

"""Attention."""
6
import collections
7
from contextlib import nullcontext
8
from importlib.metadata import version as get_pkg_version
9
import math
10
import os
11
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
12
import warnings
13
import logging
14

cyanguwa's avatar
cyanguwa committed
15
import numpy as np
16
from packaging.version import Version as PkgVersion
17
18

import torch
19
import torch.nn.functional as F
20

21
import transformer_engine_torch as tex
22
23
import transformer_engine as te
from transformer_engine.pytorch.utils import get_cudnn_version
24
25
26
27
from transformer_engine.pytorch.cpp_extensions import (
    cast_to_fp8,
    cast_from_fp8,
)
28
29
30
31
32
from transformer_engine.pytorch.cpp_extensions.fused_attn import (
    fused_attn_fwd_qkvpacked,
    fused_attn_bwd_qkvpacked,
    fused_attn_fwd_kvpacked,
    fused_attn_bwd_kvpacked,
33
34
    fused_attn_fwd,
    fused_attn_bwd,
35
36
37
38
39
    QKVLayout,
    AttnBiasType,
    AttnMaskType,
    FusedAttnBackend,
)
40
41
from transformer_engine.pytorch.fp8 import get_fp8_te_dtype
from transformer_engine.pytorch.float8_tensor import Float8Tensor
42
from transformer_engine.pytorch.module import LayerNormLinear, Linear
43
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
44
45
46
47
48
from transformer_engine.pytorch.utils import (
    divide,
    attention_mask_func,
    split_tensor_along_dim,
    get_device_compute_capability,
49
    get_default_init_method,
50
51
52
53
)
from transformer_engine.pytorch.constants import (
    AttnMaskTypes,
    AttnTypes,
54
    AttnBiasTypes,
55
    QKVLayouts,
56
    dist_group_type,
57
    TE_DType,
58
59
60
61
)
from transformer_engine.pytorch.softmax import FusedScaleMaskSoftmax
from transformer_engine.pytorch.distributed import (
    get_distributed_world_size,
62
    get_distributed_rank,
63
    checkpoint,
64
65
66
    set_all_rng_states,
    CudaRNGStatesTracker,
    graph_safe_rng_available,
67
68
)
from transformer_engine.pytorch.export import is_in_onnx_export_mode
69
from transformer_engine.pytorch.jit import jit_fuser, no_torch_dynamo
70
71
from transformer_engine.pytorch.graph import is_graph_capturing

72

73
74
75
76
77
78
79
_flash_attn_version = PkgVersion(get_pkg_version("flash-attn"))
_flash_attn_version_required = PkgVersion("2.0.6")
_flash_attn_max_version = PkgVersion("2.5.8")
_flash_attn_2_1_plus = _flash_attn_version >= PkgVersion("2.1")
_flash_attn_2_3_plus = _flash_attn_version >= PkgVersion("2.3")
_flash_attn_2_4_plus = _flash_attn_version >= PkgVersion("2.4")
_flash_attn_2_4_1_plus = _flash_attn_version >= PkgVersion("2.4.1")
80

81
if _flash_attn_version >= _flash_attn_version_required:
82
83
84
85
    from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_forward_func
    from flash_attn.flash_attn_interface import _flash_attn_varlen_forward as _flash_attn_forward
    from flash_attn.flash_attn_interface import _flash_attn_varlen_backward as _flash_attn_backward
    from flash_attn_2_cuda import varlen_bwd as flash_attn_cuda_bwd
86

87
META_QKV = tex.FP8FwdTensors.GEMM1_OUTPUT
88
META_DQKV = tex.FP8BwdTensors.GRAD_OUTPUT1
89
90
91
92
META_O = tex.FP8FwdTensors.GEMM2_INPUT
META_DO = tex.FP8BwdTensors.GRAD_INPUT2
META_S = tex.FP8FwdTensors.GEMM3_OUTPUT
META_DP = tex.FP8BwdTensors.GRAD_INPUT3
93

94
# NVTE_DEBUG = 0/1 # disables/enables debug mode, default = 0
95
_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
96
97
98
99
100
# NVTE_DEBUG_LEVEL = 0/1/2 # enables more and more verbose debug mode, default = 0
_NVTE_DEBUG_LEVEL = int(os.getenv("NVTE_DEBUG_LEVEL", "0"))
log_level = _NVTE_DEBUG * _NVTE_DEBUG_LEVEL
log_levels = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG}
logging.basicConfig(
101
102
    format="[%(levelname)-8s | %(name)-19s]: %(message)s",
    level=log_levels[log_level if log_level in [0, 1, 2] else 2],
103
104
)

105
106
107
108
109
110
111
112
_alibi_cache = {
    "_num_heads": None,
    "_alibi_slopes": None,
    "_max_seqlen_q": None,
    "_max_seqlen_kv": None,
    "_alibi_bias": None,
    "_alibi_slopes_require_update": False,
    "_alibi_bias_require_update": False,
113
}
114
115


116
117
__all__ = ["DotProductAttention", "InferenceParams", "MultiheadAttention"]

118

119
class InferenceParams:  # pylint: disable=too-few-public-methods
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
    """
    Inference parameters that are passed to the main model in order
    to efficienly calculate and store the context during inference.

    Parameters
    ----------
    max_batch_size : int
                    maximum batch size during inference.
    max_sequence_length : int
                         maximum sequence length during inference.
    """

    def __init__(self, max_batch_size, max_sequence_length):
        self.max_sequence_length = max_sequence_length
        self.max_batch_size = max_batch_size
        self.sequence_len_offset = 0
        self.batch_size_offset = 0
        self.key_value_memory_dict = {}

    def swap_key_value_dict(self, batch_indices):
        """
        Reorders the KV cache using the specified batch indices.

        Parameters
        ----------
        batch_indices : List[int]
                       Sequence of indices to reorder along the batch dimensions of
                       the KV cache. Must have a length equal to the batch size.
        """
        if len(self.key_value_memory_dict) == 0:
            raise ValueError("should not swap when dict in empty")

        for layer_number, inference_memory in self.key_value_memory_dict.items():
            inference_key_memory, inference_value_memory = inference_memory
            assert (
                len(batch_indices) == inference_key_memory.shape[1]
            )  # make sure batch size is the same
            new_inference_key_memory = inference_key_memory[:, batch_indices]
            new_inference_value_memory = inference_value_memory[:, batch_indices]
            self.key_value_memory_dict[layer_number] = (
                new_inference_key_memory,
                new_inference_value_memory,
            )
163

164

165
166
167
168
169
@torch.no_grad()
def get_alibi(
    num_heads: int,
    max_seqlen_q: int,
    max_seqlen_kv: int,
170
171
172
    alibi_slopes: Optional[torch.Tensor] = None,
    bias_dtype: Optional[torch.dtype] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
173
    """
174
175
176
177
178
179
180
181
182
183
184
185
    Parameters
    ----------
    num_heads: int
        Number of heads.
    max_seqlen_q: int
        Maximum sequence length for queries.
    max_seqlen_kv: int
        Maximum sequence length for keys and values.
    alibi_slopes: Optional[torch.Tensor], default = `None`
        Custom ALiBi slopes, FP32, CUDA tensor, in shape [num_heads] or [batch_size, num_heads].
    bias_dtype: Optional[torch.dtype], default = `None`
        Dtype of the generated ALiBi bias. If None, use torch.float32.
186

187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
    Returns
    ----------
    alibi_slopes: torch.Tensor
        ALiBi slopes in FP32 and shape [num_heads] or [batch_size, num_heads].
    alibi_bias: torch.Tensor
        ALiBi bias in FP32 or `bias_dtype`. If `alibi_slopes` is in [num_heads] shape,
        then `alibi_bias` is in [1, num_heads, max_seqlen_q, max_seqlen_kv], and if
        `alibi_slopes` is in [batch_size, num_heads], then the bias is in
        [batch_size, num_heads, max_seqlen_q, max_seqlen_kv].
    """
    global _alibi_cache
    if _alibi_cache["_alibi_slopes_require_update"]:
        if alibi_slopes is not None:
            _alibi_cache["_alibi_slopes"] = alibi_slopes
        else:
            n = 2 ** math.floor(math.log2(num_heads))
            m_0 = 2.0 ** (-8.0 / n)
            m = torch.pow(m_0, torch.arange(1, 1 + n))

            if n < num_heads:
                m_hat_0 = 2.0 ** (-4.0 / n)
                m_hat = torch.pow(m_hat_0, torch.arange(1, 1 + 2 * (num_heads - n), 2))
                m = torch.cat([m, m_hat])

            _alibi_cache["_alibi_slopes"] = m.to(dtype=torch.float32, device="cuda")
        _alibi_cache["_num_heads"] = num_heads
        _alibi_cache["_alibi_slopes_require_update"] = False

    if _alibi_cache["_alibi_bias_require_update"]:
        assert _alibi_cache["_alibi_slopes"] is not None, "ALiBi slopes can not be None!"
        if _alibi_cache["_alibi_slopes"].dim() == 1:
            slopes_shape = torch.Size([1, _alibi_cache["_alibi_slopes"].shape[0], 1, 1])
        if _alibi_cache["_alibi_slopes"].dim() == 2:
            slopes_shape = torch.Size([*_alibi_cache["_alibi_slopes"].shape[:], 1, 1])
221
222
223
224
225
226
        bias = torch.arange(1 - max_seqlen_kv, 1, dtype=torch.int32, device="cuda").view(
            1, 1, 1, max_seqlen_kv
        )
        bias = bias - torch.arange(1 - max_seqlen_q, 1, dtype=torch.int32, device="cuda").view(
            1, 1, max_seqlen_q, 1
        )
227
228
229
230
231
232
233
234
        bias = bias.abs().mul(-1)
        bias = bias * _alibi_cache["_alibi_slopes"].view(slopes_shape)
        _alibi_cache["_max_seqlen_q"], _alibi_cache["_max_seqlen_kv"] = max_seqlen_q, max_seqlen_kv
        bias_dtype = torch.float32 if bias_dtype is None else bias_dtype
        _alibi_cache["_alibi_bias"] = bias.contiguous().to(dtype=bias_dtype, device="cuda")
        _alibi_cache["_alibi_bias_require_update"] = False

    return _alibi_cache["_alibi_slopes"], _alibi_cache["_alibi_bias"]
235
236
237
238
239
240
241
242
243


def get_cu_seqlens(mask: torch.Tensor) -> torch.Tensor:
    """
    Given a padding mask of shape [batch_size, 1, 1, max_seqlen], returns an int32
    tensor of shape [batch_size + 1] containing the cumulative sequence lengths of
    the samples in a batch.
    """
    mask = mask.squeeze(1).squeeze(1)
244
    reduced_mask = mask.logical_not().sum(dim=1)
245
246
247
248
249
250
    cu_seqlens = reduced_mask.cumsum(dim=0).to(torch.int32)
    zero = torch.zeros(1, dtype=torch.int32, device="cuda")
    cu_seqlens = torch.cat((zero, cu_seqlens))

    return cu_seqlens

251

252
253
254
def get_cu_seqlens_and_indices(mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Given a padding mask of shape [batch_size, 1, 1, max_seqlen], returns an int32
255
256
257
    tensor of shape [batch_size + 1] containing the cumulative sequence lengths of
    the samples in a batch, and another int32 tensor of shape [batch_size * max_seqlen, 1, 1]
    containing the indices for the valid tokens.
258
259
260
261
    """
    mask = mask.squeeze(1).squeeze(1)
    bs, seqlen = mask.shape

262
    reduced_mask = mask.logical_not().sum(dim=1)
263
264
265
266
267
    cu_seqlens = reduced_mask.cumsum(dim=0).to(torch.int32)
    zero = torch.zeros(1, dtype=torch.int32, device="cuda")
    cu_seqlens = torch.cat((zero, cu_seqlens))

    mask = mask.reshape(-1)
268
    indices = mask.logical_not().nonzero()
269
270
271
272
    indices = indices.unsqueeze(-1)

    num_nonzeros = indices.shape[0]
    pad_amount = bs * seqlen - num_nonzeros
273
274
275
    indices = F.pad(
        input=indices, pad=(0, 0, 0, 0, 0, pad_amount), mode="constant", value=float(bs * seqlen)
    )
276
277
278
279

    return cu_seqlens, indices


280
281
282
283
284
285
286
287
def get_indices(max_seqlen: int, cu_seqlens: torch.Tensor) -> torch.Tensor:
    """
    Given max_seqlen and cu_seqlens of shape [batch_size + 1], returns an int32
    tensor of shape [batch_size * max_seqlen, 1, 1] containing the indices for
    the valid tokens in a batch.
    """
    bs = len(cu_seqlens) - 1
    seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
288
289
    indices = [i * max_seqlen + ii for i, j in enumerate(seqlens) for ii in range(j)]
    indices = torch.Tensor(indices).unsqueeze(1).unsqueeze(1).to(dtype=torch.int64, device="cuda")
290
291
292

    num_nonzeros = indices.shape[0]
    pad_amount = bs * max_seqlen - num_nonzeros
293
294
295
296
297
298
    indices = F.pad(
        input=indices,
        pad=(0, 0, 0, 0, 0, pad_amount),
        mode="constant",
        value=float(bs * max_seqlen),
    )
299
300
301

    return indices

302

303
_cu_seqlens_cache = {}
304
305


306
307
308
309
310
311
312
313
314
315
def _get_full_cu_seqlens(
    batch_size: int,
    max_seqlen: int,
    device: torch.device,
) -> torch.Tensor:
    """Cumulative sequence lengths in full data batch

    All sequences in batch have the maximum sequence length.

    """
316
317
318
319
320
321
322
323
324
325
    global _cu_seqlens_cache
    if (batch_size, max_seqlen) not in _cu_seqlens_cache:
        _cu_seqlens_cache[(batch_size, max_seqlen)] = torch.arange(
            0,
            (batch_size + 1) * max_seqlen,
            step=max_seqlen,
            dtype=torch.int32,
            device=device,
        )
    return _cu_seqlens_cache[(batch_size, max_seqlen)]
326
327


328
329
330
331
332
333
334
335
336
@jit_fuser
def pack_tensor(
    indices: torch.Tensor,
    tensor: torch.Tensor,
) -> torch.Tensor:
    """
    Packs the given tensor using the `indices`.
    """
    padding_indice = torch.zeros(
337
338
        1, tensor.shape[1], tensor.shape[2], dtype=tensor.dtype, device=tensor.device
    )
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
    tensor = torch.cat((tensor, padding_indice), dim=0)

    indices = indices.repeat(1, tensor.shape[1], tensor.shape[2])
    packed = torch.gather(tensor, 0, indices)
    return packed


@jit_fuser
def pack_2_tensors(
    indices: torch.Tensor,
    t1: torch.Tensor,
    t2: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Packs the given 2 tensors using the `indices`.
    """
    t1_packed = pack_tensor(indices, t1)
    t2_packed = pack_tensor(indices, t2)
    return t1_packed, t2_packed


@jit_fuser
def pack_3_tensors(
    indices: torch.Tensor,
    t1: torch.Tensor,
    t2: torch.Tensor,
    t3: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Packs the given 3 tensors using the `indices`.
    """
    t1_packed = pack_tensor(indices, t1)
    t2_packed = pack_tensor(indices, t2)
    t3_packed = pack_tensor(indices, t3)
    return t1_packed, t2_packed, t3_packed


@jit_fuser
def unpack_tensor(
    indices: torch.Tensor,
    dim0: int,
    tensor: torch.Tensor,
) -> torch.Tensor:
    """
    Inverse of `pack_tensor`.
    """
    indices = indices.repeat(1, tensor.shape[1], tensor.shape[2])
    unpacked = torch.zeros(
387
388
        dim0 + 1, tensor.shape[1], tensor.shape[2], dtype=tensor.dtype, device=tensor.device
    )
389
    unpacked.scatter_(0, indices, tensor)
390
    unpacked = unpacked[0:-1, :, :]
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
    return unpacked


@jit_fuser
def unpack_2_tensors(
    indices: torch.Tensor,
    dim0: int,
    t1: torch.Tensor,
    t2: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Inverse of `pack_2_tensors`.
    """
    t1_unpacked = unpack_tensor(indices, dim0, t1)
    t2_unpacked = unpack_tensor(indices, dim0, t2)
    return t1_unpacked, t2_unpacked


@jit_fuser
def unpack_3_tensors(
    indices: torch.Tensor,
    dim0: int,
    t1: torch.Tensor,
    t2: torch.Tensor,
    t3: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Inverse of `pack_3_tensors`.
    """
    t1_unpacked = unpack_tensor(indices, dim0, t1)
    t2_unpacked = unpack_tensor(indices, dim0, t2)
    t3_unpacked = unpack_tensor(indices, dim0, t3)
    return t1_unpacked, t2_unpacked, t3_unpacked


class PackTensors(torch.autograd.Function):
    """
    Autograd function to pack tensors.
    """
430

431
432
    @staticmethod
    def forward(
433
        ctx, indices: torch.Tensor, *tensors: Tuple[torch.Tensor, ...]
434
435
    ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
        assert 1 <= len(tensors) <= 3, f"Packing {len(tensors)} tensors not supported."
436
        ctx.save_for_backward(indices)
437
438
439
440
441
442
443
444
445
        ctx.dim0 = tensors[0].shape[0]
        if len(tensors) == 1:
            return pack_tensor(indices, *tensors)
        if len(tensors) == 2:
            return pack_2_tensors(indices, *tensors)
        return pack_3_tensors(indices, *tensors)

    @staticmethod
    def backward(ctx, *grad_outputs: Tuple[torch.Tensor, ...]):
446
        (indices,) = ctx.saved_tensors
447
        if len(grad_outputs) == 1:
448
            return None, unpack_tensor(indices, ctx.dim0, *grad_outputs)
449
        if len(grad_outputs) == 2:
450
451
            return None, *unpack_2_tensors(indices, ctx.dim0, *grad_outputs)
        return None, *unpack_3_tensors(indices, ctx.dim0, *grad_outputs)
452
453
454
455
456
457


class UnpackTensor(torch.autograd.Function):
    """
    Autograd function to unpack a tensor.
    """
458

459
460
461
462
463
464
465
    @staticmethod
    def forward(
        ctx,
        indices: torch.Tensor,
        dim0: int,
        tensor: torch.Tensor,
    ) -> torch.Tensor:
466
        ctx.save_for_backward(indices)
467
468
469
470
        return unpack_tensor(indices, dim0, tensor)

    @staticmethod
    def backward(ctx, grad_output):
471
472
        (indices,) = ctx.saved_tensors
        return None, None, pack_tensor(indices, grad_output)
473
474


475
476
477
def flash_attn_p2p_communicate(
    rank, send_tensor, send_dst, recv_tensor, recv_src, cp_group, batch_p2p_comm
):
478
    """Point-to-point communications of KV and dKV in Attention with context parallelism"""
479
480
481
482
    send_recv_ops = []

    if batch_p2p_comm:
        if rank % 2 == 0:
483
484
485
486
487
488
            send_op = torch.distributed.P2POp(
                torch.distributed.isend, send_tensor, send_dst, cp_group
            )
            recv_op = torch.distributed.P2POp(
                torch.distributed.irecv, recv_tensor, recv_src, cp_group
            )
489
490
491
            send_recv_ops.append(send_op)
            send_recv_ops.append(recv_op)
        else:
492
493
494
495
496
497
            recv_op = torch.distributed.P2POp(
                torch.distributed.irecv, recv_tensor, recv_src, cp_group
            )
            send_op = torch.distributed.P2POp(
                torch.distributed.isend, send_tensor, send_dst, cp_group
            )
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
            send_recv_ops.append(recv_op)
            send_recv_ops.append(send_op)
        send_recv_reqs = torch.distributed.batch_isend_irecv(send_recv_ops)
    else:
        if rank % 2 == 0:
            send_op = torch.distributed.isend(send_tensor, send_dst, cp_group)
            recv_op = torch.distributed.irecv(recv_tensor, recv_src, cp_group)
            send_recv_ops.append(send_op)
            send_recv_ops.append(recv_op)
        else:
            recv_op = torch.distributed.irecv(recv_tensor, recv_src, cp_group)
            send_op = torch.distributed.isend(send_tensor, send_dst, cp_group)
            send_recv_ops.append(recv_op)
            send_recv_ops.append(send_op)
        send_recv_reqs = send_recv_ops

    return send_recv_reqs


517
@jit_fuser
518
def flash_attn_fwd_out_correction(out, out_per_step, seq_dim, softmax_lse, softmax_lse_per_step):
519
    """Merge partial outputs of each step in Attention with context parallelism"""
520
    softmax_lse_corrected_exp = torch.exp(softmax_lse_per_step - softmax_lse).movedim(2, seq_dim)
521
    softmax_lse_corrected_exp = softmax_lse_corrected_exp.unsqueeze(-1)
522
    out_corrected = out_per_step * softmax_lse_corrected_exp
523
524
525
    out.add_(out_corrected)


526
@jit_fuser
527
def flash_attn_fwd_softmax_lse_correction(softmax_lse, softmax_lse_per_step):
528
    """Merge softmax stats of each step in Attention with context parallelism"""
529
530
531
532
    max_scale = torch.max(softmax_lse, softmax_lse_per_step)
    min_scale = torch.min(softmax_lse, softmax_lse_per_step)
    new_scale = max_scale + torch.log(1 + torch.exp(min_scale - max_scale))
    softmax_lse.copy_(new_scale)
533
534


535
class AttnFuncWithCP(torch.autograd.Function):
536
    """
537
538
    Attention implementation with context parallelism.
    Split attention compute into multiple steps, and overlap current-step
539
540
541
542
    compute with next-step communication.
    """

    @staticmethod
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
    def forward(
        ctx,
        is_training,
        q,
        k,
        v,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        seq_offsets_q,
        seq_offsets_k,
        seq_offsets_v,
        seq_offsets_o,
        dropout_p,
        cp_group,
        cp_global_ranks,
        cp_stream,
        softmax_scale,
        qkv_format,
        attn_mask_type,
        attn_bias_type,
        attn_bias,
        deterministic,
        use_fused_attention,
    ):
569
570
571
572
573
574
        if softmax_scale is None:
            softmax_scale = q.shape[-1] ** (-0.5)

        cp_size = get_distributed_world_size(cp_group)
        rank = get_distributed_rank(cp_group)
        send_dst = cp_global_ranks[(rank + 1) % cp_size]
575
        recv_src = cp_global_ranks[(rank - 1) % cp_size]
576
577
        batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2)

578
579
        causal = "causal" in attn_mask_type
        padding = "padding" in attn_mask_type
580

581
582
        qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format

583
        if causal:
584
585
            if qkv_format == "bshd":
                # [b, s, np, hn] -> [b, 2, s//2, np, hn]
586
                q, k, v = [x.view(x.shape[0], 2, x.shape[1] // 2, *x.shape[2:]) for x in [q, k, v]]
587
588
            elif qkv_format == "sbhd":
                # [s, b, np, hn] -> [2, s//2, b, np, hn]
589
                q, k, v = [x.view(2, x.shape[0] // 2, *x.shape[1:]) for x in [q, k, v]]
590
        if attn_bias is not None:
591
            assert len(attn_bias.shape) == 4, (
592
593
594
595
                "Only support bias shape of [b, h, sq, sk] for forward, "
                "and [1, h, sq, sk] for backward!"
            )
            # [b, np, sq, sk] -> [b, np, 2, sq//2, 2*cp, sk//(2*cp)]
596
597
598
599
600
601
            attn_bias_ = attn_bias.view(
                *attn_bias.shape[:-2],
                2,
                attn_bias.shape[-2] // 2,
                2 * cp_size,
                attn_bias.shape[-1] // (2 * cp_size),
602
603
            )
            # [b, np, sq, sk] -> [b, np, sq, 2*cp, sk//(2*cp)]
604
605
            attn_bias = attn_bias.view(
                *attn_bias.shape[:-1], 2 * cp_size, attn_bias.shape[-1] // (2 * cp_size)
606
            )
607
        assert q.shape[-1] % 8 == 0, "hidden size per attention head should be multiple of 8"
608
609
610
611
612
        fa_optional_forward_kwargs = {}
        if _flash_attn_2_3_plus:
            fa_optional_forward_kwargs["window_size"] = [-1, 0] if causal else [-1, -1]
        if _flash_attn_2_4_plus:
            fa_optional_forward_kwargs["alibi_slopes"] = None
613

614
615
616
        # Flash Attn inputs
        q_inputs = [None, None]
        kv_inputs = [None, None]
617
        attn_bias_inputs = [None, None]
618
619
620
621
        # Flash Attn outputs
        out_per_step = [None for _ in range(cp_size)]
        softmax_lse_per_step = [None for _ in range(cp_size)]
        rng_states = [None for _ in range(cp_size)]
622
        attn_biases = [None for _ in range(cp_size)]
623
624
625
626
627
628
629
630
631
632

        # create two streams to resolve wave quantization issue of Flash Attn in each step
        flash_attn_streams = [torch.cuda.current_stream(), cp_stream]
        # synchronize fwd results correction across steps
        fwd_results_correction_done = torch.cuda.Event()

        p2p_comm_buffers = [None for _ in range(cp_size)]
        p2p_comm_buffers[0] = torch.cat((k.unsqueeze(0), v.unsqueeze(0)), dim=0)
        send_recv_reqs = [[], []]

633
        for i in range(cp_size + 1):
634
            if i < cp_size:
635
                with torch.cuda.stream(flash_attn_streams[i % 2]):
636
                    # wait until KV is received
637
                    for req in send_recv_reqs[(i + 1) % 2]:
638
639
                        req.wait()

640
641
642
643
644
645
646
647
648
649
650
651
652
                    if i < (cp_size - 1):
                        p2p_comm_buffers[i + 1] = torch.empty_like(p2p_comm_buffers[i])
                        send_recv_reqs[i % 2] = flash_attn_p2p_communicate(
                            rank,
                            p2p_comm_buffers[i],
                            send_dst,
                            p2p_comm_buffers[i + 1],
                            recv_src,
                            cp_group,
                            batch_p2p_comm,
                        )

                    kv_inputs[i % 2] = p2p_comm_buffers[i]
653
654
                    if causal:
                        if i == 0:
655
                            if use_fused_attention:
656
657
                                if qkv_format == "bshd":
                                    # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
658
                                    q_inputs[i % 2] = q.view(q.shape[0], -1, *q.shape[-2:])
659
                                    # [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn]
660
661
662
                                    kv_inputs[i % 2] = kv_inputs[i % 2].view(
                                        2, k.shape[0], -1, *k.shape[-2:]
                                    )
663
664
                                elif qkv_format == "sbhd":
                                    # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
665
                                    q_inputs[i % 2] = q.view(-1, *q.shape[-3:])
666
                                    # [2, 2, sk//2, b, np, hn] -> [2, sk, b, np, hn]
667
                                    kv_inputs[i % 2] = kv_inputs[i % 2].view(2, -1, *k.shape[-3:])
668
                                elif qkv_format == "thd":
669
                                    q_inputs[i % 2] = q
670
671
                                if attn_bias is not None:
                                    idx = (rank - i) % cp_size
672
673
674
675
676
677
                                    attn_bias_inputs[i % 2] = torch.cat(
                                        (
                                            attn_bias[..., idx, :],
                                            attn_bias[..., (2 * cp_size - idx - 1), :],
                                        ),
                                        dim=-1,
678
                                    ).contiguous()
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
                                out_per_step[i], [softmax_lse_per_step[i], rng_states[i], *rest] = (
                                    fused_attn_fwd(
                                        is_training,
                                        max_seqlen_q,
                                        max_seqlen_k,
                                        cu_seqlens_q,
                                        cu_seqlens_k,
                                        q_inputs[i % 2],
                                        kv_inputs[i % 2][0],
                                        kv_inputs[i % 2][1],
                                        TE_DType[q.dtype],
                                        tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
                                        attn_scale=softmax_scale,
                                        dropout=dropout_p,
                                        qkv_layout=qkv_layout,
                                        attn_mask_type=attn_mask_type,
                                        attn_bias_type=attn_bias_type,
                                        attn_bias=attn_bias_inputs[i % 2],
                                        seq_offsets_q=seq_offsets_q,
                                        seq_offsets_k=seq_offsets_k,
                                        seq_offsets_v=seq_offsets_v,
                                        seq_offsets_o=seq_offsets_o,
                                    )
702
                                )
703
704
                                if len(rest) > 0:
                                    attn_biases[i] = rest[0]
705
706
                            else:
                                # [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
707
                                q_inputs[i % 2] = q.view(-1, *q.shape[-2:])
708
                                # [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn]
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
                                kv_inputs[i % 2] = kv_inputs[i % 2].view(2, -1, *k.shape[-2:])
                                (
                                    _,
                                    _,
                                    _,
                                    _,
                                    out_per_step[i],
                                    softmax_lse_per_step[i],
                                    _,
                                    rng_states[i],
                                ) = _flash_attn_forward(
                                    q_inputs[i % 2],
                                    kv_inputs[i % 2][0],
                                    kv_inputs[i % 2][1],
                                    cu_seqlens_q,
                                    cu_seqlens_k,
                                    max_seqlen_q,
                                    max_seqlen_k,
                                    dropout_p,
                                    softmax_scale,
                                    causal=True,
                                    return_softmax=False,
                                    **fa_optional_forward_kwargs,
732
                                )
733
                        elif i <= rank:
734
                            if use_fused_attention:
735
736
                                if qkv_format == "bshd":
                                    # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
737
                                    q_inputs[i % 2] = q.view(q.shape[0], -1, *q.shape[-2:])
738
                                    # [2, b, 2, sk//2, np, hn] -> [2, b, sk//2, np, hn]
739
                                    kv_inputs[i % 2] = kv_inputs[i % 2][:, :, 0, ...].contiguous()
740
741
                                elif qkv_format == "sbhd":
                                    # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
742
                                    q_inputs[i % 2] = q.view(-1, *q.shape[-3:])
743
                                    # [2, 2, sk//2, b, np, hn] -> [2, sk//2, b, np, hn]
744
                                    kv_inputs[i % 2] = kv_inputs[i % 2][:, 0, ...].contiguous()
745
                                elif qkv_format == "thd":
746
                                    q_inputs[i % 2] = q
747
                                    # [2, t, np, hn] -> [2, t/2, np, hn]
748
749
750
                                    kv_inputs[i % 2] = tex.thd_read_half_tensor(
                                        kv_inputs[i % 2], cu_seqlens_k, 0
                                    )
751
752
                                if attn_bias is not None:
                                    idx = (rank - i) % cp_size
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
                                    attn_bias_inputs[i % 2] = attn_bias[..., idx, :].contiguous()
                                out_per_step[i], [softmax_lse_per_step[i], rng_states[i], *rest] = (
                                    fused_attn_fwd(
                                        is_training,
                                        max_seqlen_q,
                                        max_seqlen_k // 2,
                                        cu_seqlens_q,
                                        cu_seqlens_k // 2,
                                        q_inputs[i % 2],
                                        kv_inputs[i % 2][0],
                                        kv_inputs[i % 2][1],
                                        TE_DType[q.dtype],
                                        tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
                                        attn_scale=softmax_scale,
                                        dropout=dropout_p,
                                        qkv_layout=qkv_layout,
                                        attn_mask_type="padding" if padding else "no_mask",
                                        attn_bias_type=attn_bias_type,
                                        attn_bias=attn_bias_inputs[i % 2],
                                        seq_offsets_q=seq_offsets_q,
                                        seq_offsets_k=(
                                            None if seq_offsets_k is None else seq_offsets_k // 2
                                        ),
                                        seq_offsets_v=(
                                            None if seq_offsets_v is None else seq_offsets_v // 2
                                        ),
                                        seq_offsets_o=seq_offsets_o,
                                    )
781
                                )
782
783
                                if len(rest) > 0:
                                    attn_biases[i] = rest[0]
784
785
                            else:
                                # [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
786
                                q_inputs[i % 2] = q.view(-1, *q.shape[-2:])
787
788
                                if qkv_format == "thd":
                                    # [2, t, np, hn] -> [2, t/2, np, hn]
789
790
791
                                    kv_inputs[i % 2] = tex.thd_read_half_tensor(
                                        kv_inputs[i % 2], cu_seqlens_k, 0
                                    )
792
793
                                else:
                                    # [2, b, 2, sk//2, np, hn] -> [2, b, sk//2, np, hn]
794
                                    kv_inputs[i % 2] = kv_inputs[i % 2][:, :, 0, ...].contiguous()
795
                                # [2, b, sk//2, np, hn] -> [2, b*sk//2, np, hn]
796
                                kv_inputs[i % 2] = kv_inputs[i % 2].view(2, -1, *k.shape[-2:])
797
798
                                if _flash_attn_2_3_plus:
                                    fa_optional_forward_kwargs["window_size"] = [-1, -1]
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
                                (
                                    _,
                                    _,
                                    _,
                                    _,
                                    out_per_step[i],
                                    softmax_lse_per_step[i],
                                    _,
                                    rng_states[i],
                                ) = _flash_attn_forward(
                                    q_inputs[i % 2],
                                    kv_inputs[i % 2][0],
                                    kv_inputs[i % 2][1],
                                    cu_seqlens_q,
                                    cu_seqlens_k // 2,
                                    max_seqlen_q,
                                    max_seqlen_k // 2,
                                    dropout_p,
                                    softmax_scale,
                                    causal=False,
                                    return_softmax=False,
                                    **fa_optional_forward_kwargs,
821
822
823
                                )
                        else:
                            if use_fused_attention:
824
825
                                if qkv_format == "bshd":
                                    # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
826
                                    q_inputs[i % 2] = q[:, 1, ...].contiguous()
827
                                    # [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn]
828
829
830
                                    kv_inputs[i % 2] = kv_inputs[i % 2].view(
                                        2, k.shape[0], -1, *k.shape[-2:]
                                    )
831
832
                                elif qkv_format == "sbhd":
                                    # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn]
833
                                    q_inputs[i % 2] = q[1].contiguous()
834
                                    # [2, 2, sk//2, b, np, hn] -> [2, sk, b, np, hn]
835
                                    kv_inputs[i % 2] = kv_inputs[i % 2].view(2, -1, *k.shape[-3:])
836
837
                                elif qkv_format == "thd":
                                    # [t, np, hn] -> [t/2, np, hn]
838
                                    q_inputs[i % 2] = tex.thd_read_half_tensor(q, cu_seqlens_q, 1)
839
840
                                if attn_bias is not None:
                                    idx = (rank - i) % cp_size
841
842
843
844
845
846
                                    attn_bias_inputs[i % 2] = torch.cat(
                                        (
                                            attn_bias_[..., 1, :, idx, :],
                                            attn_bias_[..., 1, :, (2 * cp_size - idx - 1), :],
                                        ),
                                        dim=-1,
847
                                    ).contiguous()
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
                                out_per_step[i], [softmax_lse_per_step[i], rng_states[i], *rest] = (
                                    fused_attn_fwd(
                                        is_training,
                                        max_seqlen_q // 2,
                                        max_seqlen_k,
                                        cu_seqlens_q // 2,
                                        cu_seqlens_k,
                                        q_inputs[i % 2],
                                        kv_inputs[i % 2][0],
                                        kv_inputs[i % 2][1],
                                        TE_DType[q.dtype],
                                        tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
                                        attn_scale=softmax_scale,
                                        dropout=dropout_p,
                                        qkv_layout=qkv_layout,
                                        attn_mask_type="padding" if padding else "no_mask",
                                        attn_bias_type=attn_bias_type,
                                        attn_bias=attn_bias_inputs[i % 2],
                                        seq_offsets_q=(
                                            None if seq_offsets_q is None else seq_offsets_q // 2
                                        ),
                                        seq_offsets_k=seq_offsets_k,
                                        seq_offsets_v=seq_offsets_v,
                                        seq_offsets_o=(
                                            None if seq_offsets_o is None else seq_offsets_o // 2
                                        ),
                                    )
875
                                )
876
877
                                if len(rest) > 0:
                                    attn_biases[i] = rest[0]
878
                            else:
879
880
                                if qkv_format == "thd":
                                    # [t, np, hn] -> [t/2, np, hn]
881
                                    q_inputs[i % 2] = tex.thd_read_half_tensor(q, cu_seqlens_q, 1)
882
883
                                else:
                                    # [b, 2, sq//2, np, hn]->[b, sq//2, np, hn]->[b*sq//2, np, hn]
884
                                    q_inputs[i % 2] = (
885
                                        q[:, 1, ...].contiguous().view(-1, *q.shape[-2:])
886
                                    )
887
                                # [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn]
888
                                kv_inputs[i % 2] = kv_inputs[i % 2].view(2, -1, *k.shape[-2:])
889
890
                                if _flash_attn_2_3_plus:
                                    fa_optional_forward_kwargs["window_size"] = [-1, -1]
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
                                (
                                    _,
                                    _,
                                    _,
                                    _,
                                    out_per_step[i],
                                    softmax_lse_per_step[i],
                                    _,
                                    rng_states[i],
                                ) = _flash_attn_forward(
                                    q_inputs[i % 2],
                                    kv_inputs[i % 2][0],
                                    kv_inputs[i % 2][1],
                                    cu_seqlens_q // 2,
                                    cu_seqlens_k,
                                    max_seqlen_q // 2,
                                    max_seqlen_k,
                                    dropout_p,
                                    softmax_scale,
                                    causal=False,
                                    return_softmax=False,
                                    **fa_optional_forward_kwargs,
913
914
915
                                )
                    else:
                        if use_fused_attention:
916
917
                            if attn_bias is not None:
                                idx = (rank - i) % cp_size
918
919
920
921
922
923
                                attn_bias_inputs[i % 2] = torch.cat(
                                    (
                                        attn_bias[..., idx, :],
                                        attn_bias[..., (2 * cp_size - idx - 1), :],
                                    ),
                                    dim=-1,
924
                                ).contiguous()
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
                            out_per_step[i], [softmax_lse_per_step[i], rng_states[i], *rest] = (
                                fused_attn_fwd(
                                    is_training,
                                    max_seqlen_q,
                                    max_seqlen_k,
                                    cu_seqlens_q,
                                    cu_seqlens_k,
                                    q,
                                    kv_inputs[i % 2][0],
                                    kv_inputs[i % 2][1],
                                    TE_DType[q.dtype],
                                    tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
                                    attn_scale=softmax_scale,
                                    dropout=dropout_p,
                                    qkv_layout=qkv_layout,
                                    attn_mask_type=attn_mask_type,
                                    attn_bias_type=attn_bias_type,
                                    attn_bias=attn_bias_inputs[i % 2],
                                    seq_offsets_q=seq_offsets_q,
                                    seq_offsets_k=seq_offsets_k,
                                    seq_offsets_v=seq_offsets_v,
                                    seq_offsets_o=seq_offsets_o,
                                )
948
                            )
949
950
                            if len(rest) > 0:
                                attn_biases[i] = rest[0]
951
                        else:
952
                            # [b, sq, np, hn] -> [b*sq, np, hn]
953
                            q_inputs[i % 2] = q.view(-1, *q.shape[-2:])
954
                            # [2, b, sk, np, hn] -> [2, b*sk, np, hn]
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
                            kv_inputs[i % 2] = kv_inputs[i % 2].view(2, -1, *k.shape[-2:])
                            (
                                _,
                                _,
                                _,
                                _,
                                out_per_step[i],
                                softmax_lse_per_step[i],
                                _,
                                rng_states[i],
                            ) = _flash_attn_forward(
                                q_inputs[i % 2],
                                kv_inputs[i % 2][0],
                                kv_inputs[i % 2][1],
                                cu_seqlens_q,
                                cu_seqlens_k,
                                max_seqlen_q,
                                max_seqlen_k,
                                dropout_p,
                                softmax_scale,
                                causal=False,
                                return_softmax=False,
                                **fa_optional_forward_kwargs,
978
                            )
979
980
981
982

            if i > 0:
                # wait until fwd restuls correction of last step is done
                if i > 1:
983
                    flash_attn_streams[(i - 1) % 2].wait_event(fwd_results_correction_done)
984

985
986
                if use_fused_attention:
                    # [b, np, sq, 1] -> [b, np, sq]
987
                    softmax_lse_per_step[i - 1].squeeze_(-1)
988

989
                with torch.cuda.stream(flash_attn_streams[(i - 1) % 2]):
990
991
992
                    if i == 1:
                        out = torch.empty_like(q).zero_()
                        softmax_lse = torch.clone(softmax_lse_per_step[0]).to(torch.double)
993
                        if causal and qkv_format != "thd":
994
995
                            # [b, np, sq] -> [b, np, 2, sq//2]
                            softmax_lse_ = softmax_lse.view(
996
                                *softmax_lse.shape[:-1], 2, softmax_lse.shape[-1] // 2
997
                            )
998
999
1000
1001
                    elif (i - 1) <= rank or not causal:
                        flash_attn_fwd_softmax_lse_correction(
                            softmax_lse, softmax_lse_per_step[i - 1]
                        )
1002
                    else:
1003
                        if qkv_format == "thd":
1004
1005
1006
                            tex.thd_second_half_lse_correction(
                                softmax_lse, softmax_lse_per_step[i - 1], cu_seqlens_q, q.size(0)
                            )
1007
                        else:
1008
1009
1010
                            flash_attn_fwd_softmax_lse_correction(
                                softmax_lse_[..., 1, :], softmax_lse_per_step[i - 1]
                            )
1011
1012

                if i < cp_size:
1013
                    flash_attn_streams[(i - 1) % 2].record_event(fwd_results_correction_done)
1014
1015
1016
1017

        torch.cuda.current_stream().wait_stream(flash_attn_streams[1])

        softmax_lse = softmax_lse.to(torch.float)
1018
1019
        if qkv_format in ["bshd", "sbhd"]:
            seq_dim = qkv_format.index("s")
1020
        for i in range(cp_size):
1021
1022
1023
1024
1025
1026
            if qkv_format == "bshd":
                out_per_step[i] = out_per_step[i].view(out.shape[0], -1, *out.shape[-2:])
                out_ = out[:, 1, ...]
            elif qkv_format == "sbhd":
                out_per_step[i] = out_per_step[i].view(-1, *out.shape[-3:])
                out_ = out[1]
1027

1028
            if i <= rank or not causal:
1029
                if qkv_format in ["bshd", "sbhd"]:
1030
1031
1032
1033
1034
1035
1036
                    flash_attn_fwd_out_correction(
                        out.view(*out_per_step[i].shape),
                        out_per_step[i],
                        seq_dim,
                        softmax_lse,
                        softmax_lse_per_step[i],
                    )
1037
                elif qkv_format == "thd":
1038
1039
1040
1041
1042
1043
1044
1045
                    tex.thd_out_correction(
                        out,
                        out_per_step[i],
                        softmax_lse,
                        softmax_lse_per_step[i],
                        cu_seqlens_q,
                        False,
                    )
1046
1047
                else:
                    assert False, f"{qkv_format} is an unsupported qkv_format!"
1048
            else:
1049
                if qkv_format in ["bshd", "sbhd"]:
1050
1051
1052
1053
1054
1055
1056
                    flash_attn_fwd_out_correction(
                        out_,
                        out_per_step[i],
                        seq_dim,
                        softmax_lse_[..., 1, :],
                        softmax_lse_per_step[i],
                    )
1057
                elif qkv_format == "thd":
1058
1059
1060
1061
1062
1063
1064
1065
                    tex.thd_out_correction(
                        out,
                        out_per_step[i],
                        softmax_lse,
                        softmax_lse_per_step[i],
                        cu_seqlens_q,
                        True,
                    )
1066
1067
                else:
                    assert False, f"{qkv_format} is an unsupported qkv_format!"
1068
1069

        kv = p2p_comm_buffers[-1]
1070
        if use_fused_attention:
1071
1072
1073
1074
            if qkv_format == "bshd":
                out = out.view(out.shape[0], -1, *out.shape[-2:])
            elif qkv_format == "sbhd":
                out = out.view(-1, *out.shape[-3:])
1075
1076
        else:
            out = out.view(-1, *out.shape[-2:])
1077

1078
        ctx.save_for_backward(
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
            q,
            kv,
            out,
            softmax_lse,
            cu_seqlens_q,
            cu_seqlens_k,
            seq_offsets_q,
            seq_offsets_k,
            seq_offsets_v,
            seq_offsets_o,
            *rng_states,
            *attn_biases,
1091
        )
1092
1093
1094
1095
1096
1097
        ctx.cp_group = cp_group
        ctx.cp_global_ranks = cp_global_ranks
        ctx.dropout_p = dropout_p
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_k = max_seqlen_k
        ctx.softmax_scale = softmax_scale
1098
        ctx.qkv_format = qkv_format
1099
        ctx.attn_mask_type = attn_mask_type
1100
1101
        ctx.attn_bias_type = attn_bias_type
        ctx.attn_bias_shape = None if attn_bias is None else attn_bias.shape
1102
        ctx.deterministic = deterministic
1103
        ctx.use_fused_attention = use_fused_attention
1104
1105
1106
1107
        return out

    @staticmethod
    def backward(ctx, dout):
1108
        (q, kv, out, softmax_lse, cu_seqlens_q, cu_seqlens_k) = ctx.saved_tensors[:6]
1109
        (seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o) = ctx.saved_tensors[6:10]
1110
        cp_size = get_distributed_world_size(ctx.cp_group)
1111
1112
        rng_states = ctx.saved_tensors[10 : 10 + cp_size]
        attn_biases = ctx.saved_tensors[10 + cp_size : 10 + cp_size * 2]
1113

1114
        rank = get_distributed_rank(ctx.cp_group)
1115
        send_dst = ctx.cp_global_ranks[(rank - 1) % cp_size]
1116
1117
1118
        recv_src = ctx.cp_global_ranks[(rank + 1) % cp_size]
        batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2)

1119
1120
        causal = "causal" in ctx.attn_mask_type
        padding = "padding" in ctx.attn_mask_type
1121
1122
        qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format

1123
        if attn_biases[0] is not None:
1124
1125
            # [b, np, sq, 2*cp, sk//(2*cp)]
            attn_dbias = torch.zeros(
1126
                *ctx.attn_bias_shape, dtype=attn_biases[0].dtype, device=attn_biases[0].device
1127
1128
1129
            )
            # [b, np, sq, 2*cp, sk//(2*cp)] -> [b, np, 2, sq//2, 2*cp, sk//(2*cp)]
            attn_dbias_ = attn_dbias.view(
1130
                *attn_dbias.shape[:-3], 2, attn_dbias.shape[-3] // 2, *attn_dbias.shape[-2:]
1131
1132
1133
1134
            )
        else:
            attn_dbias = None

1135
        if causal:
1136
1137
1138
1139
            if ctx.qkv_format == "thd":
                softmax_lse_ = tex.thd_read_second_half_lse(softmax_lse, cu_seqlens_q, q.size(0))
            else:
                # [b, np, sq] -> [b, np, 2, sq//2]
1140
1141
1142
                softmax_lse_ = softmax_lse.view(
                    *softmax_lse.shape[:-1], 2, softmax_lse.shape[-1] // 2
                )
1143
1144
1145
1146
1147
                softmax_lse_ = softmax_lse_[..., 1, :].contiguous()
                if ctx.use_fused_attention:
                    # [b, np, sq//2] -> [b, np, sq//2, 1]
                    softmax_lse_.unsqueeze_(-1)

1148
1149
1150
        if ctx.use_fused_attention:
            # [b, np, sq] -> [b, np, sq, 1]
            softmax_lse.unsqueeze_(-1)
1151
1152
1153
1154
1155
        out = out.view(*q.shape)
        dout = dout.view(*q.shape)
        # Flash Attn outputs
        dq = torch.empty_like(q)

1156
1157
1158
1159
        p2p_comm_buffers = [
            torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device),
            torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device),
        ]
1160
1161
1162
        p2p_comm_buffers[0][0].copy_(kv)
        send_recv_reqs = []

1163
1164
1165
1166
1167
1168
        fa_optional_backward_kwargs = {}
        if _flash_attn_2_4_plus:
            fa_optional_backward_kwargs["alibi_slopes"] = None
        if _flash_attn_2_4_1_plus:
            fa_optional_backward_kwargs["deterministic"] = ctx.deterministic

1169
1170
1171
1172
1173
        for i in range(cp_size):
            # wait until KV is received
            for req in send_recv_reqs:
                req.wait()

1174
1175
            send_tensor = p2p_comm_buffers[i % 2]
            recv_tensor = p2p_comm_buffers[(i + 1) % 2]
1176
1177
1178
            if i == 0:
                send_tensor = send_tensor[0]
                recv_tensor = recv_tensor[0]
1179
            if i == (cp_size - 1):
1180
1181
1182
                send_tensor = send_tensor[1]
                recv_tensor = recv_tensor[1]

1183
1184
1185
            send_recv_reqs = flash_attn_p2p_communicate(
                rank, send_tensor, send_dst, recv_tensor, recv_src, ctx.cp_group, batch_p2p_comm
            )
1186

1187
            kv = p2p_comm_buffers[i % 2][0]
1188
            # In reversed order of fwd
1189
            if causal:
1190
                if i == (cp_size - 1):
1191
                    if ctx.use_fused_attention:
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
                        if ctx.qkv_format == "bshd":
                            # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
                            q_ = q.view(q.shape[0], -1, *q.shape[-2:])
                            # [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn]
                            kv_ = kv.view(*kv.shape[0:2], -1, *kv.shape[-2:])
                            # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
                            out_ = out.view(out.shape[0], -1, *out.shape[-2:])
                            dout_ = dout.view(dout.shape[0], -1, *dout.shape[-2:])
                        elif ctx.qkv_format == "sbhd":
                            # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
                            q_ = q.view(-1, *q.shape[-3:])
                            # [2, 2, sk//2, b, np, hn] -> [2, sk, b, np, hn]
                            kv_ = kv.view(kv.shape[0], -1, *kv.shape[-3:])
                            # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
                            out_ = out.view(-1, *out.shape[-3:])
                            dout_ = dout.view(-1, *dout.shape[-3:])
1208
1209
                        elif ctx.qkv_format == "thd":
                            q_, kv_, out_, dout_ = q, kv, out, dout
1210
                        aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]]
1211
                        if attn_dbias is not None:
1212
                            aux_ctx_tensors += [attn_biases[cp_size - i - 1]]
1213
                        dq_, dk_, dv_, dbias_ = fused_attn_bwd(
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
                            ctx.max_seqlen_q,
                            ctx.max_seqlen_k,
                            cu_seqlens_q,
                            cu_seqlens_k,
                            q_,
                            kv_[0],
                            kv_[1],
                            out_,
                            dout_,
                            TE_DType[q.dtype],
                            TE_DType[kv.dtype],
                            aux_ctx_tensors,
1226
                            tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
1227
1228
1229
1230
                            seq_offsets_q,
                            seq_offsets_k,
                            seq_offsets_v,
                            seq_offsets_o,
1231
1232
                            attn_scale=ctx.softmax_scale,
                            dropout=ctx.dropout_p,
1233
                            qkv_layout=qkv_layout,
1234
                            attn_mask_type=ctx.attn_mask_type,
1235
                            attn_bias_type=ctx.attn_bias_type,
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
                        )
                    else:
                        # [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
                        q_ = q.view(-1, *q.shape[-2:])
                        dq_ = torch.empty_like(q_)
                        # [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn]
                        kv_ = kv.view(2, -1, *kv.shape[-2:])
                        dkv_ = torch.empty_like(kv_)
                        # [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
                        out_ = out.view(-1, *out.shape[-2:])
                        dout_ = dout.view(-1, *dout.shape[-2:])
                        if _flash_attn_2_3_plus:
                            fa_optional_backward_kwargs["window_size"] = [-1, 0]
                        _flash_attn_backward(
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
                            dout_,
                            q_,
                            kv_[0],
                            kv_[1],
                            out_,
                            softmax_lse,
                            dq_,
                            dkv_[0],
                            dkv_[1],
                            cu_seqlens_q,
                            cu_seqlens_k,
                            ctx.max_seqlen_q,
                            ctx.max_seqlen_k,
                            ctx.dropout_p,
                            ctx.softmax_scale,
                            True,
                            rng_state=rng_states[cp_size - i - 1],
                            **fa_optional_backward_kwargs,
1268
                        )
1269
                elif i >= (cp_size - rank - 1):
1270
                    if ctx.use_fused_attention:
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
                        if ctx.qkv_format == "bshd":
                            # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
                            q_ = q.view(q.shape[0], -1, *q.shape[-2:])
                            # [2, b, 2, sk//2, np, hn] -> [2, b, sk//2, np, hn]
                            kv_ = kv[:, :, 0, ...].contiguous()
                            # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
                            out_ = out.view(out.shape[0], -1, *out.shape[-2:])
                            dout_ = dout.view(dout.shape[0], -1, *dout.shape[-2:])
                        elif ctx.qkv_format == "sbhd":
                            # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
                            q_ = q.view(-1, *q.shape[-3:])
                            # [2, 2, sk//2, b, np, hn] -> [2, sk//2, b, np, hn]
                            kv_ = kv[:, 0, ...].contiguous()
                            # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
                            out_ = out.view(-1, *out.shape[-3:])
                            dout_ = dout.view(-1, *dout.shape[-3:])
1287
1288
1289
1290
                        elif ctx.qkv_format == "thd":
                            q_, out_, dout_ = q, out, dout
                            # [2, t, np, hn] -> [2, t/2, np, hn]
                            kv_ = tex.thd_read_half_tensor(kv, cu_seqlens_k, 0)
1291
                        aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]]
1292
                        if attn_dbias is not None:
1293
                            aux_ctx_tensors += [attn_biases[cp_size - i - 1]]
1294
                        dq_, dk_, dv_, dbias_ = fused_attn_bwd(
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
                            ctx.max_seqlen_q,
                            ctx.max_seqlen_k // 2,
                            cu_seqlens_q,
                            cu_seqlens_k // 2,
                            q_,
                            kv_[0],
                            kv_[1],
                            out_,
                            dout_,
                            TE_DType[q.dtype],
                            TE_DType[kv.dtype],
                            aux_ctx_tensors,
1307
                            tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
1308
1309
1310
1311
                            seq_offsets_q,
                            None if seq_offsets_k is None else seq_offsets_k // 2,
                            None if seq_offsets_v is None else seq_offsets_v // 2,
                            seq_offsets_o,
1312
1313
                            attn_scale=ctx.softmax_scale,
                            dropout=ctx.dropout_p,
1314
                            qkv_layout=qkv_layout,
1315
                            attn_mask_type="padding" if padding else "no_mask",
1316
                            attn_bias_type=ctx.attn_bias_type,
1317
1318
1319
1320
1321
                        )
                    else:
                        # [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
                        q_ = q.view(-1, *q.shape[-2:])
                        dq_ = torch.empty_like(q_)
1322
1323
1324
1325
1326
1327
                        if ctx.qkv_format == "thd":
                            # [2, t, np, hn] -> [2, t/2, np, hn]
                            kv_ = tex.thd_read_half_tensor(kv, cu_seqlens_k, 0)
                        else:
                            # [2, b, 2, sk//2, np, hn]->[2, b, sk//2, np, hn]->[2, b*sk//2, np, hn]
                            kv_ = kv[:, :, 0, ...].contiguous().view(2, -1, *kv.shape[-2:])
1328
1329
1330
1331
1332
1333
1334
                        dkv_ = torch.empty_like(kv_)
                        # [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
                        out_ = out.view(-1, *out.shape[-2:])
                        dout_ = dout.view(-1, *dout.shape[-2:])
                        if _flash_attn_2_3_plus:
                            fa_optional_backward_kwargs["window_size"] = [-1, -1]
                        _flash_attn_backward(
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
                            dout_,
                            q_,
                            kv_[0],
                            kv_[1],
                            out_,
                            softmax_lse,
                            dq_,
                            dkv_[0],
                            dkv_[1],
                            cu_seqlens_q,
                            cu_seqlens_k // 2,
                            ctx.max_seqlen_q,
                            ctx.max_seqlen_k // 2,
                            ctx.dropout_p,
                            ctx.softmax_scale,
                            False,
                            rng_state=rng_states[cp_size - i - 1],
                            **fa_optional_backward_kwargs,
1353
1354
1355
                        )
                else:
                    if ctx.use_fused_attention:
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
                        if ctx.qkv_format == "bshd":
                            # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
                            q_ = q[:, 1, ...].contiguous()
                            # [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn]
                            kv_ = kv.view(*kv.shape[0:2], -1, *kv.shape[-2:])
                            # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
                            out_ = out[:, 1, ...].contiguous()
                            dout_ = dout[:, 1, ...].contiguous()
                        elif ctx.qkv_format == "sbhd":
                            # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn]
                            q_ = q[1].contiguous()
                            # [2, 2, sk//2, b, np, hn] -> [2, sk, b, np, hn]
                            kv_ = kv.view(kv.shape[0], -1, *kv.shape[-3:])
                            # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn]
                            out_ = out[1].contiguous()
                            dout_ = dout[1].contiguous()
1372
1373
1374
1375
1376
1377
                        elif ctx.qkv_format == "thd":
                            # [t, np, hn] -> [t/2, np, hn]
                            q_ = tex.thd_read_half_tensor(q, cu_seqlens_q, 1)
                            out_ = tex.thd_read_half_tensor(out, cu_seqlens_q, 1)
                            dout_ = tex.thd_read_half_tensor(dout, cu_seqlens_q, 1)
                            kv_ = kv
1378
                        aux_ctx_tensors = [softmax_lse_, rng_states[cp_size - i - 1]]
1379
                        if attn_dbias is not None:
1380
                            aux_ctx_tensors += [attn_biases[cp_size - i - 1]]
1381
                        dq_, dk_, dv_, dbias_ = fused_attn_bwd(
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
                            ctx.max_seqlen_q // 2,
                            ctx.max_seqlen_k,
                            cu_seqlens_q // 2,
                            cu_seqlens_k,
                            q_,
                            kv_[0],
                            kv_[1],
                            out_,
                            dout_,
                            TE_DType[q.dtype],
                            TE_DType[kv.dtype],
                            aux_ctx_tensors,
1394
                            tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
1395
1396
1397
1398
                            None if seq_offsets_q is None else seq_offsets_q // 2,
                            seq_offsets_k,
                            seq_offsets_v,
                            None if seq_offsets_o is None else seq_offsets_o // 2,
1399
1400
                            attn_scale=ctx.softmax_scale,
                            dropout=ctx.dropout_p,
1401
                            qkv_layout=qkv_layout,
1402
                            attn_mask_type="padding" if padding else "no_mask",
1403
                            attn_bias_type=ctx.attn_bias_type,
1404
1405
                        )
                    else:
1406
1407
1408
1409
1410
1411
                        if ctx.qkv_format == "thd":
                            # [t, np, hn] -> [t/2, np, hn]
                            q_ = tex.thd_read_half_tensor(q, cu_seqlens_q, 1)
                        else:
                            # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] -> [b*sq//2, np, hn]
                            q_ = q[:, 1, ...].contiguous().view(-1, *q.shape[-2:])
1412
1413
1414
1415
                        dq_ = torch.empty_like(q_)
                        # [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn]
                        kv_ = kv.view(2, -1, *kv.shape[-2:])
                        dkv_ = torch.empty_like(kv_)
1416
1417
1418
1419
1420
1421
1422
                        if ctx.qkv_format == "thd":
                            out_ = tex.thd_read_half_tensor(out, cu_seqlens_q, 1)
                            dout_ = tex.thd_read_half_tensor(dout, cu_seqlens_q, 1)
                        else:
                            # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] -> [b*sq//2, np, hn]
                            out_ = out[:, 1, ...].contiguous().view(-1, *out.shape[-2:])
                            dout_ = dout[:, 1, ...].contiguous().view(-1, *dout.shape[-2:])
1423
1424
1425
                        if _flash_attn_2_3_plus:
                            fa_optional_backward_kwargs["window_size"] = [-1, -1]
                        _flash_attn_backward(
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
                            dout_,
                            q_,
                            kv_[0],
                            kv_[1],
                            out_,
                            softmax_lse_,
                            dq_,
                            dkv_[0],
                            dkv_[1],
                            cu_seqlens_q // 2,
                            cu_seqlens_k,
                            ctx.max_seqlen_q // 2,
                            ctx.max_seqlen_k,
                            ctx.dropout_p,
                            ctx.softmax_scale,
                            False,
                            rng_state=rng_states[cp_size - i - 1],
                            **fa_optional_backward_kwargs,
1444
1445
1446
                        )
            else:
                if ctx.use_fused_attention:
1447
                    aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]]
1448
                    if attn_dbias is not None:
1449
                        aux_ctx_tensors += [attn_biases[cp_size - i - 1]]
1450
                    dq_, dk_, dv_, dbias_ = fused_attn_bwd(
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
                        ctx.max_seqlen_q,
                        ctx.max_seqlen_k,
                        cu_seqlens_q,
                        cu_seqlens_k,
                        q,
                        kv[0],
                        kv[1],
                        out,
                        dout,
                        TE_DType[q.dtype],
                        TE_DType[kv.dtype],
                        aux_ctx_tensors,
1463
                        tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
1464
1465
1466
1467
                        seq_offsets_q,
                        seq_offsets_k,
                        seq_offsets_v,
                        seq_offsets_o,
1468
1469
                        attn_scale=ctx.softmax_scale,
                        dropout=ctx.dropout_p,
1470
                        qkv_layout=qkv_layout,
1471
                        attn_mask_type=ctx.attn_mask_type,
1472
                        attn_bias_type=ctx.attn_bias_type,
1473
1474
1475
                    )
                else:
                    # [b, sq, np, hn] -> [b*sq, np, hn]
1476
1477
                    q_ = q.view(-1, *q.shape[-2:])
                    dq_ = torch.empty_like(q_)
1478
                    # [2, b, sk, np, hn] -> [2, b*sk, np, hn]
1479
1480
                    kv_ = kv.view(2, -1, *kv.shape[-2:])
                    dkv_ = torch.empty_like(kv_)
1481
                    # [b, sq, np, hn] -> [b*sq, np, hn]
1482
1483
                    out_ = out.view(-1, *out.shape[-2:])
                    dout_ = dout.view(-1, *dout.shape[-2:])
1484
1485
                    if _flash_attn_2_3_plus:
                        fa_optional_backward_kwargs["window_size"] = [-1, -1]
1486
                    _flash_attn_backward(
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
                        dout_,
                        q_,
                        kv_[0],
                        kv_[1],
                        out_,
                        softmax_lse,
                        dq_,
                        dkv_[0],
                        dkv_[1],
                        cu_seqlens_q,
                        cu_seqlens_k,
                        ctx.max_seqlen_q,
                        ctx.max_seqlen_k,
                        ctx.dropout_p,
                        ctx.softmax_scale,
                        False,
                        **fa_optional_backward_kwargs,
1504
1505
                    )

1506
            if i >= (cp_size - rank - 1) or not causal:
1507
1508
1509
1510
                # [b*sq, np, hn] -> [b, 2, sq//2, np, hn] if causal
                # [b*sq, np, hn] -> [b, sq, np, hn] if not causal
                dq_ = dq_.view(*dq.shape)
            else:
1511
1512
1513
1514
1515
1516
                if ctx.qkv_format == "bshd":
                    # [b*sq//2, np, hn] -> [b, sq//2, np, hn]
                    dq_ = dq_.view(dq.shape[0], *dq.shape[2:])
                elif ctx.qkv_format == "sbhd":
                    # [b*sq//2, np, hn] -> [sq//2, b, np, hn]
                    dq_ = dq_.view(-1, *dq.shape[-3:])
1517

1518
            if causal:
1519
                if i > (cp_size - rank - 1):
1520
                    dq.add_(dq_)
1521
1522
                elif i == (cp_size - rank - 1):
                    if rank == (cp_size - 1):
1523
1524
                        dq.copy_(dq_)
                    else:
1525
1526
1527
1528
1529
1530
                        if ctx.qkv_format == "bshd":
                            dq[:, 0, ...].copy_(dq_[:, 0, ...])
                            dq[:, 1, ...].add_(dq_[:, 1, ...])
                        elif ctx.qkv_format == "sbhd":
                            dq[0].copy_(dq_[0])
                            dq[1].add_(dq_[1])
1531
1532
                        elif ctx.qkv_format == "thd":
                            tex.thd_grad_correction(dq, dq_, cu_seqlens_q, "copy", "add")
1533
                elif i > 0:
1534
1535
1536
1537
                    if ctx.qkv_format == "bshd":
                        dq[:, 1, ...].add_(dq_)
                    elif ctx.qkv_format == "sbhd":
                        dq[1].add_(dq_)
1538
1539
                    elif ctx.qkv_format == "thd":
                        tex.thd_grad_correction(dq, dq_, cu_seqlens_q, "none", "add")
1540
                else:
1541
1542
1543
1544
                    if ctx.qkv_format == "bshd":
                        dq[:, 1, ...].copy_(dq_)
                    elif ctx.qkv_format == "sbhd":
                        dq[1].copy_(dq_)
1545
1546
                    elif ctx.qkv_format == "thd":
                        tex.thd_grad_correction(dq, dq_, cu_seqlens_q, "none", "copy")
1547
1548
1549
1550
1551
            else:
                if i == 0:
                    dq.copy_(dq_)
                else:
                    dq.add_(dq_)
1552

1553
            if attn_dbias is not None:
1554
                idx = (rank + i + 1) % cp_size
1555
                if i == (cp_size - 1) or not causal:
1556
                    # [b, np, sq, sk//cp] -> [b, np, sq, 2, sk//(2*cp)]
1557
                    dbias_ = dbias_.view(*dbias_.shape[:-1], 2, dbias_.shape[-1] // 2)
1558
                    attn_dbias[..., idx, :].copy_(dbias_[..., 0, :])
1559
1560
                    attn_dbias[..., (2 * cp_size - idx - 1), :].copy_(dbias_[..., 1, :])
                elif i >= (cp_size - rank - 1):
1561
1562
1563
1564
                    # [b, np, sq, sk//(2*cp)]
                    attn_dbias[..., idx, :].copy_(dbias_)
                else:
                    # [b, np, sq//2, sk//cp] -> [b, np, sq//2, 2, sk//(2*cp)]
1565
                    dbias_ = dbias_.view(*dbias_.shape[:-1], 2, dbias_.shape[-1] // 2)
1566
                    attn_dbias_[..., 1, :, idx, :].copy_(dbias_[..., 0, :])
1567
                    attn_dbias_[..., 1, :, (2 * cp_size - idx - 1), :].copy_(dbias_[..., 1, :])
1568

1569
1570
1571
            # wait until dKV is received
            for req in send_recv_reqs:
                req.wait()
1572

1573
            dkv = p2p_comm_buffers[(i + 1) % 2][1]
1574
1575
            if ctx.use_fused_attention:
                dkv_ = torch.cat((dk_.unsqueeze(0), dv_.unsqueeze(0)), dim=0)
1576
            if causal and i >= (cp_size - rank - 1) and i != (cp_size - 1):
1577
1578
1579
1580
1581
1582
                if ctx.qkv_format == "bshd":
                    # [2, b*sk//2, np, hn] -> [2, b, sk//2, np, hn]
                    dkv_ = dkv_.view(*dkv.shape[0:2], *dkv.shape[3:])
                elif ctx.qkv_format == "sbhd":
                    # [2, b*sk//2, np, hn] -> [2, sk//2, b, np, hn]
                    dkv_ = dkv_.view(dkv.shape[0], -1, *dkv.shape[-3:])
1583
1584
1585
1586
            else:
                # [2, b*sk, np, hn] -> [2, b, 2, sk//2, np, hn] if causal
                # [2, b*sk, np, hn] -> [2, b, sk, np, hn] if not causal
                dkv_ = dkv_.view(*dkv.shape)
1587

1588
            if causal:
1589
                if i == (cp_size - 1):
1590
                    if rank == 0:
1591
1592
1593
1594
1595
1596
                        if ctx.qkv_format == "bshd":
                            dkv[:, :, 0, ...].add_(dkv_[:, :, 0, ...])
                            dkv[:, :, 1, ...].copy_(dkv_[:, :, 1, ...])
                        elif ctx.qkv_format == "sbhd":
                            dkv[:, 0, ...].add_(dkv_[:, 0, ...])
                            dkv[:, 1, ...].copy_(dkv_[:, 1, ...])
1597
1598
                        elif ctx.qkv_format == "thd":
                            tex.thd_grad_correction(dkv, dkv_, cu_seqlens_k, "add", "copy")
1599
1600
                    else:
                        dkv.add_(dkv_)
1601
1602
                elif i >= (cp_size - rank - 1):
                    if i == 0 and rank == (cp_size - 1):
1603
1604
1605
1606
                        if ctx.qkv_format == "bshd":
                            dkv[:, :, 0, ...].copy_(dkv_)
                        elif ctx.qkv_format == "sbhd":
                            dkv[:, 0, ...].copy_(dkv_)
1607
1608
                        elif ctx.qkv_format == "thd":
                            tex.thd_grad_correction(dkv, dkv_, cu_seqlens_k, "copy", "none")
1609
                    else:
1610
1611
1612
1613
                        if ctx.qkv_format == "bshd":
                            dkv[:, :, 0, ...].add_(dkv_)
                        elif ctx.qkv_format == "sbhd":
                            dkv[:, 0, ...].add_(dkv_)
1614
1615
                        elif ctx.qkv_format == "thd":
                            tex.thd_grad_correction(dkv, dkv_, cu_seqlens_k, "add", "none")
1616
1617
1618
1619
1620
                elif i > 0:
                    dkv.add_(dkv_)
                else:
                    dkv.copy_(dkv_)
            else:
1621
1622
1623
1624
1625
                if i == 0:
                    dkv.copy_(dkv_)
                else:
                    dkv.add_(dkv_)

1626
        if causal:
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
            if ctx.qkv_format == "bshd":
                # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
                dq = dq.view(q.shape[0], -1, *q.shape[-2:])
                # [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn]
                dkv = dkv.view(*kv.shape[0:2], -1, *kv.shape[-2:])
            elif ctx.qkv_format == "sbhd":
                # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
                dq = dq.view(-1, *q.shape[-3:])
                # [2, 2, sk//2, b, np, hn] -> [2, sk, b, np, hn]
                dkv = dkv.view(kv.shape[0], -1, *kv.shape[-3:])

        if attn_dbias is not None:
            # [b, np, sq, 2*cp, sk//(2*cp)] -> [b, np, sq, sk]
            attn_dbias = attn_dbias.view(*attn_dbias.shape[:-2], -1)

1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
        return (
            None,
            dq,
            dkv[0],
            dkv[1],
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            attn_dbias,
            None,
            None,
        )
1667
1668
1669


def attn_forward_func_with_cp(
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
    is_training,
    q,
    k,
    v,
    cu_seqlens_q,
    cu_seqlens_k,
    max_seqlen_q,
    max_seqlen_k,
    seq_offsets_q,
    seq_offsets_k,
    seq_offsets_v,
    seq_offsets_o,
    dropout_p,
    cp_group,
    cp_global_ranks,
    cp_stream,
    softmax_scale=None,
    qkv_format="bshd",
    attn_mask_type="causal",
    attn_bias_type="no_bias",
    attn_bias=None,
    deterministic=False,
    use_fused_attention=False,
1693
1694
) -> torch.Tensor:
    """Attention implementation with context parallelism"""
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
    assert qkv_format in [
        "bshd",
        "sbhd",
        "thd",
    ], f"QKV format of {qkv_format} is not supported with context parallelism!"
    assert (
        qkv_format != "sbhd" or use_fused_attention
    ), "FlashAttention does not support sbhd format!"
    assert (
        qkv_format != "thd"
        or not use_fused_attention
        or attn_mask_type in ["padding", "padding_causal"]
    ), (
        f"Context parallelism is not supported for {attn_mask_type} mask type and "
        f"{qkv_format} format with {'FusedAttention' if use_fused_attention else 'FlashAttention'}!"
    )
    assert attn_bias is None or (use_fused_attention and "padding" not in attn_mask_type), (
        """Attention bias is only supported with FusedAttention and "causal" """
        """or "no_mask" mask types!"""
    )
1715
    out = AttnFuncWithCP.apply(
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
        is_training,
        q,
        k,
        v,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        seq_offsets_q,
        seq_offsets_k,
        seq_offsets_v,
        seq_offsets_o,
        dropout_p,
        cp_group,
        cp_global_ranks,
        cp_stream,
        softmax_scale,
        qkv_format,
        attn_mask_type,
        attn_bias_type,
        attn_bias,
        deterministic,
        use_fused_attention,
1739
1740
1741
1742
    )
    return out


1743
1744
1745
1746
class RotaryPositionEmbedding(torch.nn.Module):
    """
    Implements Rotary Position Embedding from https://arxiv.org/abs/2104.09864.
    """
1747

1748
1749
1750
    def __init__(
        self,
        dim: int,
1751
        rotary_percent: float = 1.0,
1752
1753
1754
1755
1756
1757
1758
1759
        seq_len_interpolation_factor: Optional[int] = None,
        pretrained_max_position_embeddings: Optional[int] = None,
    ):
        """
        Parameters
        ----------
        dim: int
            rotary embedding dimension
1760
1761
        rotary_percent: float
            Percent of rotary dimension to use for rotary position embeddings.
1762
1763
1764
1765
1766
1767
1768
        seq_len_interpolation_factor: int
            if not None, discrete positions will be interpolated by this factor via the trick in
            https://arxiv.org/abs/2306.15595
        pretrained_max_position_embeddings: int
            pre-trained max_position_embeddings before position interpolation
        """
        super().__init__()
1769
1770
        if rotary_percent < 1.0:
            dim = int(dim * rotary_percent)
1771
        self.seq_len_interpolation_factor = seq_len_interpolation_factor
1772
1773
1774
1775
1776
1777
1778
        inv_freq = 1.0 / (
            10000
            ** (
                torch.arange(0, dim, 2, dtype=torch.float32, device=torch.cuda.current_device())
                / dim
            )
        )
1779
        self.register_buffer("inv_freq", inv_freq)
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
        self.pretrained_max_position_embeddings = pretrained_max_position_embeddings

    def forward(self, max_seq_len: int, offset: int = 0):
        """
        Create rotary position embedding frequencies

        Parameters
        ----------
        max_seq_len: int
            sequence length of a sample
        offset: int, default = 0
            fixed offset for freqencies
        """
1793
1794
1795
1796
        seq = (
            torch.arange(max_seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
            + offset
        )
1797

1798
1799
1800
1801
1802
1803
1804
1805
        if (
            self.pretrained_max_position_embeddings is not None
            and self.seq_len_interpolation_factor is not None
        ):
            if (
                max_seq_len
                > self.pretrained_max_position_embeddings * self.seq_len_interpolation_factor
            ):
1806
1807
1808
1809
1810
1811
                # dynamic linear scaling (length > position we have learned)
                seq *= 1 / (max_seq_len / self.pretrained_max_position_embeddings)
            else:
                # fixed linear scaling
                seq *= 1 / self.seq_len_interpolation_factor

1812
        freqs = torch.einsum("i , j -> i j", seq, self.inv_freq)
1813
1814
1815
1816
1817
1818
        # first part even vector components, second part odd vector components,
        #  2 * dim in dimension size
        emb = torch.cat((freqs, freqs), dim=-1)
        # emb [seq_length, .., dim]
        return emb.reshape(emb.size(0), 1, 1, emb.size(1))

1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836

class FusedRoPEFunc(torch.autograd.Function):
    """
    Function for FusedRoPE

    This implementation assumes the input tensor to be in `sbhd`, `bshd` or `thd` format and
    the RoPE tensor to be of shape (s, 1, 1, d). It accepts arbitrary memory layouts to avoid
    the expensive `.contiguous()` calls, thus it may not achieve the best memory access pattern.
    """

    @staticmethod
    def forward(
        ctx,
        t: torch.Tensor,
        freqs: torch.Tensor,
        tensor_format: str = "sbhd",
        cu_seqlens: Union[torch.Tensor, None] = None,
    ) -> torch.Tensor:
1837
1838
        if freqs.dtype != torch.float32:
            freqs = freqs.float()
1839
1840
1841
        if tensor_format == "sbhd":
            output = tex.fused_rope_forward(t, freqs, False)
        elif tensor_format == "bshd":
1842
            output = tex.fused_rope_forward(t.transpose(0, 1), freqs, True).transpose(0, 1)
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
        elif tensor_format == "thd":
            output = tex.fused_rope_thd_forward(t, cu_seqlens, freqs)
        else:
            raise ValueError(f"Unsupported tensor_format: {tensor_format}.")
        ctx.save_for_backward(freqs, cu_seqlens)
        ctx.tensor_format = tensor_format

        return output

    @staticmethod
1853
    def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
        freqs, cu_seqlens = ctx.saved_tensors
        if ctx.tensor_format == "sbhd":
            grad_input = tex.fused_rope_backward(grad_output, freqs, False)
        elif ctx.tensor_format == "bshd":
            grad_input = tex.fused_rope_backward(
                grad_output.transpose(0, 1), freqs, True
            ).transpose(0, 1)
        elif ctx.tensor_format == "thd":
            grad_input = tex.fused_rope_thd_backward(grad_output, cu_seqlens, freqs)
        else:
            raise ValueError(f"Unsupported tensor_format: {ctx.tensor_format}.")

        return grad_input, None, None, None, None


1869
1870
1871
1872
1873
1874
1875
1876
1877
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
    """
    change sign so the last dimension becomes [-odd, +even]
    """
    x = x.view(x.shape[:-1] + torch.Size((2, x.shape[-1] // 2)))
    x1, x2 = x.unbind(dim=-2)
    return torch.cat((-x2, x1), dim=-1)


1878
def apply_rotary_pos_emb(
1879
1880
1881
1882
1883
1884
    t: torch.Tensor,
    freqs: torch.Tensor,
    tensor_format: str = "sbhd",
    fused: bool = False,
    cu_seqlens: Union[torch.Tensor, None] = None,
) -> torch.Tensor:
1885
    """
1886
    Apply rotary positional embedding tensor to the input tensor.
1887

1888
1889
1890
    Parameters
    ----------
    t: torch.Tensor
1891
        Input tensor of shape `[s, b, h, d]`, `[b, s, h, d]` or `[t, h, d]`, on which
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
        rotary positional embedding will be applied.
    freqs: torch.Tensor
        Rotary positional embedding tensor of shape `[s2, 1, 1, d2]` and dtype 'float',
        with `s2 >= s` and `d2 <= d`.
    fused: bool, default = False
        Whether to use a fused applying RoPE implementation.
    tensor_format: {'sbhd', 'bshd', 'thd'}, default = 'sbhd'
        is `bshd` if `t` is of shape `[bs, seq, ...]`, or `sbhd` if `t` is
        of shape `[seq, bs, ...]`. 'thd' is only supported when `fused` is True.
    cu_seqlens: torch.Tensor, default = None.
        Cumulative sum of sequence lengths in a batch for `t`, with shape [b + 1] and
        dtype torch.int32. Only valid when `tensor_format` is 'thd'.
1904
    """
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
    if fused:
        assert (
            tensor_format != "thd" or cu_seqlens is not None
        ), "cu_seqlens must not be None when tensor_format is 'thd'."
        return FusedRoPEFunc.apply(t, freqs, tensor_format, cu_seqlens)

    assert tensor_format in ("sbhd", "bshd"), (
        "Only formats `sbhd` or `bshd` are supported for input tensor `t` "
        f"when fused is False, got {tensor_format}."
    )

1916
1917
1918
1919
1920
    max_seq_len = freqs.shape[0]
    cur_seq_len = t.shape[1] if tensor_format == "bshd" else t.shape[0]

    # Only apply the rotary embeddings up to the sequence length of the running
    # input.
1921
1922
1923
    assert (
        cur_seq_len <= max_seq_len
    ), f"Rotary Embeddings only supported up to {max_seq_len} sequence length!"
1924
    freqs = freqs[:cur_seq_len]
1925
    if tensor_format == "bshd":
1926
1927
1928
1929
        freqs = freqs.transpose(0, 1)  # [seq, 1, 1, dim] -> [1, seq, 1, dim]
    # cos/sin first then dtype conversion for better precision
    cos_ = torch.cos(freqs).to(t.dtype)
    sin_ = torch.sin(freqs).to(t.dtype)
1930

1931
1932
1933
1934
1935
1936
    rot_dim = freqs.shape[-1]
    # ideally t_pass is empty so rotary pos embedding is applied to all tensor t
    t, t_pass = t[..., :rot_dim], t[..., rot_dim:]

    # first part is cosine component
    # second part is sine component, need to change signs with _rotate_half method
1937
    t = (t * cos_) + (_rotate_half(t) * sin_)
1938
1939
1940
    return torch.cat((t, t_pass), dim=-1)


cyanguwa's avatar
cyanguwa committed
1941
class _SplitAlongDim(torch.autograd.Function):
1942
1943
1944
    """"""

    @staticmethod
1945
1946
1947
1948
1949
    def forward(
        ctx,
        mixed_x_layer: torch.Tensor,
        split_dim: int,
        split_size_or_sections: Union[int, List[int], Tuple[int]],
1950
    ) -> Tuple[torch.Tensor, ...]:
cyanguwa's avatar
cyanguwa committed
1951
1952
        ctx.split_dim = split_dim
        ctx.split_size_or_sections = split_size_or_sections
1953
        if isinstance(mixed_x_layer, Float8Tensor):
1954
1955
1956
1957
1958
1959
            return tuple(
                Float8Tensor.make_like(
                    mixed_x_layer,
                    data=x,
                )
                for x in torch.split(
1960
1961
                    mixed_x_layer._data,
                    split_size_or_sections=split_size_or_sections,
1962
1963
1964
1965
                    dim=split_dim,
                )
            )
        return torch.split(mixed_x_layer, split_size_or_sections, dim=split_dim)
1966
1967

    @staticmethod
1968
    def backward(ctx, *grad_outputs):
1969
1970
        assert len(grad_outputs) > 0, "No gradients received for backprop!"

cyanguwa's avatar
cyanguwa committed
1971
1972
        if isinstance(ctx.split_size_or_sections, (list, tuple)):
            split_sizes = ctx.split_size_or_sections
1973
1974
1975
            assert len(grad_outputs) == len(
                split_sizes
            ), "Unequal number of gradients vs split sections for backprop!"
cyanguwa's avatar
cyanguwa committed
1976
1977
1978
1979
1980
        if isinstance(ctx.split_size_or_sections, int):
            split_sizes = [ctx.split_size_or_sections] * len(grad_outputs)
        dims = len(grad_outputs[0].shape)
        split_dim = (ctx.split_dim + dims) % dims

1981
1982
1983
1984
1985
1986
1987
1988
        if isinstance(grad_outputs[0], Float8Tensor):
            noop_ok = True
            strides = grad_outputs[0].stride()
            data_ptr = grad_outputs[0]._data.untyped_storage().data_ptr()
            shape = list(grad_outputs[0].shape)
            for i, tensor in enumerate(grad_outputs):
                shape_i = shape
                shape_i[split_dim] = split_sizes[i]
1989
1990
1991
1992
1993
1994
1995
                offset_size = sum(split_sizes[:i]) * np.prod(shape[split_dim + 1 :])
                if (
                    tensor.stride() != strides
                    or list(tensor.shape) != shape_i
                    or tensor._data.untyped_storage().data_ptr() != data_ptr
                    or tensor.storage_offset() != offset_size
                ):
1996
1997
1998
                    noop_ok = False
                    break
            if noop_ok:
1999
2000
2001
                ret = torch.Tensor().to(
                    device=grad_outputs[0].device, dtype=grad_outputs[0]._data.dtype
                )
2002
2003
                new_shape = list(shape)
                new_shape[split_dim] = sum(split_sizes)
2004
2005
2006
2007
2008
                ret.set_(
                    grad_outputs[0]._data.untyped_storage(),
                    grad_outputs[0]._data.storage_offset(),
                    new_shape,
                    strides,
2009
2010
2011
2012
                )
                return Float8Tensor.make_like(grad_outputs[0], data=ret), None, None

            grad_outputs_data = [x._data for x in grad_outputs]
2013
2014
2015
2016
2017
2018
2019
            return (
                Float8Tensor.make_like(
                    grad_outputs[0], data=torch.cat(grad_outputs_data, dim=split_dim)
                ),
                None,
                None,
            )
2020
2021
        noop_ok = True
        strides = grad_outputs[0].stride()
2022
        data_ptr = grad_outputs[0].untyped_storage().data_ptr()
cyanguwa's avatar
cyanguwa committed
2023
        shape = list(grad_outputs[0].shape)
2024
        for i, tensor in enumerate(grad_outputs):
cyanguwa's avatar
cyanguwa committed
2025
2026
            shape_i = shape
            shape_i[split_dim] = split_sizes[i]
2027
2028
2029
2030
2031
2032
2033
            offset_size = sum(split_sizes[:i]) * np.prod(shape[split_dim + 1 :])
            if (
                tensor.stride() != strides
                or list(tensor.shape) != shape_i
                or tensor.untyped_storage().data_ptr() != data_ptr
                or tensor.storage_offset() != offset_size
            ):
2034
2035
2036
                noop_ok = False
                break
        if noop_ok:
2037
            ret = torch.Tensor().to(device=grad_outputs[0].device, dtype=grad_outputs[0].dtype)
2038
            new_shape = list(shape)
cyanguwa's avatar
cyanguwa committed
2039
            new_shape[split_dim] = sum(split_sizes)
2040
2041
2042
2043
2044
            ret.set_(
                grad_outputs[0].untyped_storage(),
                grad_outputs[0].storage_offset(),
                new_shape,
                strides,
2045
            )
cyanguwa's avatar
cyanguwa committed
2046
            return ret, None, None
2047

2048
        return torch.cat(grad_outputs, dim=split_dim), None, None
2049
2050
2051
2052
2053
2054
2055
2056
2057


class UnfusedDotProductAttention(torch.nn.Module):
    """Parallel attention w/o QKV and Proj Gemms
    BMM1 -> softmax + dropout -> BMM2
    """

    def __init__(
        self,
2058
        softmax_scale: float,
2059
2060
2061
2062
2063
2064
        attention_dropout: float = 0.0,
        attention_dropout_ctx: Optional[Callable] = nullcontext,
        layer_number: Optional[int] = None,
    ) -> None:
        super().__init__()

2065
        self.softmax_scale = softmax_scale
2066
2067
2068
        self.attention_dropout_ctx = attention_dropout_ctx
        self.layer_number = layer_number

2069
        self.scale_mask_softmax = FusedScaleMaskSoftmax(attention_mask_func)
2070
2071
2072
2073
2074
2075

        # Dropout. Note that for a single iteration, this layer will generate
        # different outputs on different number of parallel partitions but
        # on average it should not be partition dependent.
        self.attention_dropout = torch.nn.Dropout(attention_dropout)

2076
2077
        # An FP16 training trick required for certain GPT-like models.
        self.apply_qk_layer_scaling = (
2078
2079
            bool(int(os.getenv("NVTE_APPLY_QK_LAYER_SCALING", "0"))) and layer_number is not None
        )
2080

2081
2082
2083
2084
2085
    def forward(
        self,
        query_layer: torch.Tensor,
        key_layer: torch.Tensor,
        value_layer: torch.Tensor,
2086
        qkv_layout: str = "sbh3d",
2087
2088
        cu_seqlens_q: Optional[torch.Tensor] = None,  # pylint: disable=unused-argument
        cu_seqlens_kv: Optional[torch.Tensor] = None,  # pylint: disable=unused-argument
2089
        attn_mask_type: str = "causal",
2090
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
2091
2092
        core_attention_bias_type: str = "no_bias",
        core_attention_bias: Optional[torch.Tensor] = None,
2093
        alibi_slopes: Optional[torch.Tensor] = None,
2094
    ) -> torch.Tensor:
2095
        """Unfused attention fprop"""
2096

2097
2098
2099
2100
2101
        assert (
            qkv_layout in QKVLayouts
        ), f"UnfusedDotProductAttention does not support qkv_layout = {qkv_layout}!"
        qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
        if qkv_format == "bshd":
2102
            # convert to sbhd and use sbhd implementation for now
2103
2104
2105
            query_layer, key_layer, value_layer = [
                x.transpose(0, 1) for x in [query_layer, key_layer, value_layer]
            ]
2106

2107
        batch_size, seqlen = query_layer.shape[1], query_layer.shape[0]
2108
        apply_qk_layer_scaling = self.apply_qk_layer_scaling and key_layer.dtype == torch.float16
2109
2110
2111
2112
2113
2114
2115
2116
2117

        # [b, np, sq, sk]
        output_size = (
            query_layer.size(1),
            query_layer.size(2),
            query_layer.size(0),
            key_layer.size(0),
        )

2118
        if key_layer.shape[2] != query_layer.shape[2]:
2119
2120
2121
            assert (
                query_layer.shape[2] % key_layer.shape[2] == 0
            ), "The number of attention heads must be divisible by the number of GQA groups!"
2122
            key_layer = key_layer.repeat_interleave(
2123
2124
                int(query_layer.shape[2] / key_layer.shape[2]), dim=2
            )
2125
            value_layer = value_layer.repeat_interleave(
2126
2127
                int(query_layer.shape[2] / value_layer.shape[2]), dim=2
            )
2128

2129
        # [sq, b, np, hn] -> [sq, b * np, hn]
2130
        query_layer = query_layer.reshape(output_size[2], output_size[0] * output_size[1], -1)
2131
2132
2133
2134
        # [sk, b, np, hn] -> [sk, b * np, hn]
        key_layer = key_layer.reshape(output_size[3], output_size[0] * output_size[1], -1)

        # preallocting result tensor: [b * np, sq, sk]
2135
2136
        # WAR to set dtype to FP32 as ONNX lacks BF16 support for ConstantOfShape operator
        is_bf16 = query_layer.dtype == torch.bfloat16
2137
2138
2139
2140
        matmul_result = torch.empty(
            output_size[0] * output_size[1],
            output_size[2],
            output_size[3],
2141
            dtype=torch.float32 if is_in_onnx_export_mode() and is_bf16 else query_layer.dtype,
2142
2143
2144
            device=torch.cuda.current_device(),
        )

2145
2146
2147
        if is_in_onnx_export_mode() and is_bf16:
            matmul_result = matmul_result.bfloat16()

2148
        scale = self.softmax_scale
2149
        if apply_qk_layer_scaling:
2150
            scale /= self.layer_number
2151
2152

        # Raw attention scores. [b * np, sq, sk]
2153
2154
2155
2156
2157
2158
        if core_attention_bias_type == "no_bias":
            matmul_result = torch.baddbmm(
                matmul_result,
                query_layer.transpose(0, 1),  # [b * np, sq, hn]
                key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
                beta=0.0,
2159
                alpha=scale,
2160
2161
2162
2163
2164
2165
2166
2167
            )

        elif core_attention_bias_type == "pre_scale_bias":
            assert core_attention_bias is not None, "core_attention_bias should not be None!"
            matmul_result = torch.bmm(
                query_layer.transpose(0, 1),  # [b * np, sq, hn]
                key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
            )
2168
2169
2170
2171
            matmul_result = (
                matmul_result.view(output_size[0], output_size[1], output_size[2], output_size[3])
                + core_attention_bias
            ).view(-1, output_size[2], output_size[3])
2172
            matmul_result *= scale
2173

2174
2175
2176
2177
        elif core_attention_bias_type in ["post_scale_bias", "alibi"]:
            if core_attention_bias_type == "post_scale_bias":
                assert core_attention_bias is not None, "core_attention_bias should not be None!"
            if core_attention_bias_type == "alibi":
2178
                _, core_attention_bias = get_alibi(
2179
2180
                    output_size[1], output_size[2], output_size[3], alibi_slopes=alibi_slopes
                )
2181
2182
2183
2184
2185
            matmul_result = torch.baddbmm(
                matmul_result,
                query_layer.transpose(0, 1),  # [b * np, sq, hn]
                key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
                beta=0.0,
2186
                alpha=scale,
2187
            )
2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
            matmul_result = (
                (
                    matmul_result.view(
                        output_size[0], output_size[1], output_size[2], output_size[3]
                    )
                    + core_attention_bias
                )
                .view(-1, output_size[2], output_size[3])
                .to(dtype=query_layer.dtype)
            )
2198
2199
2200
2201
2202
2203

        # change view to [b, np, sq, sk]
        attention_scores = matmul_result.view(*output_size)

        # attention scores and attention mask [b, np, sq, sk]
        softmax_scale = self.layer_number if apply_qk_layer_scaling else None
2204
        attention_probs = self.scale_mask_softmax(
2205
2206
            attention_scores, attention_mask, attn_mask_type, softmax_scale
        )
2207
2208
2209
2210
2211
2212
2213
2214
2215
2216
2217
2218
2219
2220
2221
2222

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        with self.attention_dropout_ctx():
            attention_probs = self.attention_dropout(attention_probs)

        # value_layer -> context layer.
        # [sk, b, np, hn] --> [b, np, sq, hn]
        output_size = (
            value_layer.size(1),
            value_layer.size(2),
            query_layer.size(0),
            value_layer.size(3),
        )

        # change view [sk, b * np, hn]
2223
        value_layer = value_layer.reshape(value_layer.size(0), output_size[0] * output_size[1], -1)
2224
2225

        # change view [b * np, sq, sk]
2226
        attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
2227
2228
2229
2230
2231
2232
2233

        # matmul: [b * np, sq, hn]
        context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))

        # change view [b, np, sq, hn]
        context_layer = context_layer.view(*output_size)

2234
        if qkv_format == "sbhd":
2235
2236
            # [b, np, sq, hn] --> [sq, b, np, hn]
            context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
2237

2238
2239
2240
            # [sq, b, np, hn] --> [sq, b, hp]
            context_layer = context_layer.view(seqlen, batch_size, -1)

2241
        if qkv_format == "bshd":
2242
2243
2244
2245
2246
            # [b, np, sq, hn] --> [b, sq, np, hn]
            context_layer = context_layer.permute(0, 2, 1, 3).contiguous()

            # [b, sq, np, hn] --> [b, sq, hp]
            context_layer = context_layer.view(batch_size, seqlen, -1)
2247
2248
2249
2250
2251
2252

        return context_layer


class _PrepareQKVForFA(torch.autograd.Function):
    """This class converts QKV from interleaved (s, b, ...) layout
2253
    to separate contiguous q, k, v tensors in (b, s, ...) layout."""
2254
2255

    @staticmethod
2256
2257
2258
2259
    def forward(
        _ctx: torch.autograd.function.FunctionCtx,  # unused
        query_layer: torch.Tensor,
        key_layer: torch.Tensor,
2260
        value_layer: torch.Tensor,
2261
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
2262
2263
2264
2265
2266
2267
2268
2269
2270
2271
2272
        # All inputs received are non-contiguous tensors.
        # The `query_layer` tensor is used to access the
        # full memory region of the QKV tensor.
        qkv = tex.fa_prepare_fwd(query_layer)
        q, k, v = split_tensor_along_dim(qkv, 0, 3)
        query_layer = torch.squeeze(q, 0)
        key_layer = torch.squeeze(k, 0)
        value_layer = torch.squeeze(v, 0)
        return query_layer, key_layer, value_layer

    @staticmethod
2273
2274
2275
2276
    def backward(
        _ctx: torch.autograd.function.FunctionCtx,  # unused
        dq: torch.Tensor,
        dk: torch.Tensor,
2277
        dv: torch.Tensor,
2278
2279
2280
2281
2282
    ) -> Tuple[Union[torch.Tensor, None], ...]:
        dqkv = tex.fa_prepare_bwd(dq, dk, dv)
        dq, dk, dv = split_tensor_along_dim(dqkv, -1, 3)
        return dq, dk, dv

2283

2284
def _get_qkv_layout(
2285
2286
2287
2288
2289
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    qkv_format: str = "sbhd",
) -> str:
2290
    """Get qkv layout.
2291

2292
2293
2294
2295
2296
2297
2298
2299
2300
2301
2302
2303
2304
2305
2306
2307
2308
2309
2310
2311
2312
2313
2314
2315
2316
2317
2318
2319
    Parameters
    ----------
    q: torch.Tensor
        Query tensor.
    k: torch.Tensor
        Key tensor.
    v: torch.Tensor
        Value tensor.
    qkv_format: str, default = `sbhd`
        Dimension format for `q`, `k` and `v`, {`sbhd`, `bshd`, `thd`}. `s` stands for
        the sequence length dimension, `b` batch size, `h` the number of attention heads,
        `d` head size, and `t` the total number of sequences in a batch, i.e.
        `t = sum(s_i) for i = 0...b-1`.

    Returns
    ----------
    qkv_layout: str
       Memory layout of `q`, `k` and `v`. Each `qkv_format` can be mapped to one of five
       memory layouts. For example, `sb3hd` means `q`, `k`, `v` are created as one chunk
       of memory and that they are interleaved in the `2`nd dimension. `sbhd_sbh2d` means
       `q` and `kv` are created in two chunks and that `q` itself is contiguous and `k`, `v`
       are interleaved with each other in the `3`rd dimension, `k = kv[:,:,:,0,:]` and
       `v = kv[:,:,:,1,:]`.
       Mapping:
       `sbhd`: {`sb3hd`, `sbh3d`, `sbhd_sb2hd`, `sbhd_sbh2d`, `sbhd_sbhd_sbhd`}
       `bshd`: {`bs3hd`, `bsh3d`, `bshd_bs2hd`, `bshd_bsh2d`, `bshd_bshd_bshd`}
       `thd` : {`t3hd`, `th3d`, `thd_t2hd`, `thd_th2d`, `thd_thd_thd`}
    """
2320

2321
2322
    check_last_dim_contiguous = all(x.stride(-1) == 1 for x in [q, k, v])
    assert check_last_dim_contiguous, "q, k and v must have stride 1 in their last dimension!"
2323

2324
2325
2326
2327
2328
2329
2330
2331
2332
2333
2334
2335
2336
2337
2338
2339
2340
    def run_iteratively(q, k, v):
        data_ptr = q.untyped_storage().data_ptr()
        check_ptrs_qkv = all(x.untyped_storage().data_ptr() == data_ptr for x in [q, k, v])
        data_ptr = k.untyped_storage().data_ptr()
        check_ptrs_kv = all(x.untyped_storage().data_ptr() == data_ptr for x in [k, v])

        stride = q.stride()
        check_strides_qkv = all(stride == x.stride() for x in [q, k, v])
        stride = k.stride()
        check_strides_kv = all(stride == x.stride() for x in [k, v])

        shape = q.shape
        check_shapes_qkv = all(shape == x.shape for x in [q, k, v])
        shape = k.shape
        check_shapes_kv = all(shape == x.shape for x in [k, v])

        last_dim_size = q.shape[-1]
2341
2342
2343
        check_last_dim_offsets_qkv = all(
            i * last_dim_size == x.storage_offset() for i, x in enumerate([q, k, v])
        )
2344
        last_dim_size = k.shape[-1]
2345
2346
2347
        check_last_dim_offsets_kv = all(
            i * last_dim_size == x.storage_offset() for i, x in enumerate([k, v])
        )
2348
2349

        last_two_dims_size = q.shape[-1] * q.shape[-2]
2350
2351
2352
        check_last_two_dims_offsets_qkv = all(
            i * last_two_dims_size == x.storage_offset() for i, x in enumerate([q, k, v])
        )
2353
        last_two_dims_size = k.shape[-1] * k.shape[-2]
2354
2355
2356
        check_last_two_dims_offsets_kv = all(
            i * last_two_dims_size == x.storage_offset() for i, x in enumerate([k, v])
        )
2357

2358
2359
2360
2361
        if (
            check_ptrs_qkv
            and check_strides_qkv
            and check_shapes_qkv
2362
            and check_last_two_dims_offsets_qkv
2363
2364
            and not check_last_dim_offsets_qkv
        ):
2365
            # sb3hd, bs3hd, t3hd
2366
2367
2368
2369
            qkv_layout = qkv_format[:-2] + "3" + qkv_format[-2:]
        elif (
            check_ptrs_qkv and check_strides_qkv and check_shapes_qkv and check_last_dim_offsets_qkv
        ):
2370
            # sbh3d, bsh3d, th3d
2371
2372
2373
2374
2375
            qkv_layout = qkv_format[:-1] + "3" + qkv_format[-1:]
        elif (
            check_ptrs_kv
            and check_strides_kv
            and check_shapes_kv
2376
            and check_last_two_dims_offsets_kv
2377
2378
            and not check_last_dim_offsets_kv
        ):
2379
            # sbhd_sb2hd, bshd_bs2hd, thd_t2hd
2380
2381
            qkv_layout = qkv_format + "_" + qkv_format[:-2] + "2" + qkv_format[-2:]
        elif check_ptrs_kv and check_strides_kv and check_shapes_kv and check_last_dim_offsets_kv:
2382
            # sbhd_sbh2d, bshd_bsh2d, thd_th2d
2383
            qkv_layout = qkv_format + "_" + qkv_format[:-1] + "2" + qkv_format[-1:]
2384
2385
        elif check_strides_kv and check_shapes_kv:
            # sbhd_sbhd_sbhd, bshd_bshd_bshd, thd_thd_thd
2386
            qkv_layout = "_".join(list([qkv_format]) * 3)
2387
        else:
2388
            qkv_layout = "not_supported"
2389
2390
2391
2392

        return qkv_layout

    qkv_layout = run_iteratively(q, k, v)
2393
    if qkv_layout == "not_supported":
2394
2395
2396
        # force q,k,v to be contiguous and run get_layout again
        q, k, v = [x.contiguous() for x in [q, k, v]]
        qkv_layout = run_iteratively(q, k, v)
2397
    if qkv_layout == "not_supported":
2398
2399
        raise Exception("The provided qkv memory layout is not supported!")

2400
    return qkv_layout, q, k, v
2401

2402

2403
def check_set_window_size(
2404
2405
2406
    attn_mask_type: str,
    window_size: Tuple[int, int] = None,
):
2407
2408
2409
2410
2411
2412
2413
2414
2415
2416
2417
2418
2419
2420
    """Check if sliding window size is compliant with mask type and if not,
    assert or set it to the appropriate size
    """
    if "causal" in attn_mask_type:
        if window_size is None:
            window_size = (-1, 0)
        else:
            assert (
                window_size[1] == 0
            ), "window_size[1] should be 0 when self_attn_mask_type includes 'causal'!"
    else:
        if window_size is None:
            window_size = (-1, -1)
    return window_size
2421

2422

2423
class FlashAttention(torch.nn.Module):
2424
    """Dot product attention, using HazyResearch flash-attn package:
2425
    https://github.com/Dao-AILab/flash-attention
2426
2427
2428
2429
    """

    def __init__(
        self,
2430
        softmax_scale: float,
2431
2432
        attention_dropout: float = 0.0,
        attention_dropout_ctx: Optional[Callable] = nullcontext,
2433
2434
        attention_type: str = "self",
        layer_number: Optional[int] = None,
2435
        deterministic: bool = False,
2436
2437
2438
2439
2440
2441
    ) -> None:
        super().__init__()

        assert (
            _flash_attn_version >= _flash_attn_version_required
        ), f"FlashAttention minimum version {_flash_attn_version_required} is required."
2442
2443
2444
        assert (
            _flash_attn_version <= _flash_attn_max_version
        ), f"FlashAttention maximum version {_flash_attn_max_version} is supported."
2445

2446
        self.softmax_scale = softmax_scale
2447
2448
        self.attention_dropout_ctx = attention_dropout_ctx
        self.attention_dropout = attention_dropout
2449
2450
        self.attention_type = attention_type
        self.layer_number = 1 if layer_number is None else layer_number
2451
        self.deterministic = deterministic
2452
2453
2454
2455
2456
2457

    def forward(
        self,
        query_layer: torch.Tensor,
        key_layer: torch.Tensor,
        value_layer: torch.Tensor,
2458
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
2459
2460
2461
        qkv_layout: str = "sbh3d",
        cu_seqlens_q: Optional[torch.Tensor] = None,
        cu_seqlens_kv: Optional[torch.Tensor] = None,
2462
2463
        max_seqlen_q: Optional[int] = None,
        max_seqlen_kv: Optional[int] = None,
2464
        attn_mask_type: str = "causal",
2465
        window_size: Optional[Tuple[int, int]] = None,
2466
        alibi_slopes: Optional[torch.Tensor] = None,
2467
        cp_group: Optional[dist_group_type] = None,
2468
        cp_global_ranks: List[int] = None,
2469
        cp_stream: torch.cuda.Stream = None,
2470
2471
2472
    ) -> torch.Tensor:
        """flash-attn fprop"""

2473
2474
        window_size = check_set_window_size(attn_mask_type, window_size)

2475
        assert (
2476
2477
2478
            query_layer.dtype in [torch.float16, torch.bfloat16]
            and key_layer.dtype in [torch.float16, torch.bfloat16]
            and value_layer.dtype in [torch.float16, torch.bfloat16]
2479
        ), "FlashAttention currently only supports FP16 and BF16."
2480
2481
        assert (
            query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda
2482
        ), "FlashAttention currently only supports CUDA tensors."
2483
2484
        assert (
            qkv_layout in QKVLayouts
2485
        ), f"FlashAttention does not support qkv_layout = {qkv_layout}!"
2486

2487
2488
        context_parallel = (cp_group is not None) and (get_distributed_world_size(cp_group) != 1)

2489
        qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
2490

2491
        if qkv_format == "sbhd":
2492
            # For now just 128, will make it more general in the future
2493
2494
2495
2496
2497
2498
2499
2500
            if (
                query_layer.shape[-1] == 128
                and query_layer.shape[0] * query_layer.shape[1] >= 512
                and qkv_layout == "sbh3d"
            ):
                query_layer, key_layer, value_layer = _PrepareQKVForFA.apply(
                    query_layer, key_layer, value_layer
                )
2501
            else:
2502
2503
2504
2505
2506
2507
2508
                query_layer, key_layer, value_layer = [
                    x.transpose(0, 1).contiguous() for x in (query_layer, key_layer, value_layer)
                ]
        elif qkv_format in ["bshd", "thd"]:
            query_layer, key_layer, value_layer = [
                x.contiguous() for x in (query_layer, key_layer, value_layer)
            ]
2509

2510
        batch_size = query_layer.shape[0]
2511

2512
        if qkv_format in ["sbhd", "bshd"]:
2513
            max_seqlen_q, max_seqlen_kv = query_layer.shape[1], key_layer.shape[1]
2514
2515
2516
2517
2518
2519
2520
            if not context_parallel:
                # [b * s, h, d]
                query_layer, key_layer, value_layer = [
                    x.view(x.shape[0] * x.shape[1], *x.shape[2:])
                    for x in [query_layer, key_layer, value_layer]
                ]

2521
            if "padding" in attn_mask_type:
2522
                assert not context_parallel, "Padding mask not supported with context parallelism!"
2523
2524
2525
2526
2527

                if self.attention_type == "self":
                    assert (
                        max_seqlen_q == max_seqlen_kv
                    ), "Maximum sequence length for Q and KV should be the same."
2528
                    if cu_seqlens_q is None:
2529
2530
2531
                        assert (
                            attention_mask is not None
                        ), "Please provide attention_mask for padding!"
2532
2533
2534
2535
2536
2537
                        cu_seqlens_q, indices_q = get_cu_seqlens_and_indices(attention_mask)
                    else:
                        indices_q = get_indices(max_seqlen_q, cu_seqlens_q)
                    cu_seqlens_kv = cu_seqlens_q
                    query_layer, key_layer, value_layer = PackTensors.apply(
                        indices_q, query_layer, key_layer, value_layer
2538
2539
                    )
                else:
2540
                    if cu_seqlens_q is None or cu_seqlens_kv is None:
2541
2542
2543
2544
2545
                        assert (
                            attention_mask is not None
                        ), "Please provide attention_mask for padding!"
                        cu_seqlens_q, indices_q = get_cu_seqlens_and_indices(attention_mask[0])
                        cu_seqlens_kv, indices_kv = get_cu_seqlens_and_indices(attention_mask[1])
2546
2547
2548
2549
                    else:
                        indices_q = get_indices(max_seqlen_q, cu_seqlens_q)
                        indices_kv = get_indices(max_seqlen_kv, cu_seqlens_kv)
                    query_layer = PackTensors.apply(indices_q, query_layer)
2550
                    key_layer, value_layer = PackTensors.apply(indices_kv, key_layer, value_layer)
2551
            else:
2552
2553
2554
2555
2556
2557
2558
2559
2560
2561
2562
2563
2564
                # Cumulative sequence lengths for unpadded data
                if cu_seqlens_q is None:
                    cu_seqlens_q = _get_full_cu_seqlens(
                        batch_size,
                        max_seqlen_q,
                        query_layer.device,
                    )
                if cu_seqlens_kv is None:
                    cu_seqlens_kv = _get_full_cu_seqlens(
                        batch_size,
                        max_seqlen_kv,
                        key_layer.device,
                    )
2565
2566
2567
2568
        elif qkv_format == "thd":
            assert (
                cu_seqlens_q is not None and cu_seqlens_kv is not None
            ), "cu_seqlens_q and cu_seqlens_kv can not be None when qkv_format = thd!"
2569
2570
2571
2572
2573
2574
            if max_seqlen_q is None:
                seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
                max_seqlen_q = seqlens_q.max().item()
            if max_seqlen_kv is None:
                seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
                max_seqlen_kv = seqlens_kv.max().item()
2575

2576
        if context_parallel:
2577
2578
2579
2580
            assert window_size in (
                (-1, -1),
                (-1, 0),
            ), "Sliding window attention is not supported with context parallelism."
2581
2582
2583
            assert (
                alibi_slopes is None
            ), "Alibi slope bias addition is not supported with context parallelism."
2584
            with self.attention_dropout_ctx():
2585
                output = attn_forward_func_with_cp(
2586
2587
2588
2589
2590
2591
2592
2593
2594
2595
2596
2597
                    self.training,
                    query_layer,
                    key_layer,
                    value_layer,
                    cu_seqlens_q,
                    cu_seqlens_kv,
                    max_seqlen_q,
                    max_seqlen_kv,
                    None,
                    None,
                    None,
                    None,
2598
                    self.attention_dropout if self.training else 0.0,
2599
2600
2601
                    cp_group,
                    cp_global_ranks,
                    cp_stream,
2602
                    softmax_scale=self.softmax_scale,
2603
                    qkv_format="bshd" if qkv_format == "sbhd" else qkv_format,
2604
                    attn_mask_type=attn_mask_type,
2605
                    deterministic=self.deterministic,
2606
2607
                )
        else:
2608
2609

            from .cpu_offload import CPUOffloadEnabled
2610

2611
2612
2613
2614
2615
2616
            if CPUOffloadEnabled:
                tensor_list = [query_layer, key_layer, value_layer, cu_seqlens_q, cu_seqlens_kv]
                for tensor in tensor_list:
                    if tensor is not None:
                        tensor.activation_offloading = True

2617
            with self.attention_dropout_ctx():
2618
                fa_optional_forward_kwargs = {}
2619
2620
                if _flash_attn_2_3_plus:
                    fa_optional_forward_kwargs["window_size"] = window_size
2621
2622
2623
2624
                if _flash_attn_2_4_plus:
                    fa_optional_forward_kwargs["alibi_slopes"] = alibi_slopes
                if _flash_attn_2_4_1_plus:
                    fa_optional_forward_kwargs["deterministic"] = self.deterministic
2625
                output = flash_attn_forward_func(
2626
2627
2628
2629
2630
2631
2632
                    query_layer,
                    key_layer,
                    value_layer,
                    cu_seqlens_q,
                    cu_seqlens_kv,
                    max_seqlen_q,
                    max_seqlen_kv,
2633
                    self.attention_dropout if self.training else 0.0,
2634
2635
                    softmax_scale=self.softmax_scale,
                    causal="causal" in attn_mask_type,
2636
                    **fa_optional_forward_kwargs,
2637
                )
2638

2639
        if qkv_format in ["sbhd", "bshd"] and "padding" in attn_mask_type:
2640
            output = UnpackTensor.apply(indices_q, batch_size * max_seqlen_q, output)
2641

2642
        if qkv_format == "sbhd":
2643
2644
            # (bs)hd -> bs(hd) -> sb(hd)
            output = output.view(batch_size, max_seqlen_q, -1).transpose(0, 1).contiguous()
2645
        elif qkv_format == "bshd":
2646
2647
            # (bs)hd -> bs(hd)
            output = output.view(batch_size, max_seqlen_q, -1).contiguous()
2648
        elif qkv_format == "thd":
2649
2650
            # thd -> t(hd)
            output = output.view(output.shape[0], -1).contiguous()
2651
2652

        return output
2653

2654

2655
def _combine_tensors(
2656
2657
2658
    tensors: List[torch.Tensor],
    dim: int,
) -> torch.Tensor:
2659
2660
2661
2662
2663
2664
    """Combine tensors along a particular dimension"""

    num_tensors = len(tensors)
    new_shape = list(tensors[0].shape)
    new_shape.insert(dim, num_tensors)
    new_stride = list(tensors[0].stride())
2665
    new_stride.insert(dim, int(new_stride[dim - 1] / num_tensors))
2666
    if isinstance(tensors[0], Float8Tensor):
2667
        combined_tensor = torch.Tensor().to(device=tensors[0].device, dtype=tensors[0]._data.dtype)
2668
2669
2670
        combined_tensor.set_(
            tensors[0]._data.untyped_storage(),
            tensors[0]._data.storage_offset(),
2671
2672
2673
2674
            new_shape,
            new_stride,
        )
        combined_tensor = Float8Tensor.make_like(tensors[0], data=combined_tensor)
2675
    else:
2676
        combined_tensor = torch.Tensor().to(device=tensors[0].device, dtype=tensors[0].dtype)
2677
        combined_tensor.set_(
2678
2679
            tensors[0].untyped_storage(), tensors[0].storage_offset(), new_shape, new_stride
        )
2680
2681

    return combined_tensor
2682

2683

2684
2685
2686
2687
class FusedAttnFunc_qkvpacked(torch.autograd.Function):
    """Function for FusedAttention with packed QKV input"""

    @staticmethod
2688
2689
2690
2691
2692
2693
2694
2695
2696
2697
2698
2699
2700
2701
2702
2703
2704
2705
2706
2707
2708
2709
2710
2711
    def forward(
        ctx,
        is_training,
        max_seqlen,
        cu_seqlens,
        seq_offsets_q,
        seq_offsets_k,
        seq_offsets_v,
        seq_offsets_o,
        qkv,
        qkv_dtype,
        attn_bias,
        attn_scale,
        dropout_p,
        fast_zero_fill,
        qkv_layout,
        attn_bias_type,
        attn_mask_type,
        rng_gen,
        fused_attention_backend,
        use_FAv2_bwd,
        fp8,
        fp8_meta,
    ):
2712
        logger = logging.getLogger("FusedAttnFunc_qkvpacked")
2713
        if fp8:
2714
            logger.debug("Running forward in FP8")
2715
            if fp8_meta["recipe"].fp8_mha:
2716
                assert isinstance(qkv, Float8Tensor), "qkv must be Float8Tensors for FP8 MHA."
2717
2718
2719
2720
                fp8_meta["scaling_fwd"].scale_inv[META_QKV] = qkv._scale_inv
            fused_attention_backend = FusedAttnBackend["FP8"]
            fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
            # 1: qkv packed, 2: kv packed, 3: qkv separate
2721
2722
2723
2724
2725
            qkv_group = len(qkv_layout.split("_"))
            assert qkv_group == 1, (
                "qkv layout should conform to 3hd or h3d, e.g. sb3hd,                 but found"
                f" {qkv_layout}."
            )
2726
2727
2728
2729
            if fp8_meta["recipe"].fp8_mha:
                qkv_fp8 = qkv._data
            else:
                qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1])
2730
2731
2732
                qkv_fp8 = cast_to_fp8(
                    qkv_c, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward
                ).view(qkv.shape)
2733
            out_fp8, aux_ctx_tensors = fused_attn_fwd_qkvpacked(
2734
2735
2736
2737
2738
2739
2740
2741
2742
2743
2744
                is_training,
                max_seqlen,
                cu_seqlens,
                qkv_fp8,
                fp8_dtype_forward,
                fused_attention_backend,
                attn_bias,
                seq_offsets_q,
                seq_offsets_k,
                seq_offsets_v,
                seq_offsets_o,
2745
2746
2747
2748
2749
2750
                fp8_meta["scaling_fwd"].scale_inv[META_QKV],
                fp8_meta["scaling_fwd"].scale_inv[META_S],
                fp8_meta["scaling_fwd"].scale[META_S],
                fp8_meta["scaling_fwd"].scale[META_O],
                fp8_meta["scaling_fwd"].amax_history[0][META_S],
                fp8_meta["scaling_fwd"].amax_history[0][META_O],
2751
2752
2753
2754
2755
2756
2757
2758
                attn_scale,
                dropout_p,
                fast_zero_fill,
                qkv_layout,
                attn_bias_type,
                attn_mask_type,
                rng_gen,
            )
2759
            if fp8_meta["recipe"].fp8_mha:
2760
2761
                out_ret = Float8Tensor(
                    data=out_fp8,
2762
2763
2764
2765
2766
2767
2768
2769
2770
                    fp8_meta=fp8_meta,
                    fp8_meta_forward=True,
                    fp8_meta_index=META_O,
                    fp8_dtype=fp8_dtype_forward,
                    dtype=qkv.dtype,
                )
            else:
                out_ret = cast_from_fp8(
                    out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]),
2771
2772
2773
2774
2775
                    fp8_meta["scaling_fwd"],
                    META_O,
                    fp8_dtype_forward,
                    qkv_dtype,
                ).view(out_fp8.shape)
2776
2777
2778
            out_save = out_ret
            if fp8_meta["recipe"].fp8_mha and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
                qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1])
2779
2780
                qkv = cast_from_fp8(
                    qkv_c._data,
2781
                    fp8_meta["scaling_fwd"],
2782
2783
2784
2785
                    META_QKV,
                    fp8_dtype_forward,
                    TE_DType[qkv.dtype],
                ).view(qkv.shape)
2786
2787
                out_save = cast_from_fp8(
                    out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]),
2788
2789
2790
2791
2792
2793
2794
2795
                    fp8_meta["scaling_fwd"],
                    META_O,
                    fp8_dtype_forward,
                    qkv_dtype,
                ).view(out_fp8.shape)
            fp8_tensors = (
                qkv_fp8,
                out_fp8,
2796
                fp8_meta["scaling_fwd"].scale.clone(),
2797
2798
                fp8_meta["scaling_fwd"].scale_inv.clone(),
            )
2799
        else:
2800
            logger.debug("Running forward in %s", qkv.dtype)
2801
            out_ret, aux_ctx_tensors = fused_attn_fwd_qkvpacked(
2802
2803
2804
2805
2806
2807
2808
2809
2810
2811
2812
2813
2814
2815
2816
2817
2818
2819
2820
2821
2822
2823
2824
2825
2826
                is_training,
                max_seqlen,
                cu_seqlens,
                qkv,
                qkv_dtype,
                fused_attention_backend,
                attn_bias,
                seq_offsets_q,
                seq_offsets_k,
                seq_offsets_v,
                seq_offsets_o,
                None,
                None,
                None,
                None,
                None,
                None,
                attn_scale,
                dropout_p,
                fast_zero_fill,
                qkv_layout,
                attn_bias_type,
                attn_mask_type,
                rng_gen,
            )
2827
2828
2829
2830
2831
            fp8_tensors = (None, None, None, None)
            out_save = out_ret

        ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
        qkvo_tensors = (qkv, out_save) if not ctx.fp8 else (None, None)
2832
2833
2834
2835
2836
2837
2838
2839
2840
2841
        ctx.save_for_backward(
            *qkvo_tensors,
            cu_seqlens,
            seq_offsets_q,
            seq_offsets_k,
            seq_offsets_v,
            seq_offsets_o,
            *fp8_tensors,
            *aux_ctx_tensors,
        )
2842
        ctx.fp8_meta = fp8_meta
2843
2844
2845
2846
2847
2848
2849
2850
        ctx.max_seqlen = max_seqlen
        ctx.qkv_dtype = qkv_dtype
        ctx.attn_scale = attn_scale
        ctx.dropout_p = dropout_p
        ctx.fast_zero_fill = fast_zero_fill
        ctx.qkv_layout = qkv_layout
        ctx.attn_bias_type = attn_bias_type
        ctx.attn_mask_type = attn_mask_type
2851
        ctx.fused_attention_backend = (
2852
            fused_attention_backend if ctx.fp8 else FusedAttnBackend["F16_arbitrary_seqlen"]
2853
        )
2854
        ctx.use_FAv2_bwd = use_FAv2_bwd
2855

2856
        return out_ret
2857
2858
2859

    @staticmethod
    def backward(ctx, d_out):
2860
        logger = logging.getLogger("FusedAttnFunc_qkvpacked")
2861
        if ctx.fp8_meta["recipe"].fp8_mha:
2862
2863
2864
            assert isinstance(
                d_out, Float8Tensor
            ), "Gradient of the DPA output must be in Float8Tensor type for FP8 MHA."
2865
2866
2867
            d_out_f8tensor = d_out
            d_out = d_out._data

2868
        d_out = d_out.contiguous()
2869
2870
2871
2872
2873
2874
2875
2876
2877
2878
2879
2880
2881
2882
        (
            qkv,
            out,
            cu_seqlens,
            seq_offsets_q,
            seq_offsets_k,
            seq_offsets_v,
            seq_offsets_o,
            qkv_fp8,
            out_fp8,
            fwd_scales,
            fwd_scale_invs,
            *aux_ctx_tensors,
        ) = ctx.saved_tensors
2883
2884
        if not aux_ctx_tensors[0].is_contiguous():
            aux_ctx_tensors[0] = aux_ctx_tensors[0].contiguous()
2885
        if ctx.use_FAv2_bwd:
2886
            softmax_lse, rng_state = aux_ctx_tensors
2887
2888
            dqkv = torch.empty_like(qkv)
            maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
2889
2890
2891
            d_out, q, k, v, out = [
                maybe_contiguous(x) for x in (d_out, qkv[:, 0], qkv[:, 1], qkv[:, 2], out)
            ]
2892
            flash_attn_cuda_bwd(
2893
2894
2895
2896
2897
2898
2899
2900
2901
2902
2903
2904
2905
2906
2907
2908
2909
2910
2911
                d_out,
                q,
                k,
                v,
                out,
                softmax_lse,
                dqkv[:, 0],
                dqkv[:, 1],
                dqkv[:, 2],
                cu_seqlens,
                cu_seqlens,
                ctx.max_seqlen,
                ctx.max_seqlen,
                ctx.dropout_p,
                ctx.attn_scale,
                False,
                "causal" in ctx.attn_mask_type,
                None,
                rng_state,
2912
            )
2913
            dqkv = dqkv[..., : d_out.shape[-1]]
2914
        else:
2915
2916
            with torch.cuda.nvtx.range("_FusedAttn_qkvpacked"):
                if ctx.fp8:
2917
                    logger.debug("Running backward in FP8")
2918
                    fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True)
2919
                    fp8_dtype_backward = get_fp8_te_dtype(
2920
2921
                        ctx.fp8_meta["recipe"], fprop_tensor=False
                    )
2922
2923
                    if ctx.fp8_meta["recipe"].fp8_mha:
                        d_out_fp8 = d_out
2924
                        ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] = d_out_f8tensor._scale_inv
2925
2926
2927
                    else:
                        d_out_fp8 = cast_to_fp8(
                            d_out.view(-1, d_out.shape[-2] * d_out.shape[-1]),
2928
2929
2930
2931
                            ctx.fp8_meta["scaling_bwd"],
                            META_DO,
                            fp8_dtype_backward,
                        ).view(d_out.shape)
2932
                    dqkv_fp8, *rest = fused_attn_bwd_qkvpacked(
2933
2934
2935
2936
2937
2938
2939
2940
                        ctx.max_seqlen,
                        cu_seqlens,
                        qkv_fp8,
                        out_fp8,
                        d_out_fp8,
                        fp8_dtype_forward,
                        fp8_dtype_backward,
                        aux_ctx_tensors,
2941
                        ctx.fused_attention_backend,
2942
2943
2944
2945
2946
2947
2948
2949
2950
2951
2952
2953
2954
2955
2956
2957
2958
2959
2960
2961
2962
                        seq_offsets_q,
                        seq_offsets_k,
                        seq_offsets_v,
                        seq_offsets_o,
                        fwd_scale_invs[META_QKV],  # d_scale_qkv,
                        fwd_scale_invs[META_S],  # d_scale_s,
                        fwd_scale_invs[META_O],  # d_scale_o,
                        ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO],  # d_scale_do
                        ctx.fp8_meta["scaling_bwd"].scale_inv[META_DP],  # d_scale_dp
                        fwd_scales[META_S],  # q_scale_s
                        ctx.fp8_meta["scaling_bwd"].scale[META_DP],  # q_scale_dp
                        ctx.fp8_meta["scaling_bwd"].scale[META_DQKV],  # q_scale_dqkv
                        ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DP],  # amax_dp
                        ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DQKV],  # amax_dqkv
                        ctx.attn_scale,
                        ctx.dropout_p,
                        ctx.fast_zero_fill,
                        ctx.qkv_layout,
                        ctx.attn_bias_type,
                        ctx.attn_mask_type,
                    )
2963
                    if ctx.fp8_meta["recipe"].fp8_mha:
2964
2965
                        dqkv = Float8Tensor(
                            data=dqkv_fp8,
2966
2967
2968
2969
2970
                            fp8_meta=ctx.fp8_meta,
                            fp8_meta_forward=False,
                            fp8_meta_index=META_DQKV,
                            fp8_dtype=fp8_dtype_backward,
                            dtype=d_out_f8tensor.dtype,
2971
                        )
2972
                    else:
2973
2974
2975
2976
2977
2978
2979
2980
2981
2982
                        dqkv_c_fp8 = dqkv_fp8.view(
                            -1, dqkv_fp8.shape[-3] * dqkv_fp8.shape[-2] * dqkv_fp8.shape[-1]
                        )
                        dqkv = cast_from_fp8(
                            dqkv_c_fp8,
                            ctx.fp8_meta["scaling_bwd"],
                            META_DQKV,
                            fp8_dtype_backward,
                            ctx.qkv_dtype,
                        ).view(dqkv_fp8.shape)
2983
                else:
2984
                    logger.debug("Running backward in %s", qkv.dtype)
2985
2986
2987
                    if d_out.dtype == torch.uint8:
                        d_out = d_out_f8tensor.from_float8(qkv.dtype)
                    dqkv, *rest = fused_attn_bwd_qkvpacked(
2988
2989
2990
2991
2992
2993
2994
2995
                        ctx.max_seqlen,
                        cu_seqlens,
                        qkv,
                        out,
                        d_out,
                        ctx.qkv_dtype,
                        ctx.qkv_dtype,
                        aux_ctx_tensors,
2996
                        ctx.fused_attention_backend,
2997
2998
2999
3000
3001
3002
3003
3004
3005
3006
3007
3008
3009
3010
3011
3012
3013
3014
3015
3016
3017
                        seq_offsets_q,
                        seq_offsets_k,
                        seq_offsets_v,
                        seq_offsets_o,
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        ctx.attn_scale,
                        ctx.dropout_p,
                        ctx.fast_zero_fill,
                        ctx.qkv_layout,
                        ctx.attn_bias_type,
                        ctx.attn_mask_type,
                    )
3018

3019
3020
        # if no_bias or alibi, return dqkv
        if ctx.attn_bias_type in ["no_bias", "alibi"]:
3021
3022
3023
3024
3025
3026
3027
3028
3029
3030
3031
3032
3033
3034
3035
3036
3037
3038
3039
3040
3041
3042
3043
3044
3045
            return (
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                dqkv,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
            )
3046
        # else, return (dqkv, dbias)
3047
3048
3049
3050
3051
3052
3053
3054
3055
3056
3057
3058
3059
3060
3061
3062
3063
3064
3065
3066
3067
3068
3069
3070
3071
        return (
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            dqkv,
            None,
            rest[0],
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
        )
3072

3073

3074
3075
3076
3077
class FusedAttnFunc_kvpacked(torch.autograd.Function):
    """Function for FusedAttention with packed KV input"""

    @staticmethod
3078
3079
3080
3081
3082
3083
3084
3085
3086
3087
3088
3089
3090
3091
3092
3093
3094
3095
3096
3097
3098
3099
3100
3101
3102
3103
3104
    def forward(
        ctx,
        is_training,
        max_seqlen_q,
        max_seqlen_kv,
        cu_seqlens_q,
        cu_seqlens_kv,
        seq_offsets_q,
        seq_offsets_k,
        seq_offsets_v,
        seq_offsets_o,
        q,
        kv,
        qkv_dtype,
        attn_bias,
        attn_scale,
        dropout_p,
        fast_zero_fill,
        qkv_layout,
        attn_bias_type,
        attn_mask_type,
        rng_gen,
        fused_attention_backend,
        use_FAv2_bwd,
        fp8,
        fp8_meta,
    ):
3105
        logger = logging.getLogger("FusedAttnFunc_kvpacked")
3106
        if fp8:
3107
            logger.debug("Running forward in FP8")
3108
            if fp8_meta["recipe"].fp8_mha:
3109
3110
3111
                assert isinstance(q, Float8Tensor) and isinstance(
                    kv, Float8Tensor
                ), "q/kv must be Float8Tensors for FP8 MHA."
3112
3113
3114
3115
3116
3117
3118
                fp8_meta["scaling_fwd"].scale_inv[META_QKV] = q._scale_inv
            fused_attention_backend = FusedAttnBackend["FP8"]
            fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
            if fp8_meta["recipe"].fp8_mha:
                q_fp8, kv_fp8 = q._data, kv._data
            else:
                # 1: qkv packed, 2: kv packed, 3: qkv separate
3119
3120
3121
3122
3123
3124
3125
3126
                qkv_group = len(qkv_layout.split("_"))
                assert qkv_group == 2, (
                    "qkv layout should conform to hd_2hd or hd_h2d, e.g. sbhd_sb2hd,              "
                    f"       but found {qkv_layout}."
                )
                q_fp8 = cast_to_fp8(q, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward).view(
                    q.shape
                )
3127
                kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1])
3128
3129
3130
                kv_fp8 = cast_to_fp8(
                    kv_c, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward
                ).view(kv.shape)
3131
            out_fp8, aux_ctx_tensors = fused_attn_fwd_kvpacked(
3132
3133
3134
3135
3136
3137
3138
3139
3140
3141
3142
3143
3144
3145
                is_training,
                max_seqlen_q,
                max_seqlen_kv,
                cu_seqlens_q,
                cu_seqlens_kv,
                q_fp8,
                kv_fp8,
                fp8_dtype_forward,
                fused_attention_backend,
                attn_bias,
                seq_offsets_q,
                seq_offsets_k,
                seq_offsets_v,
                seq_offsets_o,
3146
3147
3148
3149
3150
3151
                fp8_meta["scaling_fwd"].scale_inv[META_QKV],
                fp8_meta["scaling_fwd"].scale_inv[META_S],
                fp8_meta["scaling_fwd"].scale[META_S],
                fp8_meta["scaling_fwd"].scale[META_O],
                fp8_meta["scaling_fwd"].amax_history[0][META_S],
                fp8_meta["scaling_fwd"].amax_history[0][META_O],
3152
3153
3154
3155
3156
3157
3158
3159
                attn_scale,
                dropout_p,
                fast_zero_fill,
                qkv_layout,
                attn_bias_type,
                attn_mask_type,
                rng_gen,
            )
3160
            if fp8_meta["recipe"].fp8_mha:
3161
3162
                out_ret = Float8Tensor(
                    data=out_fp8,
3163
3164
3165
3166
3167
3168
3169
3170
3171
                    fp8_meta=fp8_meta,
                    fp8_meta_forward=True,
                    fp8_meta_index=META_O,
                    fp8_dtype=fp8_dtype_forward,
                    dtype=q.dtype,
                )
            else:
                out_ret = cast_from_fp8(
                    out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]),
3172
3173
3174
3175
3176
                    fp8_meta["scaling_fwd"],
                    META_O,
                    fp8_dtype_forward,
                    qkv_dtype,
                ).view(out_fp8.shape)
3177
3178
            out_save = out_ret
            if fp8_meta["recipe"].fp8_mha and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
3179
3180
3181
                q = cast_from_fp8(
                    q._data, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward, TE_DType[q.dtype]
                ).view(q.shape)
3182
                kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1])
3183
3184
                kv = cast_from_fp8(
                    kv_c._data,
3185
                    fp8_meta["scaling_fwd"],
3186
3187
3188
3189
                    META_QKV,
                    fp8_dtype_forward,
                    TE_DType[kv.dtype],
                ).view(kv.shape)
3190
3191
                out_save = cast_from_fp8(
                    out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]),
3192
3193
3194
3195
3196
3197
3198
3199
3200
                    fp8_meta["scaling_fwd"],
                    META_O,
                    fp8_dtype_forward,
                    qkv_dtype,
                ).view(out_fp8.shape)
            fp8_tensors = (
                q_fp8,
                kv_fp8,
                out_fp8,
3201
                fp8_meta["scaling_fwd"].scale.clone(),
3202
3203
                fp8_meta["scaling_fwd"].scale_inv.clone(),
            )
3204
        else:
3205
            logger.debug("Running forward in %s", q.dtype)
3206
            out_ret, aux_ctx_tensors = fused_attn_fwd_kvpacked(
3207
3208
3209
3210
3211
3212
3213
3214
3215
3216
3217
3218
3219
3220
3221
3222
3223
3224
3225
3226
3227
3228
3229
3230
3231
3232
3233
3234
                is_training,
                max_seqlen_q,
                max_seqlen_kv,
                cu_seqlens_q,
                cu_seqlens_kv,
                q,
                kv,
                qkv_dtype,
                fused_attention_backend,
                attn_bias,
                seq_offsets_q,
                seq_offsets_k,
                seq_offsets_v,
                seq_offsets_o,
                None,
                None,
                None,
                None,
                None,
                None,
                attn_scale,
                dropout_p,
                fast_zero_fill,
                qkv_layout,
                attn_bias_type,
                attn_mask_type,
                rng_gen,
            )
3235
3236
3237
3238
3239
            out_save = out_ret
            fp8_tensors = (None, None, None, None, None)

        ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
        qkvo_tensors = (q, kv, out_save) if not ctx.fp8 else (None, None, None)
3240
3241
3242
3243
3244
3245
3246
3247
3248
3249
3250
        ctx.save_for_backward(
            *qkvo_tensors,
            cu_seqlens_q,
            cu_seqlens_kv,
            seq_offsets_q,
            seq_offsets_k,
            seq_offsets_v,
            seq_offsets_o,
            *fp8_tensors,
            *aux_ctx_tensors,
        )
3251
        ctx.fp8_meta = fp8_meta
3252
3253
3254
3255
3256
3257
3258
3259
3260
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_kv = max_seqlen_kv
        ctx.qkv_dtype = qkv_dtype
        ctx.attn_scale = attn_scale
        ctx.dropout_p = dropout_p
        ctx.fast_zero_fill = fast_zero_fill
        ctx.qkv_layout = qkv_layout
        ctx.attn_bias_type = attn_bias_type
        ctx.attn_mask_type = attn_mask_type
3261
        ctx.fused_attention_backend = (
3262
            fused_attention_backend if ctx.fp8 else FusedAttnBackend["F16_arbitrary_seqlen"]
3263
        )
3264
        ctx.use_FAv2_bwd = use_FAv2_bwd
3265

3266
        return out_ret
3267
3268
3269

    @staticmethod
    def backward(ctx, d_out):
3270
        logger = logging.getLogger("FusedAttnFunc_kvpacked")
3271
        if ctx.fp8_meta["recipe"].fp8_mha:
3272
3273
3274
            assert isinstance(
                d_out, Float8Tensor
            ), "Gradient of the DPA output must be in Float8Tensor type for FP8 MHA."
3275
3276
3277
            d_out_f8tensor = d_out
            d_out = d_out._data

3278
        d_out = d_out.contiguous()
3279
3280
3281
3282
3283
3284
3285
3286
3287
3288
3289
3290
3291
3292
3293
3294
3295
        (
            q,
            kv,
            out,
            cu_seqlens_q,
            cu_seqlens_kv,
            seq_offsets_q,
            seq_offsets_k,
            seq_offsets_v,
            seq_offsets_o,
            q_fp8,
            kv_fp8,
            out_fp8,
            fwd_scales,
            fwd_scale_invs,
            *aux_ctx_tensors,
        ) = ctx.saved_tensors
3296
3297
        if not aux_ctx_tensors[0].is_contiguous():
            aux_ctx_tensors[0] = aux_ctx_tensors[0].contiguous()
3298
        if ctx.use_FAv2_bwd:
3299
            softmax_lse, rng_state = aux_ctx_tensors
3300
3301
3302
            dq = torch.empty_like(q)
            dkv = torch.empty_like(kv)
            maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
3303
            d_out, q, k, v, out = [maybe_contiguous(x) for x in (d_out, q, kv[:, 0], kv[:, 1], out)]
3304
            flash_attn_cuda_bwd(
3305
3306
3307
3308
3309
3310
3311
3312
3313
3314
3315
3316
3317
3318
3319
3320
3321
3322
3323
                d_out,
                q,
                k,
                v,
                out,
                softmax_lse,
                dq,
                dkv[:, 0],
                dkv[:, 1],
                cu_seqlens_q,
                cu_seqlens_kv,
                ctx.max_seqlen_q,
                ctx.max_seqlen_kv,
                ctx.dropout_p,
                ctx.attn_scale,
                False,
                "causal" in ctx.attn_mask_type,
                None,
                rng_state,
3324
            )
3325
3326
            dq = dq[..., : d_out.shape[-1]]
            dkv = dkv[..., : d_out.shape[-1]]
3327
        else:
3328
3329
            with torch.cuda.nvtx.range("_FusedAttn_kvpacked"):
                if ctx.fp8:
3330
                    logger.debug("Running backward in FP8")
3331
                    fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True)
3332
                    fp8_dtype_backward = get_fp8_te_dtype(
3333
3334
                        ctx.fp8_meta["recipe"], fprop_tensor=False
                    )
3335
3336
                    if ctx.fp8_meta["recipe"].fp8_mha:
                        d_out_fp8 = d_out
3337
                        ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] = d_out_f8tensor._scale_inv
3338
3339
3340
                    else:
                        d_out_fp8 = cast_to_fp8(
                            d_out.view(-1, d_out.shape[-2] * d_out.shape[-1]),
3341
3342
3343
3344
                            ctx.fp8_meta["scaling_bwd"],
                            META_DO,
                            fp8_dtype_backward,
                        ).view(d_out.shape)
3345
                    dq_fp8, dkv_fp8, *rest = fused_attn_bwd_kvpacked(
3346
3347
3348
3349
3350
3351
3352
3353
3354
3355
3356
                        ctx.max_seqlen_q,
                        ctx.max_seqlen_kv,
                        cu_seqlens_q,
                        cu_seqlens_kv,
                        q_fp8,
                        kv_fp8,
                        out_fp8,
                        d_out_fp8,
                        fp8_dtype_forward,
                        fp8_dtype_backward,
                        aux_ctx_tensors,
3357
                        ctx.fused_attention_backend,
3358
3359
3360
3361
3362
3363
3364
3365
3366
3367
3368
3369
3370
3371
3372
3373
3374
3375
3376
3377
3378
                        seq_offsets_q,
                        seq_offsets_k,
                        seq_offsets_v,
                        seq_offsets_o,
                        fwd_scale_invs[META_QKV],  # d_scale_qkv,
                        fwd_scale_invs[META_S],  # d_scale_s,
                        fwd_scale_invs[META_O],  # d_scale_o,
                        ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO],  # d_scale_do
                        ctx.fp8_meta["scaling_bwd"].scale_inv[META_DP],  # d_scale_dp
                        fwd_scales[META_S],  # q_scale_s
                        ctx.fp8_meta["scaling_bwd"].scale[META_DP],  # q_scale_dp
                        ctx.fp8_meta["scaling_bwd"].scale[META_DQKV],  # q_scale_dqkv
                        ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DP],  # amax_dp
                        ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DQKV],  # amax_dqkv
                        ctx.attn_scale,
                        ctx.dropout_p,
                        ctx.fast_zero_fill,
                        ctx.qkv_layout,
                        ctx.attn_bias_type,
                        ctx.attn_mask_type,
                    )
3379
                    if ctx.fp8_meta["recipe"].fp8_mha:
3380
3381
                        dq = Float8Tensor(
                            data=dq_fp8,
3382
3383
3384
3385
3386
                            fp8_meta=ctx.fp8_meta,
                            fp8_meta_forward=False,
                            fp8_meta_index=META_DQKV,
                            fp8_dtype=fp8_dtype_backward,
                            dtype=d_out_f8tensor.dtype,
3387
3388
3389
                        )
                        dkv = Float8Tensor(
                            data=dkv_fp8,
3390
3391
3392
3393
3394
                            fp8_meta=ctx.fp8_meta,
                            fp8_meta_forward=False,
                            fp8_meta_index=META_DQKV,
                            fp8_dtype=fp8_dtype_backward,
                            dtype=d_out_f8tensor.dtype,
3395
                        )
3396
3397
3398
                    else:
                        dq = cast_from_fp8(
                            dq_fp8.view(-1, dq_fp8.shape[-2] * dq_fp8.shape[-1]),
3399
3400
3401
3402
3403
3404
3405
3406
3407
3408
3409
3410
3411
3412
3413
                            ctx.fp8_meta["scaling_bwd"],
                            META_DQKV,
                            fp8_dtype_backward,
                            ctx.qkv_dtype,
                        ).view(dq_fp8.shape)
                        dkv_c_fp8 = dkv_fp8.view(
                            -1, dkv_fp8.shape[-3] * dkv_fp8.shape[-2] * dkv_fp8.shape[-1]
                        )
                        dkv = cast_from_fp8(
                            dkv_c_fp8,
                            ctx.fp8_meta["scaling_bwd"],
                            META_DQKV,
                            fp8_dtype_backward,
                            ctx.qkv_dtype,
                        ).view(dkv_fp8.shape)
3414
                else:
3415
                    logger.debug("Running backward in %s", q.dtype)
3416
3417
3418
                    if d_out.dtype == torch.uint8:
                        d_out = d_out_f8tensor.from_float8(q.dtype)
                    dq, dkv, *rest = fused_attn_bwd_kvpacked(
3419
3420
3421
3422
3423
3424
3425
3426
3427
3428
3429
                        ctx.max_seqlen_q,
                        ctx.max_seqlen_kv,
                        cu_seqlens_q,
                        cu_seqlens_kv,
                        q,
                        kv,
                        out,
                        d_out,
                        ctx.qkv_dtype,
                        ctx.qkv_dtype,
                        aux_ctx_tensors,
3430
                        ctx.fused_attention_backend,
3431
3432
3433
3434
3435
3436
3437
3438
3439
3440
3441
3442
3443
3444
3445
3446
3447
3448
3449
3450
3451
                        seq_offsets_q,
                        seq_offsets_k,
                        seq_offsets_v,
                        seq_offsets_o,
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        ctx.attn_scale,
                        ctx.dropout_p,
                        ctx.fast_zero_fill,
                        ctx.qkv_layout,
                        ctx.attn_bias_type,
                        ctx.attn_mask_type,
                    )
3452

3453
3454
        # if no_bias or alibi, return dqkv
        if ctx.attn_bias_type in ["no_bias", "alibi"]:
3455
3456
3457
3458
3459
3460
3461
3462
3463
3464
3465
3466
3467
3468
3469
3470
3471
3472
3473
3474
3475
3476
3477
3478
3479
3480
3481
3482
            return (
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                dq,
                dkv,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
            )
3483
        # else, return (dqkv, dbias)
3484
3485
3486
3487
3488
3489
3490
3491
3492
3493
3494
3495
3496
3497
3498
3499
3500
3501
3502
3503
3504
3505
3506
3507
3508
3509
3510
3511
3512
        return (
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            dq,
            dkv,
            None,
            rest[0],
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
        )

3513

3514
3515
3516
3517
class FusedAttnFunc(torch.autograd.Function):
    """Function for FusedAttention with separate Q, K, V tensors"""

    @staticmethod
3518
3519
3520
3521
3522
3523
3524
3525
3526
3527
3528
3529
3530
3531
3532
3533
3534
3535
3536
3537
3538
3539
3540
3541
3542
3543
3544
3545
    def forward(
        ctx,
        is_training,
        max_seqlen_q,
        max_seqlen_kv,
        cu_seqlens_q,
        cu_seqlens_kv,
        seq_offsets_q,
        seq_offsets_k,
        seq_offsets_v,
        seq_offsets_o,
        q,
        k,
        v,
        qkv_dtype,
        attn_bias,
        attn_scale,
        dropout_p,
        fast_zero_fill,
        qkv_layout,
        attn_bias_type,
        attn_mask_type,
        rng_gen,
        fused_attention_backend,
        use_FAv2_bwd,
        fp8,
        fp8_meta,
    ):
3546
        logger = logging.getLogger("FusedAttnFunc")
3547
        if fp8:
3548
            logger.debug("Running forward in FP8")
3549
3550
3551
            fused_attention_backend = FusedAttnBackend["FP8"]
            fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
            if fp8_meta["recipe"].fp8_mha:
3552
3553
                assert (
                    isinstance(q, Float8Tensor)
3554
                    and isinstance(k, Float8Tensor)
3555
3556
                    and isinstance(v, Float8Tensor)
                ), "q/k/v must be Float8Tensors for FP8 MHA."
3557
3558
3559
3560
                fp8_meta["scaling_fwd"].scale_inv[META_QKV] = q._scale_inv
                q_fp8, k_fp8, v_fp8 = q._data, k._data, v._data
            else:
                # 1: qkv packed, 2: kv packed, 3: qkv separate
3561
                qkv_group = len(qkv_layout.split("_"))
3562
                if qkv_group == 1:
3563
3564
                    dim = qkv_layout.find("3")
                    qkv = _combine_tensors([q, k, v], dim)
3565
                    qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1])
3566
3567
3568
3569
                    qkv_fp8 = cast_to_fp8(
                        qkv_c, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward
                    ).view(qkv.shape)
                    q_fp8, k_fp8, v_fp8 = _SplitAlongDim.apply(qkv_fp8, dim, [1, 1, 1])
3570
3571
                    q_fp8, k_fp8, v_fp8 = [x.squeeze(dim) for x in [q_fp8, k_fp8, v_fp8]]
                if qkv_group == 2:
3572
3573
3574
3575
3576
                    q_fp8 = cast_to_fp8(
                        q, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward
                    ).view(q.shape)
                    dim = qkv_layout.split("_")[1].find("2")
                    kv = _combine_tensors([k, v], dim)
3577
                    kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1])
3578
3579
3580
3581
                    kv_fp8 = cast_to_fp8(
                        kv_c, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward
                    ).view(kv.shape)
                    k_fp8, v_fp8 = _SplitAlongDim.apply(kv_fp8, dim, [1, 1])
3582
3583
                    k_fp8, v_fp8 = [x.squeeze(dim) for x in [k_fp8, v_fp8]]
                if qkv_group == 3:
3584
3585
3586
3587
3588
3589
3590
3591
3592
                    q_fp8 = cast_to_fp8(
                        q, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward
                    ).view(q.shape)
                    k_fp8 = cast_to_fp8(
                        k, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward
                    ).view(k.shape)
                    v_fp8 = cast_to_fp8(
                        v, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward
                    ).view(v.shape)
3593
            out_fp8, aux_ctx_tensors = fused_attn_fwd(
3594
3595
3596
3597
3598
3599
3600
3601
3602
3603
3604
3605
3606
3607
3608
                is_training,
                max_seqlen_q,
                max_seqlen_kv,
                cu_seqlens_q,
                cu_seqlens_kv,
                q_fp8,
                k_fp8,
                v_fp8,
                fp8_dtype_forward,
                fused_attention_backend,
                attn_bias,
                seq_offsets_q,
                seq_offsets_k,
                seq_offsets_v,
                seq_offsets_o,
3609
3610
3611
3612
3613
3614
                fp8_meta["scaling_fwd"].scale_inv[META_QKV],
                fp8_meta["scaling_fwd"].scale_inv[META_S],
                fp8_meta["scaling_fwd"].scale[META_S],
                fp8_meta["scaling_fwd"].scale[META_O],
                fp8_meta["scaling_fwd"].amax_history[0][META_S],
                fp8_meta["scaling_fwd"].amax_history[0][META_O],
3615
3616
3617
3618
3619
3620
3621
3622
                attn_scale,
                dropout_p,
                fast_zero_fill,
                qkv_layout,
                attn_bias_type,
                attn_mask_type,
                rng_gen,
            )
3623
            if fp8_meta["recipe"].fp8_mha:
3624
3625
                out_ret = Float8Tensor(
                    data=out_fp8,
3626
3627
3628
3629
3630
3631
3632
3633
3634
                    fp8_meta=fp8_meta,
                    fp8_meta_forward=True,
                    fp8_meta_index=META_O,
                    fp8_dtype=fp8_dtype_forward,
                    dtype=q.dtype,
                )
            else:
                out_ret = cast_from_fp8(
                    out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]),
3635
3636
3637
3638
3639
                    fp8_meta["scaling_fwd"],
                    META_O,
                    fp8_dtype_forward,
                    qkv_dtype,
                ).view(out_fp8.shape)
3640
3641
3642
3643
            out_save = out_ret

            if fp8_meta["recipe"].fp8_mha and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
                # 1: qkv packed, 2: kv packed, 3: qkv separate
3644
                qkv_group = len(qkv_layout.split("_"))
3645
                if qkv_group == 1:
3646
3647
                    dim = qkv_layout.find("3")
                    qkv = _combine_tensors([q, k, v], dim)
3648
                    qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1])
3649
3650
                    qkv_no_fp8 = cast_from_fp8(
                        qkv_c._data,
3651
                        fp8_meta["scaling_fwd"],
3652
3653
3654
3655
3656
                        META_QKV,
                        fp8_dtype_forward,
                        TE_DType[qkv.dtype],
                    ).view(qkv.shape)
                    q, k, v = _SplitAlongDim.apply(qkv_no_fp8, dim, [1, 1, 1])
3657
3658
                    q, k, v = [x.squeeze(dim) for x in [q, k, v]]
                if qkv_group == 2:
3659
3660
                    q = cast_from_fp8(
                        q._data,
3661
                        fp8_meta["scaling_fwd"],
3662
3663
3664
3665
3666
3667
                        META_QKV,
                        fp8_dtype_forward,
                        TE_DType[q.dtype],
                    ).view(q.shape)
                    dim = qkv_layout.split("_")[1].find("2")
                    kv = _combine_tensors([k, v], dim)
3668
                    kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1])
3669
3670
                    kv_no_fp8 = cast_from_fp8(
                        kv_c._data,
3671
                        fp8_meta["scaling_fwd"],
3672
3673
3674
3675
3676
                        META_QKV,
                        fp8_dtype_forward,
                        TE_DType[kv.dtype],
                    ).view(kv.shape)
                    k, v = _SplitAlongDim.apply(kv_no_fp8, dim, [1, 1])
3677
3678
                    k, v = [x.squeeze(dim) for x in [k, v]]
                if qkv_group == 3:
3679
3680
                    q = cast_from_fp8(
                        q._data,
3681
                        fp8_meta["scaling_fwd"],
3682
3683
3684
3685
3686
3687
                        META_QKV,
                        fp8_dtype_forward,
                        TE_DType[q.dtype],
                    ).view(q.shape)
                    k = cast_from_fp8(
                        k._data,
3688
                        fp8_meta["scaling_fwd"],
3689
3690
3691
3692
3693
3694
                        META_QKV,
                        fp8_dtype_forward,
                        TE_DType[k.dtype],
                    ).view(k.shape)
                    v = cast_from_fp8(
                        v._data,
3695
                        fp8_meta["scaling_fwd"],
3696
3697
3698
3699
                        META_QKV,
                        fp8_dtype_forward,
                        TE_DType[v.dtype],
                    ).view(v.shape)
3700
3701
                out_save = cast_from_fp8(
                    out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]),
3702
3703
3704
3705
3706
3707
3708
3709
3710
3711
3712
                    fp8_meta["scaling_fwd"],
                    META_O,
                    fp8_dtype_forward,
                    qkv_dtype,
                ).view(out_fp8.shape)

            fp8_tensors = (
                q_fp8,
                k_fp8,
                v_fp8,
                out_fp8,
3713
                fp8_meta["scaling_fwd"].scale.clone(),
3714
3715
                fp8_meta["scaling_fwd"].scale_inv.clone(),
            )
3716
        else:
3717
            logger.debug("Running forward in %s", q.dtype)
3718
            out_ret, aux_ctx_tensors = fused_attn_fwd(
3719
3720
3721
3722
3723
3724
3725
3726
3727
3728
3729
3730
3731
3732
3733
3734
3735
3736
3737
3738
3739
3740
3741
3742
3743
3744
3745
3746
3747
                is_training,
                max_seqlen_q,
                max_seqlen_kv,
                cu_seqlens_q,
                cu_seqlens_kv,
                q,
                k,
                v,
                qkv_dtype,
                fused_attention_backend,
                attn_bias,
                seq_offsets_q,
                seq_offsets_k,
                seq_offsets_v,
                seq_offsets_o,
                None,
                None,
                None,
                None,
                None,
                None,
                attn_scale,
                dropout_p,
                fast_zero_fill,
                qkv_layout,
                attn_bias_type,
                attn_mask_type,
                rng_gen,
            )
3748
3749
            out_save = out_ret
            fp8_tensors = (None, None, None, None, None, None)
3750

3751
        from .cpu_offload import CPUOffloadEnabled
3752

3753
        if CPUOffloadEnabled:
3754
            tensor_list = [q, k, v, out_save, cu_seqlens_q, cu_seqlens_kv]
3755
            qkv_layout = "sbhd_sbhd_sbhd"
3756
3757
3758
3759
            for tensor in tensor_list:
                if tensor is not None:
                    tensor.activation_offloading = True

3760
3761
        ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
        qkvo_tensors = (q, k, v, out_save) if not ctx.fp8 else (None, None, None, None)
3762
3763
3764
3765
3766
3767
3768
3769
3770
3771
3772
        ctx.save_for_backward(
            *qkvo_tensors,
            cu_seqlens_q,
            cu_seqlens_kv,
            seq_offsets_q,
            seq_offsets_k,
            seq_offsets_v,
            seq_offsets_o,
            *fp8_tensors,
            *aux_ctx_tensors,
        )
3773
        ctx.fp8_meta = fp8_meta
3774
3775
3776
3777
3778
3779
3780
3781
3782
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_kv = max_seqlen_kv
        ctx.qkv_dtype = qkv_dtype
        ctx.attn_scale = attn_scale
        ctx.dropout_p = dropout_p
        ctx.fast_zero_fill = fast_zero_fill
        ctx.qkv_layout = qkv_layout
        ctx.attn_bias_type = attn_bias_type
        ctx.attn_mask_type = attn_mask_type
3783
        ctx.fused_attention_backend = (
3784
            fused_attention_backend if ctx.fp8 else FusedAttnBackend["F16_arbitrary_seqlen"]
3785
        )
3786
3787
        ctx.use_FAv2_bwd = use_FAv2_bwd

3788
        return out_ret
3789
3790
3791

    @staticmethod
    def backward(ctx, d_out):
3792
        logger = logging.getLogger("FusedAttnFunc")
3793
        if ctx.fp8_meta["recipe"].fp8_mha:
3794
3795
3796
            assert isinstance(
                d_out, Float8Tensor
            ), "Gradient of the DPA output must be in Float8Tensor type for FP8 MHA."
3797
3798
3799
            d_out_f8tensor = d_out
            d_out = d_out._data

3800
        d_out = d_out.contiguous()
3801
3802
3803
3804
3805
3806
3807
3808
3809
3810
3811
3812
3813
3814
3815
3816
3817
3818
3819
        (
            q,
            k,
            v,
            out,
            cu_seqlens_q,
            cu_seqlens_kv,
            seq_offsets_q,
            seq_offsets_k,
            seq_offsets_v,
            seq_offsets_o,
            q_fp8,
            k_fp8,
            v_fp8,
            out_fp8,
            fwd_scales,
            fwd_scale_invs,
            *aux_ctx_tensors,
        ) = ctx.saved_tensors
3820
3821
        if not aux_ctx_tensors[0].is_contiguous():
            aux_ctx_tensors[0] = aux_ctx_tensors[0].contiguous()
3822
        if ctx.use_FAv2_bwd:
3823
            softmax_lse, rng_state = aux_ctx_tensors
3824
3825
3826
3827
            dq = torch.empty_like(q)
            dk = torch.empty_like(k)
            dv = torch.empty_like(v)
            maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
3828
            d_out, q, k, v, out = [maybe_contiguous(x) for x in (d_out, q, k, v, out)]
3829
            flash_attn_cuda_bwd(
3830
3831
3832
3833
3834
3835
3836
3837
3838
3839
3840
3841
3842
3843
3844
3845
3846
3847
3848
                d_out,
                q,
                k,
                v,
                out,
                softmax_lse,
                dq,
                dk,
                dv,
                cu_seqlens_q,
                cu_seqlens_kv,
                ctx.max_seqlen_q,
                ctx.max_seqlen_kv,
                ctx.dropout_p,
                ctx.attn_scale,
                False,
                "causal" in ctx.attn_mask_type,
                None,
                rng_state,
3849
            )
3850
3851
3852
            dq = dq[..., : d_out.shape[-1]]
            dk = dk[..., : d_out.shape[-1]]
            dv = dv[..., : d_out.shape[-1]]
3853
        else:
3854
3855
            with torch.cuda.nvtx.range("_FusedAttn"):
                if ctx.fp8:
3856
                    logger.debug("Running backward in FP8")
3857
3858
                    fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True)
                    fp8_dtype_backward = get_fp8_te_dtype(
3859
3860
                        ctx.fp8_meta["recipe"], fprop_tensor=False
                    )
3861
3862
                    if ctx.fp8_meta["recipe"].fp8_mha:
                        d_out_fp8 = d_out
3863
                        ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] = d_out_f8tensor._scale_inv
3864
3865
3866
                    else:
                        d_out_fp8 = cast_to_fp8(
                            d_out.view(-1, d_out.shape[-2] * d_out.shape[-1]),
3867
3868
3869
3870
                            ctx.fp8_meta["scaling_bwd"],
                            META_DO,
                            fp8_dtype_backward,
                        ).view(d_out.shape)
3871
                    dq_fp8, dk_fp8, dv_fp8, *rest = fused_attn_bwd(
3872
3873
3874
3875
3876
3877
3878
3879
3880
3881
3882
3883
                        ctx.max_seqlen_q,
                        ctx.max_seqlen_kv,
                        cu_seqlens_q,
                        cu_seqlens_kv,
                        q_fp8,
                        k_fp8,
                        v_fp8,
                        out_fp8,
                        d_out_fp8,
                        fp8_dtype_forward,
                        fp8_dtype_backward,
                        aux_ctx_tensors,
3884
                        ctx.fused_attention_backend,
3885
3886
3887
3888
3889
3890
3891
3892
3893
3894
3895
3896
3897
3898
3899
3900
3901
3902
3903
3904
3905
                        seq_offsets_q,
                        seq_offsets_k,
                        seq_offsets_v,
                        seq_offsets_o,
                        fwd_scale_invs[META_QKV],  # d_scale_qkv,
                        fwd_scale_invs[META_S],  # d_scale_s,
                        fwd_scale_invs[META_O],  # d_scale_o,
                        ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO],  # d_scale_do
                        ctx.fp8_meta["scaling_bwd"].scale_inv[META_DP],  # d_scale_dp
                        fwd_scales[META_S],  # q_scale_s
                        ctx.fp8_meta["scaling_bwd"].scale[META_DP],  # q_scale_dp
                        ctx.fp8_meta["scaling_bwd"].scale[META_DQKV],  # q_scale_dqkv
                        ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DP],  # amax_dp
                        ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DQKV],  # amax_dqkv
                        ctx.attn_scale,
                        ctx.dropout_p,
                        ctx.fast_zero_fill,
                        ctx.qkv_layout,
                        ctx.attn_bias_type,
                        ctx.attn_mask_type,
                    )
3906
                    if ctx.fp8_meta["recipe"].fp8_mha:
3907
3908
                        dq = Float8Tensor(
                            data=dq_fp8,
3909
3910
3911
3912
3913
                            fp8_meta=ctx.fp8_meta,
                            fp8_meta_forward=False,
                            fp8_meta_index=META_DQKV,
                            fp8_dtype=fp8_dtype_backward,
                            dtype=d_out_f8tensor.dtype,
3914
3915
3916
                        )
                        dk = Float8Tensor(
                            data=dk_fp8,
3917
3918
3919
3920
3921
                            fp8_meta=ctx.fp8_meta,
                            fp8_meta_forward=False,
                            fp8_meta_index=META_DQKV,
                            fp8_dtype=fp8_dtype_backward,
                            dtype=d_out_f8tensor.dtype,
3922
3923
3924
                        )
                        dv = Float8Tensor(
                            data=dv_fp8,
3925
3926
3927
3928
3929
                            fp8_meta=ctx.fp8_meta,
                            fp8_meta_forward=False,
                            fp8_meta_index=META_DQKV,
                            fp8_dtype=fp8_dtype_backward,
                            dtype=d_out_f8tensor.dtype,
3930
                        )
3931
                    else:
3932
                        qkv_group = len(ctx.qkv_layout.split("_"))
3933
                        if qkv_group == 1:
3934
3935
3936
3937
3938
3939
3940
3941
3942
3943
3944
3945
3946
                            dim = ctx.qkv_layout.find("3")
                            dqkv_fp8 = _combine_tensors([dq_fp8, dk_fp8, dv_fp8], dim)
                            dqkv_c_fp8 = dqkv_fp8.view(
                                -1, dqkv_fp8.shape[-3] * dqkv_fp8.shape[-2] * dqkv_fp8.shape[-1]
                            )
                            dqkv = cast_from_fp8(
                                dqkv_c_fp8,
                                ctx.fp8_meta["scaling_bwd"],
                                META_DQKV,
                                fp8_dtype_backward,
                                ctx.qkv_dtype,
                            ).view(dqkv_fp8.shape)
                            dq, dk, dv = _SplitAlongDim.apply(dqkv, dim, [1, 1, 1])
3947
3948
3949
3950
                            dq, dk, dv = [x.squeeze(dim) for x in [dq, dk, dv]]
                        if qkv_group == 2:
                            dq = cast_from_fp8(
                                dq_fp8.view(-1, dq_fp8.shape[-2] * dq_fp8.shape[-1]),
3951
3952
3953
3954
3955
3956
3957
3958
3959
3960
3961
3962
3963
3964
3965
3966
3967
3968
                                ctx.fp8_meta["scaling_bwd"],
                                META_DQKV,
                                fp8_dtype_backward,
                                ctx.qkv_dtype,
                            ).view(dq_fp8.shape)
                            dim = ctx.qkv_layout.split("_")[1].find("2")
                            dkv_fp8 = _combine_tensors([dk_fp8, dv_fp8], dim)
                            dkv_c_fp8 = dkv_fp8.view(
                                -1, dkv_fp8.shape[-3] * dkv_fp8.shape[-2] * dkv_fp8.shape[-1]
                            )
                            dkv = cast_from_fp8(
                                dkv_c_fp8,
                                ctx.fp8_meta["scaling_bwd"],
                                META_DQKV,
                                fp8_dtype_backward,
                                ctx.qkv_dtype,
                            ).view(dkv_fp8.shape)
                            dk, dv = _SplitAlongDim.apply(dkv, dim, [1, 1])
3969
3970
3971
3972
                            dk, dv = [x.squeeze(dim) for x in [dk, dv]]
                        if qkv_group == 3:
                            dq = cast_from_fp8(
                                dq_fp8.view(-1, dq_fp8.shape[-2] * dq_fp8.shape[-1]),
3973
3974
3975
3976
3977
                                ctx.fp8_meta["scaling_bwd"],
                                META_DQKV,
                                fp8_dtype_backward,
                                ctx.qkv_dtype,
                            ).view(dq_fp8.shape)
3978
3979
                            dk = cast_from_fp8(
                                dk_fp8.view(-1, dk_fp8.shape[-2] * dk_fp8.shape[-1]),
3980
3981
3982
3983
3984
                                ctx.fp8_meta["scaling_bwd"],
                                META_DQKV,
                                fp8_dtype_backward,
                                ctx.qkv_dtype,
                            ).view(dk_fp8.shape)
3985
3986
                            dv = cast_from_fp8(
                                dv_fp8.view(-1, dv_fp8.shape[-2] * dv_fp8.shape[-1]),
3987
3988
3989
3990
3991
                                ctx.fp8_meta["scaling_bwd"],
                                META_DQKV,
                                fp8_dtype_backward,
                                ctx.qkv_dtype,
                            ).view(dv_fp8.shape)
3992
                else:
3993
                    logger.debug("Running backward in %s", q.dtype)
3994
3995
3996
                    if d_out.dtype == torch.uint8:
                        d_out = d_out_f8tensor.from_float8(q.dtype)
                    dq, dk, dv, *rest = fused_attn_bwd(
3997
3998
3999
4000
4001
4002
4003
4004
4005
4006
4007
4008
                        ctx.max_seqlen_q,
                        ctx.max_seqlen_kv,
                        cu_seqlens_q,
                        cu_seqlens_kv,
                        q,
                        k,
                        v,
                        out,
                        d_out,
                        ctx.qkv_dtype,
                        ctx.qkv_dtype,
                        aux_ctx_tensors,
4009
                        ctx.fused_attention_backend,
4010
4011
4012
4013
4014
4015
4016
4017
4018
4019
4020
4021
4022
4023
4024
4025
4026
4027
4028
4029
4030
                        seq_offsets_q,
                        seq_offsets_k,
                        seq_offsets_v,
                        seq_offsets_o,
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        None,
                        ctx.attn_scale,
                        ctx.dropout_p,
                        ctx.fast_zero_fill,
                        ctx.qkv_layout,
                        ctx.attn_bias_type,
                        ctx.attn_mask_type,
                    )
4031

4032
4033
        # if no_bias or alibi, return dqkv
        if ctx.attn_bias_type in ["no_bias", "alibi"]:
4034
4035
4036
4037
4038
4039
4040
4041
4042
4043
4044
4045
4046
4047
4048
4049
4050
4051
4052
4053
4054
4055
4056
4057
4058
4059
4060
4061
4062
            return (
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                dq,
                dk,
                dv,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
                None,
            )
4063
        # else, return (dqkv, dbias)
4064
4065
4066
4067
4068
4069
4070
4071
4072
4073
4074
4075
4076
4077
4078
4079
4080
4081
4082
4083
4084
4085
4086
4087
4088
4089
4090
4091
4092
        return (
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            dq,
            dk,
            dv,
            None,
            rest[0],
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
        )
4093

4094

4095
class FusedAttention(torch.nn.Module):
4096
4097
4098
4099
4100
4101
4102
4103
4104
    """Dot product attention, with multiple backends:

    1. FusedAttnBackend["F16_max512_seqlen"]
       cuDNN based fused attention for FP16/BF16 and <=512 sequence length.
    2. FusedAttnBackend["F16_arbitrary_seqlen"]
       cuDNN based fused attention for FP16/BF16 and any sequence length.

    Support matrix:

4105
4106
4107
4108
    | backend       | 1                       | 2                              |
    | flash based   | no                      | yes                            |
    | cuDNN based   | yes                     | yes                            |
    | qkv dtype     | fp16/bf16               | fp16/bf16                      |
4109
    | attn_type     | self/cross              | self/cross                     |
4110
    | qkv_layout    |                         |                                |
4111
    |  - (q,k,v)    | sb3hd, bs3hd            | sb3hd, bs3hd, sbh3d, bsh3d     |
4112
    |               | sbhd_sb2hd, bshd_bs2hd  | sbhd_sb2hd, bshd_bs2hd         |
4113
4114
    |               | bshd_bshd_bshd          | sbhd_sbh2d, bshd_bsh2d         |
    |               |                         | sbhd_sbhd_sbhd, bshd_bshd_bshd |
4115
4116
    | mask_type     | causal/padding/no_mask  | causal/padding/no_mask         |
    | bias_type     | post_scale_bias/no_bias | post_scale_bias/alibi/no_bias  |
4117
    | dropout       | yes                     | yes                            |
4118
4119
    | max_seqlen    | <=512, multiple of 64   | any, multiple of 64            |
    | head_dim      | 64                      | <=128, multiple of 8           |
4120
    | output dtype  | fp16/bf16               | fp16/bf16                      |
4121
4122
4123
4124
    """

    def __init__(
        self,
4125
        softmax_scale: float,
4126
4127
4128
        attention_dropout: float = 0.0,
        attention_dropout_ctx: Optional[Callable] = nullcontext,
        attention_type: str = "self",
4129
4130
        layer_number: Optional[int] = None,
        deterministic: bool = False,
4131
4132
4133
    ) -> None:
        super().__init__()

4134
        self.logger = logging.getLogger("FusedAttention")
4135
        self.softmax_scale = softmax_scale
4136
4137
4138
        self.attention_dropout = attention_dropout
        self.attention_dropout_ctx = attention_dropout_ctx
        self.attention_type = attention_type
4139
4140
4141
        self.use_FAv2_bwd = os.getenv(
            "NVTE_FUSED_ATTN_USE_FAv2_BWD", "0"
        ) == "1" and get_device_compute_capability() == (9, 0)
4142
4143
4144
4145
4146
4147
4148
4149
4150
4151
4152
4153
4154
4155
4156
4157
        self.layer_number = 1 if layer_number is None else layer_number
        if deterministic:
            # workspace optimization path is deterministic
            os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "-1"

        # CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT
        # - unset:       enables workspace optimization when required workspace is <= 256MB
        #                or when bias gradient needs to be computed
        # - n:           enables workspace optimization when required workspace is <= n bytes
        # - -1:          enables workspace optimization always
        # - 0:           disables workspace optimization always
        if "NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT" in os.environ:
            if os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] == "0":
                os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "0"
            if os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] == "1":
                os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "-1"
4158

4159
        def remove_extra_states_check(self, incompatible_keys):  # pylint: disable=unused-argument
4160
4161
            """
            Temporarily remove fused_attention._extra_state as a missing key
4162
4163
4164
4165
            or an unexpected key when loading TransformerEngine checkpoints.
            Please store FP8 metadata as DotProductAttention's _extra_state,
            rather than FusedAttention's _extra_state. This hook will be
            phased out in TransformerEngine 2.0.
4166
4167
            """
            for key in incompatible_keys.missing_keys:
4168
                if "fused_attention._extra_state" in key:
4169
                    incompatible_keys.missing_keys.remove(key)
4170
4171
4172
4173
4174
4175
4176
            for key in incompatible_keys.unexpected_keys:
                if "fused_attention._extra_state" in key:
                    incompatible_keys.unexpected_keys.remove(key)
                    warnings.warn(
                        "fused_attention._extra_state is not loaded from checkpoint. Please map "
                        "FusedAttention's _extra_state to DotProductAttention's _extra_state."
                    )
4177

4178
4179
        self.register_load_state_dict_post_hook(remove_extra_states_check)

4180
    @no_torch_dynamo()
4181
4182
4183
4184
4185
    def forward(
        self,
        query_layer: torch.Tensor,
        key_layer: torch.Tensor,
        value_layer: torch.Tensor,
4186
4187
4188
        qkv_layout: str = "sbh3d",
        cu_seqlens_q: Optional[torch.Tensor] = None,
        cu_seqlens_kv: Optional[torch.Tensor] = None,
4189
4190
4191
4192
        seq_offsets_q: Optional[torch.Tensor] = None,
        seq_offsets_k: Optional[torch.Tensor] = None,
        seq_offsets_v: Optional[torch.Tensor] = None,
        seq_offsets_o: Optional[torch.Tensor] = None,
4193
4194
        max_seqlen_q: Optional[int] = None,
        max_seqlen_kv: Optional[int] = None,
4195
        attn_mask_type: str = "causal",
4196
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
4197
        fused_attention_backend: tex.NVTE_Fused_Attn_Backend = tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend,
4198
4199
4200
        core_attention_bias_type: str = "no_bias",
        core_attention_bias: Optional[torch.Tensor] = None,
        fast_zero_fill: bool = True,
4201
4202
4203
        cp_group: Optional[dist_group_type] = None,
        cp_global_ranks: List[int] = None,
        cp_stream: torch.cuda.Stream = None,
4204
4205
        fp8: bool = False,
        fp8_meta: Optional[Dict[str, Any]] = None,
4206
4207
    ) -> torch.Tensor:
        """fused attention fprop"""
4208
4209
4210
        assert (
            fused_attention_backend != tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend
        ), "No fused attention backend supports this input combination!"
4211
        assert (
4212
4213
4214
            (query_layer.dtype in [torch.float16, torch.bfloat16, torch.uint8])
            and (key_layer.dtype in [torch.float16, torch.bfloat16, torch.uint8])
            and (value_layer.dtype in [torch.float16, torch.bfloat16, torch.uint8])
4215
        ), "FusedAttention only supports FP16 and BF16 data types."
4216
4217
        assert (
            query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda
4218
        ), "FusedAttention only supports CUDA tensors."
4219
4220
        assert (
            qkv_layout in QKVLayouts
4221
        ), f"FusedAttention does not support qkv_layout = {qkv_layout}!"
4222

4223
4224
        context_parallel = (cp_group is not None) and (get_distributed_world_size(cp_group) != 1)

4225
        qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
4226

4227
4228
        if qkv_format in ["sbhd", "bshd"]:
            if qkv_format == "sbhd":
4229
                batch_size, max_seqlen_q, max_seqlen_kv = (
4230
4231
4232
4233
4234
                    query_layer.shape[1],
                    query_layer.shape[0],
                    key_layer.shape[0],
                )
            if qkv_format == "bshd":
4235
                batch_size, max_seqlen_q, max_seqlen_kv = (
4236
4237
4238
4239
4240
                    query_layer.shape[0],
                    query_layer.shape[1],
                    key_layer.shape[1],
                )
            if "padding" in attn_mask_type:
4241
4242
                assert not context_parallel, "Padding mask not supported with context parallelism!"

4243
4244
4245
4246
4247
                if cu_seqlens_q is None or cu_seqlens_kv is None:
                    if attention_mask is None:
                        raise RuntimeError(
                            "Please provide attention_mask or cu_seqlens for padding!"
                        )
4248
                    if self.attention_type == "self":
4249
4250
                        cu_seqlens_q = get_cu_seqlens(attention_mask)
                        cu_seqlens_kv = cu_seqlens_q
4251
                    else:
4252
4253
                        cu_seqlens_q = get_cu_seqlens(attention_mask[0])
                        cu_seqlens_kv = get_cu_seqlens(attention_mask[1])
4254
            else:
4255
4256
4257
4258
4259
4260
4261
4262
4263
4264
4265
4266
                if cu_seqlens_q is None:
                    cu_seqlens_q = _get_full_cu_seqlens(
                        batch_size,
                        max_seqlen_q,
                        query_layer.device,
                    )
                if cu_seqlens_kv is None:
                    cu_seqlens_kv = _get_full_cu_seqlens(
                        batch_size,
                        max_seqlen_kv,
                        key_layer.device,
                    )
4267
4268
4269
        if qkv_format == "thd":
            assert (
                max_seqlen_q is not None
4270
4271
4272
                and max_seqlen_kv is not None
                and cu_seqlens_q is not None
                and cu_seqlens_kv is not None
4273
4274
4275
            ), "max_seqlen_q/kv and cu_seqlens_q/kv can not be None when qkv_format is thd!"
            if (
                seq_offsets_q is None
4276
4277
                or seq_offsets_k is None
                or seq_offsets_v is None
4278
                or seq_offsets_o is None
4279
4280
4281
4282
                or context_parallel
            ):
                qkv_group = "".join([x for x in qkv_layout if x not in "bst"])
                qkv_group = "hd_hd_hd" if context_parallel else qkv_group
4283
4284
4285
4286
                num_heads = query_layer.shape[-2]
                num_gqa_groups = key_layer.shape[-2]
                head_dim = query_layer.shape[-1]
                seq_offsets_o = num_heads * head_dim * cu_seqlens_q
4287
                if qkv_group == "hd_hd_hd":
4288
4289
4290
                    seq_offsets_q = num_heads * head_dim * cu_seqlens_q
                    seq_offsets_k = num_gqa_groups * head_dim * cu_seqlens_kv
                    seq_offsets_v = num_gqa_groups * head_dim * cu_seqlens_kv
4291
                if qkv_group in ["3hd", "h3d"]:
4292
4293
4294
                    seq_offsets_q = num_heads * head_dim * 3 * cu_seqlens_q
                    seq_offsets_k = num_heads * head_dim * 3 * cu_seqlens_q
                    seq_offsets_v = num_heads * head_dim * 3 * cu_seqlens_q
4295
                if qkv_group in ["hd_2hd", "hd_h2d"]:
4296
4297
4298
                    seq_offsets_q = num_heads * head_dim * cu_seqlens_q
                    seq_offsets_k = num_gqa_groups * head_dim * 2 * cu_seqlens_kv
                    seq_offsets_v = num_gqa_groups * head_dim * 2 * cu_seqlens_kv
4299
4300
4301

        qkv_dtype = TE_DType[query_layer.dtype]

4302
4303
4304
4305
4306
        use_FAv2_bwd = (
            self.use_FAv2_bwd
            and (core_attention_bias_type == "no_bias")
            and (fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen)
        )
4307
4308

        if context_parallel:
4309
            assert (
4310
4311
4312
4313
4314
4315
4316
4317
                fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen
            ), f"{fused_attention_backend} does not work with context parallelism!"
            assert core_attention_bias_type not in [
                "alibi"
            ], f"{core_attention_bias_type} is not supported with context parallelism!"
            query_layer, key_layer, value_layer = [
                x.contiguous() for x in (query_layer, key_layer, value_layer)
            ]
4318
4319
4320
            with self.attention_dropout_ctx():
                output = attn_forward_func_with_cp(
                    self.training,
4321
4322
4323
4324
4325
4326
4327
4328
4329
4330
4331
                    query_layer,
                    key_layer,
                    value_layer,
                    cu_seqlens_q,
                    cu_seqlens_kv,
                    max_seqlen_q,
                    max_seqlen_kv,
                    seq_offsets_q,
                    seq_offsets_k,
                    seq_offsets_v,
                    seq_offsets_o,
4332
                    self.attention_dropout if self.training else 0.0,
4333
4334
4335
                    cp_group,
                    cp_global_ranks,
                    cp_stream,
4336
                    softmax_scale=self.softmax_scale,
4337
                    qkv_format=qkv_format,
4338
                    attn_mask_type=attn_mask_type,
4339
4340
                    attn_bias_type=core_attention_bias_type,
                    attn_bias=core_attention_bias,
4341
4342
4343
                    use_fused_attention=True,
                )
        else:
4344
4345
4346
4347
4348
            with self.attention_dropout_ctx():
                if fp8:
                    assert fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_FP8, (
                        f"cuDNN attention sub-backend {int(tex.NVTE_Fused_Attn_Backend.NVTE_FP8)}"
                        " is required for FP8 attention!"
4349
                    )
4350
4351
4352
4353
4354
4355
4356
4357
4358
4359
4360
4361
4362
4363
4364
4365
4366
4367
4368
4369
4370
4371
4372
4373
4374
4375
4376
4377
4378
4379
                    assert (
                        fp8_meta is not None
                    ), "FP8 metadata fp8_meta is required for FP8 attention!"
                output = FusedAttnFunc.apply(
                    self.training,
                    max_seqlen_q,
                    max_seqlen_kv,
                    cu_seqlens_q,
                    cu_seqlens_kv,
                    seq_offsets_q,
                    seq_offsets_k,
                    seq_offsets_v,
                    seq_offsets_o,
                    query_layer,
                    key_layer,
                    value_layer,
                    qkv_dtype,
                    core_attention_bias,
                    self.softmax_scale,
                    self.attention_dropout if self.training else 0.0,
                    fast_zero_fill,
                    qkv_layout,
                    core_attention_bias_type,
                    attn_mask_type,
                    None,  # rng_gen
                    fused_attention_backend,
                    use_FAv2_bwd,
                    fp8,
                    fp8_meta,
                )
4380

4381
4382
        # ...hd -> ...(hd)
        return output.view(*output.shape[:-2], -1)
4383
4384


4385
class DotProductAttention(TransformerEngineBaseModule):
4386
4387
4388
4389
4390
4391
    """Allows the model to jointly attend to information from different
    representation subspaces as described in the paper:
    `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.

    .. note::

4392
        Argument :attr:`attention_mask` in the `forward` call is only used when
4393
        :attr:`attn_mask_type` includes '"padding"' or `"arbitrary"`.
4394
4395
4396

    .. warning::

4397
        FlashAttention uses a non-deterministic algorithm for optimal performance. To observe
4398
        deterministic behavior at the cost of performance, use FlashAttention version >= `2.4.1`
4399
4400
        and set the environment variable :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO=0`. In order
        to disable`flash-attn` entirely, set :attr:`NVTE_FLASH_ATTN=0`.
4401
4402
4403
4404
4405
4406

    Parameters
    ----------
    num_attention_heads : int
                         number of attention heads in the transformer layer.
    kv_channels : int
4407
                number of key-query-value channels per attention head.
4408
4409
4410
4411
4412
4413
4414
4415
    num_gqa_groups : Optional[int] = None
                    number of GQA groups in the transformer layer.
                    Grouped Query Attention is described in
                    `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
                    This only affects the keys and values, not the queries.
                    GQA-1 is equivalent to Multi-Query Attention
                    (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
                    is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
4416
4417
    attention_dropout: float, default = 0.0
                      dropout probability for the dropout op during multi-head attention.
4418
    attn_mask_type: str, default = `causal`
4419
4420
4421
4422
4423
4424
4425
4426
4427
4428
4429
4430
4431
4432
                   type of attention mask passed into softmax operation, options are "`no_mask`",
                   "`padding`", "`causal`", "`padding,causal`", "`causal,padding`", and
                   "`arbitrary`", where "`padding,causal`" and "`causal,padding`" are equivalent.
                   This arg can be overridden by :attr:`attn_mask_type` in the `forward` method.
                   It is useful for cases involving compilation/tracing, e.g. ONNX export, and the
                   forward arg is useful for dynamically changing mask types, e.g. a different mask
                   for training and inference. For "`no_mask`", no attention mask is applied. For
                   "`causal`" or the causal mask in "`padding,causal`", TransformerEngine calculates
                   and applies an upper triangular mask to the softmax input. No user input is
                   needed. For "`padding`" or the padding mask in "`padding,causal`", users need to
                   provide the locations of padded tokens either via :attr:`cu_seqlens_q` and
                   :attr:`cu_seqlens_kv` in the shape of [batch_size + 1] or :attr:`attention_mask`
                   in the shape [batch_size, 1, 1, max_seq_len]. For the "`arbitrary`" mask, users
                   need to provide a mask that is broadcastable to the shape of softmax input.
4433
4434
4435
4436
4437
4438
    window_size: Optional[Tuple[int, int]], default = `None`
                sliding window size for local attention, where query at position i attends to keys
                in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
                + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
                window and causal mask specifically. Similar to :attr:`attn_mask_type`, it can
                be overridden by :attr:`window_size` in `forward` as well.
4439
4440
    attention_type: str, default = `self`
                   type of attention, either "`self`" and "`cross`".
4441
4442
4443
    layer_number: int, default = `None`
                 layer number of the current `DotProductAttention` when multiple such modules
                 are concatenated, for instance in consecutive transformer blocks.
4444
4445
4446
4447
4448
4449
4450
4451
4452
4453
    qkv_format: str, default = `sbhd`
               dimension format for `query_layer`, `key_layer` and `value_layer`,
               {`sbhd`, `bshd`, `thd`}. `s` stands for the sequence length, `b` batch size,
               `h` the number of heads, `d` head size, and `t` the total number of sequences
               in a batch, with `t = sum(s_i), for i = 0...b-1`. `sbhd` and `bshd` formats
               are used for when sequences in a batch are of equal length or padded to
               equal length, and the `thd` format is used for when sequences in a batch
               have different lengths. Please note that these formats do not reflect how
               tensors `query_layer`, `key_layer`, `value_layer` are laid out in memory.
               For that, please use `_get_qkv_layout` to gain the layout information.
4454
4455
4456
    softmax_scale: Optional[float], default = `None`
                softmax scale for the attention scores. If `None`, defaults to
                `1.0 / math.sqrt(kv_channels)`.
4457
4458
4459
4460
4461
4462
4463
4464
4465

    Parallelism parameters
    ----------------------
    sequence_parallel : bool, default = `False`
                       if set to `True`, uses sequence parallelism.
    tp_size : int, default = 1
             tensor parallel world size.
    tp_group : ProcessGroup, default = `None`
              tensor parallel process group.
4466
4467
4468
4469
4470
4471
4472
4473
4474
    cp_group : ProcessGroup, default = `None`
              context parallel process group.
    cp_global_ranks : list of global rank IDs, default = `None`
                     global rank IDs of GPUs that are in cp_group.
    cp_stream : CUDA stream, default = `None`
               context parallelism splits flash attention into multiple steps for
               compute and communication overlapping. To address the wave quantization
               issue of each split step, we add an additional CUDA stream so that we
               can overlap two flash attention kernels.
4475
4476
4477
4478
4479
4480
    """

    def __init__(
        self,
        num_attention_heads: int,
        kv_channels: int,
4481
        num_gqa_groups: Optional[int] = None,
4482
        attention_dropout: float = 0.0,
4483
        qkv_format: str = "sbhd",
4484
        attn_mask_type: str = "causal",
4485
        window_size: Optional[Tuple[int, int]] = None,
4486
4487
4488
4489
4490
        sequence_parallel: bool = False,
        tp_size: int = 1,
        get_rng_state_tracker: Optional[Callable] = None,
        tp_group: Optional[dist_group_type] = None,
        layer_number: Optional[int] = None,
4491
        attention_type: str = "self",
4492
        cp_group: Optional[dist_group_type] = None,
4493
        cp_global_ranks: List[int] = None,
4494
        cp_stream: torch.cuda.Stream = None,
4495
        softmax_scale: Optional[float] = None,
4496
4497
4498
    ) -> None:
        super().__init__()

4499
        self.logger = logging.getLogger("DotProductAttention")
4500
        self.qkv_format = qkv_format
4501
        attn_mask_type = attn_mask_type.replace(",", "_")
4502
4503
        if attn_mask_type == "causal_padding":
            attn_mask_type = "padding_causal"
4504
        self.attn_mask_type = attn_mask_type
4505
4506
        self.window_size = window_size
        self.window_size = check_set_window_size(attn_mask_type, self.window_size)
4507
        self.tp_size = tp_size if tp_group is None else get_distributed_world_size(tp_group)
4508
4509
        self.tp_group = tp_group
        self.get_rng_state_tracker = get_rng_state_tracker
4510
        self.num_attention_heads = num_attention_heads
4511
        self.layer_number = 1 if layer_number is None else layer_number
4512
4513
4514
        self.cp_group = cp_group
        self.cp_global_ranks = cp_global_ranks
        self.cp_stream = cp_stream
4515

4516
        self.hidden_size_per_attention_head = kv_channels
4517

4518
        self.num_gqa_groups = num_attention_heads if num_gqa_groups is None else num_gqa_groups
4519
4520
        self.num_gqa_groups_per_partition = int(self.num_gqa_groups // tp_size)

4521
4522
4523
        assert (
            num_attention_heads % self.num_gqa_groups == 0
        ), "The number of attention heads must be divisible by the number of GQA groups!"
4524

4525
        self.rng_states_tracker = None
4526
4527
4528
        if sequence_parallel or get_rng_state_tracker is None:
            attention_dropout_ctx = nullcontext
        else:
4529
4530
4531
            self.rng_states_tracker = get_rng_state_tracker()
            set_all_rng_states(self.rng_states_tracker.get_states())
            attention_dropout_ctx = self.rng_states_tracker.fork
4532

4533
4534
        if softmax_scale is None:
            softmax_scale = 1.0 / math.sqrt(kv_channels)
4535
4536

        self.device_compute_capability = get_device_compute_capability()
4537
4538
4539
        self.deterministic = (
            not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1")))
            or torch.are_deterministic_algorithms_enabled()
4540
        )
4541
4542
4543
4544

        self.use_flash_attention = int(
            os.getenv("NVTE_FLASH_ATTN", "1")
        ) and self.device_compute_capability >= (8, 0)
4545
4546
4547
4548
4549
        if int(os.getenv("NVTE_FLASH_ATTN", "1")) == 0:
            self.logger.debug("Disabling FlashAttention due to NVTE_FLASH_ATTN=0")
        if self.device_compute_capability < (8, 0):
            self.logger.debug("Disabling FlashAttention for compute capability < sm80")

4550
        if not _flash_attn_2_4_1_plus and self.deterministic:
4551
            self.use_flash_attention = False
4552
            self.logger.warning(
4553
4554
4555
                "Disabling usage of FlashAttention since version <2.4.1 does not support "
                "deterministic execution. In order to use FA with deterministic behavior,"
                " please install FlashAttention version >=2.4.1."
4556
4557
            )

4558
4559
4560
        self.use_fused_attention = int(
            os.getenv("NVTE_FUSED_ATTN", "1")
        ) and self.device_compute_capability >= (8, 0)
4561
4562
4563
4564
        if int(os.getenv("NVTE_FUSED_ATTN", "1")) == 0:
            self.logger.debug("Disabling FusedAttention due to NVTE_FUSED_ATTN=0")
        if self.device_compute_capability < (8, 0):
            self.logger.debug("Disabling FusedAttention for compute capability < sm80")
4565

4566
        assert attention_type in AttnTypes, f"attention_type {attention_type} not supported"
4567
4568
4569
4570

        self.attention_type = attention_type
        self.attention_dropout = attention_dropout

4571
4572
4573
4574
4575
4576
        attn_kwargs = {
            "attention_dropout": attention_dropout,
            "attention_dropout_ctx": attention_dropout_ctx,
        }

        if self.use_flash_attention:
4577
4578
4579
4580
4581
4582
4583
            self.flash_attention = FlashAttention(
                softmax_scale,
                attention_type=attention_type,
                layer_number=layer_number,
                deterministic=self.deterministic,
                **attn_kwargs,
            )
4584

4585
        # Instantiating three types since use of flash-attn and FusedAttention
4586
        # might be ruled out due to forward inputs.
4587
        if self.use_fused_attention:
4588
4589
4590
4591
4592
4593
4594
            self.fused_attention = FusedAttention(
                softmax_scale,
                attention_type=attention_type,
                layer_number=layer_number,
                deterministic=self.deterministic,
                **attn_kwargs,
            )
4595

4596
        self.unfused_attention = UnfusedDotProductAttention(
4597
4598
            softmax_scale, **attn_kwargs, layer_number=layer_number
        )
4599

4600
4601
4602
4603
4604
4605
4606
4607
4608
4609
4610
4611
        def remove_extra_states_check(self, incompatible_keys):  # pylint: disable=unused-argument
            """
            Temporarily remove core_attention._extra_state as a missing key
            when loading older TransformerEngine checkpoints. Will phase out
            this hook in TransformerEngine 2.0.
            """
            for key in incompatible_keys.missing_keys:
                if "core_attention._extra_state" in key:
                    incompatible_keys.missing_keys.remove(key)

        self.register_load_state_dict_post_hook(remove_extra_states_check)

4612
4613
4614
4615
    def _checkpointed_attention_forward(
        self,
        attention_func: Callable,
        *forward_args: Tuple[torch.Tensor, ...],
4616
        **forward_kwargs: Dict[str, Any],
4617
4618
4619
    ) -> torch.Tensor:
        """Forward method with activation checkpointing."""

4620
4621
        def custom_forward(*input_args, **input_kwargs):
            return attention_func(*input_args, **input_kwargs)
4622
4623
4624

        hidden_states = checkpoint(
            custom_forward,
4625
4626
4627
            distribute_saved_activations=False,
            get_rng_state_tracker=self.get_rng_state_tracker,
            tp_group=self.tp_group,
4628
            *forward_args,
4629
            **forward_kwargs,
4630
4631
4632
4633
        )

        return hidden_states

4634
4635
4636
4637
4638
4639
    def set_context_parallel_group(
        self,
        cp_group: Union[dist_group_type, None],
        cp_global_ranks: List[int],
        cp_stream: torch.cuda.Stream,
    ) -> None:
4640
4641
4642
4643
4644
4645
4646
4647
4648
4649
4650
4651
4652
        """
        Set the context parallel attributes for the given
        module before executing the forward pass.

        Parameters
        ----------
        cp_group : ProcessGroup
                  context parallel process group.
        cp_global_ranks : List[int]
                         list of global ranks in the context group.
        cp_stream : torch.cuda.Stream
                   cuda stream for context parallel execution.
        """
4653
4654
4655
4656
        self.cp_group = cp_group
        self.cp_global_ranks = cp_global_ranks
        self.cp_stream = cp_stream

4657
    @no_torch_dynamo(recursive=False)
4658
4659
4660
4661
4662
    def forward(
        self,
        query_layer: torch.Tensor,
        key_layer: torch.Tensor,
        value_layer: torch.Tensor,
4663
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
4664
4665
4666
        qkv_format: Optional[str] = None,
        cu_seqlens_q: Optional[torch.Tensor] = None,
        cu_seqlens_kv: Optional[torch.Tensor] = None,
4667
4668
4669
4670
        seq_offsets_q: Optional[torch.Tensor] = None,
        seq_offsets_k: Optional[torch.Tensor] = None,
        seq_offsets_v: Optional[torch.Tensor] = None,
        seq_offsets_o: Optional[torch.Tensor] = None,
4671
4672
        max_seqlen_q: Optional[int] = None,
        max_seqlen_kv: Optional[int] = None,
4673
        attn_mask_type: Optional[str] = None,
4674
        window_size: Optional[Tuple[int, int]] = None,
4675
        checkpoint_core_attention: bool = False,
4676
4677
        core_attention_bias_type: str = "no_bias",
        core_attention_bias: Optional[torch.Tensor] = None,
4678
        alibi_slopes: Optional[torch.Tensor] = None,
4679
        fast_zero_fill: bool = True,
4680
        inference_params: Optional[InferenceParams] = None,
4681
        is_first_microbatch: Optional[bool] = None,
4682
4683
4684
4685
4686
4687
    ) -> torch.Tensor:
        """
        Dot Product Attention Layer.

        .. note::

4688
4689
            Argument :attr:`attention_mask` is only used when :attr:`attn_mask_type`
            includes '"padding"' or `"arbitrary"`.
4690
4691
4692

        .. note::

4693
4694
4695
            Input tensor :attr:`query_layer` must be of shape
            (:attr:`sequence_length`, :attr:`batch_size`, :attr:`num_attention_heads`,
            :attr:`kv_channels`) and the tensors :attr:`key_layer` and :attr:`value_layer`
4696
            must each be of shape (:attr:`sequence_length`, :attr:`batch_size`,
4697
            :attr:`num_gqa_groups`, :attr:`kv_channels`). Output of shape
4698
4699
4700
            (:attr:`sequence_length`, :attr:`batch_size`, :attr:`num_attention_heads`
            * :attr:`kv_channels`) is returned.

4701
4702
        .. note::

4703
4704
4705
4706
4707
4708
4709
4710
4711
4712
4713
4714
4715
4716
4717
4718
4719
4720
            DotProductAttention supports three backends: 1) FlashAttention which calls
            HazyResearch/Dao-AILab's `flash-attn <https://arxiv.org/pdf/2305.13245.pdf>`_
            PyTorch API, 2) FusedAttention which has multiple fused attention implementations
            based on `cuDNN Graph API
            <https://docs.nvidia.com/deeplearning/cudnn/developer-guide/index.html#op-fusion>`_
            (see :attr:`FusedAttention` for more details on FusedAttention backends), and 3)
            UnfusedDotProductAttention which is the native PyTorch implementation
            with fused scaled masked softmax.

        .. note::

            Users can use environment variables :attr:`NVTE_FLASH_ATTN`, :attr:`NVTE_FUSED_ATTN`,
            and :attr:`NVTE_FUSED_ATTN_BACKEND` to control which DotProductAttention backend,
            and FusedAttention backend if applicable, to use. TransformerEngine prioritizes
            FlashAttention over FusedAttention and over UnfusedDotProductAttention.
            If FusedAttention is being used, users can also choose to switch to flash-attn's
            implementation for backward by setting :attr:`NVTE_FUSED_ATTN_USE_FAv2_BWD=1`
            (default: 0), because of the performance differences between various versions of
4721
4722
4723
4724
4725
            flash-attn and FusedAttention. Further, :attr:`NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT`
            can be used to enable (:attr:`1`) or disable (:attr:`0`) the workspace related
            optimizations in FusedAttention. When unset, TransformerEngine determines the code path
            based on its internal logic. These optimizations trade memory for performance
            and should be used with care.
4726

4727
4728
4729
4730
4731
4732
4733
4734
        Parameters
        ----------
        query_layer : torch.Tensor
                     Query tensor.
        key_layer : torch.Tensor
                   Key tensor.
        value_layer : torch.Tensor
                     Value tensor.
4735
4736
4737
4738
4739
4740
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]],
             default = `None`. Boolean tensor(s) used to mask out attention softmax input.
             It should be 'None' for 'causal' and 'no_mask' types. For 'padding' masks, it should be
             a single tensor of [batch_size, 1, 1, seqlen_q] for self-attention, and a tuple of
             two tensors in shapes [batch_size, 1, 1, seqlen_q] and [batch_size, 1, 1, seqlen_kv]
             for cross-attention. For the 'arbitrary' mask type, it should be in a shape that is
4741
4742
4743
             broadcastable to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv]. A `True` value
             means the corresponding position is masked out and a `False` means that position is
             allowed to participate in attention.
4744
4745
4746
4747
4748
4749
4750
4751
        qkv_format: str, default = `None`
                   If provided, overrides :attr:`qkv_format` from initialization.
        cu_seqlens_q: Optional[torch.Tensor], default = `None`
                   Cumulative sum of sequence lengths in a batch for `query_layer`,
                   with shape [batch_size + 1] and dtype torch.int32.
        cu_seqlens_kv: Optional[torch.Tensor], default = `None`
                   Cumulative sum of sequence lengths in a batch for `key_layer` and `value_layer`,
                   with shape [batch_size + 1] and dtype torch.int32.
4752
4753
4754
4755
4756
4757
4758
4759
4760
4761
4762
4763
        seq_offsets_q: Optional[torch.Tensor], default = `None`
                   Cumulative offset of different sequences in a batch for `query_layer`,
                   with shape [batch_size + 1] and dtype torch.int32. Required for `thd` layouts.
        seq_offsets_k: Optional[torch.Tensor], default = `None`
                   Cumulative offset of different sequences in a batch for `key_layer`,
                   with shape [batch_size + 1] and dtype torch.int32. Required for `thd` layouts.
        seq_offsets_v: Optional[torch.Tensor], default = `None`
                   Cumulative offset of different sequences in a batch for `value_layer`,
                   with shape [batch_size + 1] and dtype torch.int32. Required for `thd` layouts.
        seq_offsets_o: Optional[torch.Tensor], default = `None`
                   Cumulative offset of different sequences in a batch for forward output,
                   with shape [batch_size + 1] and dtype torch.int32. Required for `thd` layouts.
4764
4765
4766
4767
4768
4769
        max_seqlen_q: Optional[int], default = `None`
                      Maximum sequence length in `query_layer`.
                      Calculated from `cu_seqlens_q` if not provided.
        max_seqlen_kv: Optional[int], default = `None`
                       Maximum sequence length in `key_layer` and `value_layer`.
                       Calculated from `cu_seqlens_kv` if not provided.
4770
4771
4772
        attn_mask_type: {`no_mask`, `padding`, `causal`, `padding,causal`, `causal,padding`,
                       `arbitrary`}, default = `None`. Type of attention mask passed into
                       softmax operation. 'padding,causal' and 'causal,padding' are equivalent.
4773
        window_size: Optional[Tuple[int, int]], default = `None`
4774
                    Sliding window size for local attention.
4775
4776
4777
4778
4779
        checkpoint_core_attention : bool, default = `False`
                                   If true, forward activations for attention are recomputed
                                   during the backward pass in order to save memory that would
                                   otherwise be occupied to store the forward activations until
                                   backprop.
4780
        core_attention_bias_type: str, default = `no_bias`
4781
                    Bias type, {`no_bias`, `pre_scale_bias`, `post_scale_bias`, `alibi`}
4782
        core_attention_bias: Optional[torch.Tensor], default = `None`
4783
4784
                    Bias tensor for Q * K.T, shape [1, num_head, max_seqlen_q, max_seqlen_kv].
                    It should be 'None' for 'no_bias' and 'alibi' bias types.
4785
4786
4787
4788
        alibi_slopes: Optional[torch.Tensor], default = `None`
                     ALiBi slopes in FP32 and shape [nheads] or [batch_size, nheads].
                     It adds a bias of (-alibi_slope * (i + seqlen_k - seqlen_q - j))
                     to the attention score of query i and key j.
4789
        fast_zero_fill: bool, default = `True`
4790
                    Whether to use the fast path to set output tensors to 0 or not.
4791
4792
4793
4794
4795
4796
4797
4798
4799
4800
        inference_params: Optional[InferenceParams], default = `None`
            Optimizes execution performance during inference by caching Keys and Values of the
            current decoding iteration. These cached values are appended to the K and V values
            computed in previous iterations, eliminating the need to recalculate them for the
            entire sequence.
            Initialization of `inference_params` is required prior to use to ensure sufficient
            memory allocation.
            Adjustments of the sequence_len_offset should be done after a complete forward pass.
            If rotary positional embeddings (RoPE) are utilized, they must be prepared beforehand.
            Supports "sbhd" and "bshd" layouts, with the "sbhd" layout being more efficient.
4801
4802
4803
4804
4805
4806
4807
4808
4809
4810
4811
4812
4813
        is_first_microbatch : {True, False, None}, default = None
                             During training using either gradient accumulation or
                             pipeline parallelism a minibatch of data is further split
                             into microbatches. Between the microbatches of the same minibatch
                             the model weights are not updated. Setting this parameter indicates
                             whether the current microbatch is the first in a minibatch or not.
                             When set, this parameter enables additional optimizations:

                             * during FP8 training, it allows caching of the FP8 versions of
                               the weights
                             * it also allows skipping gradient accumulation during the
                               first microbatch (since it is the first gradient being
                               produced)
4814
        """
4815
4816
4817
4818
4819
4820
4821
4822
4823
4824
4825
4826
4827
4828
4829
4830
4831
4832
4833
4834
4835
4836
4837
4838
        with self.prepare_forward(
            query_layer,
            is_first_microbatch,
            num_gemms=3,
            allow_non_contiguous=True,
        ) as query_layer:

            if self.fp8:
                forced_fp8_dpa = ""
                if self.fp8_meta["recipe"].fp8_mha:
                    if not self.fp8_meta["recipe"].fp8_dpa:
                        self.fp8_meta["recipe"].fp8_dpa = True
                        forced_fp8_dpa = " (forced)"

            if self.fp8 and self.fp8_meta["recipe"].fp8_dpa:
                forward_dtype = get_fp8_te_dtype(self.fp8_meta["recipe"], fprop_tensor=True)
                backward_dtype = get_fp8_te_dtype(self.fp8_meta["recipe"], fprop_tensor=False)
                assert forward_dtype in [
                    tex.DType.kFloat8E4M3,
                    tex.DType.kFloat8E5M2,
                ] and backward_dtype in [
                    tex.DType.kFloat8E4M3,
                    tex.DType.kFloat8E5M2,
                ], """DotProductAttention only supports "E4M3" and "E5M2" FP8 data types."""
4839

4840
4841
4842
            assert (
                query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda
            ), "DotProductAttention only supports CUDA tensors."
4843

4844
            assert key_layer.shape == value_layer.shape, "Keys and values must have the same shape!"
4845

4846
4847
4848
4849
4850
4851
4852
4853
            if attn_mask_type is not None:
                window_size = check_set_window_size(attn_mask_type, window_size)
            if attn_mask_type is None:
                attn_mask_type = self.attn_mask_type
            else:
                attn_mask_type = attn_mask_type.replace(",", "_")
                if attn_mask_type == "causal_padding":
                    attn_mask_type = "padding_causal"
4854

4855
            assert (
4856
4857
4858
4859
4860
4861
                attn_mask_type in AttnMaskTypes
            ), f"Attention mask type {attn_mask_type} is not supported!"
            if qkv_format == "thd":
                assert (
                    "padding" in attn_mask_type
                ), "Attention mask type must be padding or padding_causal for qkv_format=thd!"
4862

4863
4864
4865
4866
4867
4868
4869
            if self.rng_states_tracker is not None and is_graph_capturing():
                assert isinstance(
                    self.rng_states_tracker, CudaRNGStatesTracker
                ), "Unsupported RNG states tracker."
                assert (
                    graph_safe_rng_available()
                ), "Upgrade PyTorch version to get RNG manipulation support for cuda graph capture."
4870

4871
4872
            if window_size is None:
                window_size = self.window_size
4873

4874
4875
            if qkv_format is None:
                qkv_format = self.qkv_format
4876

4877
4878
            if inference_params is not None:
                assert self.layer_number is not None, "Layer number must be set!"
4879

4880
4881
4882
                if qkv_format == "bshd":
                    key_layer = key_layer.transpose(0, 1)
                    value_layer = value_layer.transpose(0, 1)
4883

4884
4885
4886
4887
                (
                    inference_key_memory,
                    inference_value_memory,
                ) = inference_params.key_value_memory_dict[self.layer_number]
4888

4889
4890
4891
                batch_start = inference_params.batch_size_offset
                batch_end = batch_start + key_layer.size(1)
                assert batch_end <= inference_key_memory.size(1)
4892

4893
4894
4895
                sequence_start = inference_params.sequence_len_offset
                sequence_end = sequence_start + key_layer.size(0)
                assert sequence_end <= inference_key_memory.size(0)
4896

4897
4898
4899
4900
4901
4902
4903
4904
4905
                # Copy keys and values into KV-cache
                inference_key_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = (
                    key_layer
                )
                inference_value_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = (
                    value_layer
                )
                key_layer = inference_key_memory[:sequence_end, batch_start:batch_end, ...]
                value_layer = inference_value_memory[:sequence_end, batch_start:batch_end, ...]
4906

4907
4908
4909
                if qkv_format == "bshd":
                    key_layer = key_layer.transpose(0, 1)
                    value_layer = value_layer.transpose(0, 1)
4910

4911
4912
                key_layer = key_layer.contiguous()
                value_layer = value_layer.contiguous()
4913
4914

            assert (
4915
4916
4917
4918
4919
4920
4921
4922
4923
4924
                key_layer.shape[-2] == self.num_gqa_groups_per_partition
                and value_layer.shape[-2] == self.num_gqa_groups_per_partition
            ), f"Keys and values must have num_gqa_group = {self.num_gqa_groups} heads!"
            assert qkv_format in [
                "sbhd",
                "bshd",
                "thd",
            ], "DotProductAttention only supports qkv_format = {'sbhd', 'bshd', 'thd'}!"

            if qkv_format == "thd":
4925
                assert all(
4926
4927
4928
4929
4930
4931
4932
4933
4934
4935
4936
4937
4938
4939
4940
4941
4942
4943
4944
4945
4946
                    len(x.shape) == 3 for x in (query_layer, key_layer, value_layer)
                ), "Queries, keys and values must be 3D tensors when qkv_format = thd!"
                assert (
                    cu_seqlens_q is not None and cu_seqlens_kv is not None
                ), "cu_seqlens_q and cu_seqlens_kv can not be None when qkv_format = thd!"
                assert (
                    cu_seqlens_q.shape == cu_seqlens_kv.shape
                    and len(cu_seqlens_q.shape) == 1
                    and len(cu_seqlens_kv.shape) == 1
                ), "cu_seqlens_q and cu_seqlens_q must both have shape [batch_size + 1]!"
                assert (
                    cu_seqlens_q.dtype == torch.int32 and cu_seqlens_kv.dtype == torch.int32
                ), "cu_seqlens_q and cu_seqlens_q must both be in dtype torch.int32!"
                if max_seqlen_q is None:
                    seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
                    max_seqlen_q = pow(2, math.ceil(math.log2(seqlens_q.max().item())))
                if max_seqlen_kv is None:
                    seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
                    max_seqlen_kv = pow(2, math.ceil(math.log2(seqlens_kv.max().item())))

            if qkv_format in ["sbhd", "bshd"]:
4947
                assert all(
4948
4949
4950
4951
4952
4953
4954
4955
4956
4957
4958
4959
4960
4961
4962
4963
4964
4965
                    len(x.shape) == 4 for x in (query_layer, key_layer, value_layer)
                ), f"Queries, keys and values must be 4D tensors when qkv_format = {qkv_format}!"
                if qkv_format == "sbhd":
                    max_seqlen_q, max_seqlen_kv = (query_layer.shape[0], key_layer.shape[0])
                if qkv_format == "bshd":
                    max_seqlen_q, max_seqlen_kv = (query_layer.shape[1], key_layer.shape[1])
                if cu_seqlens_q is not None:
                    seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
                    assert all(
                        seqlens_q <= max_seqlen_q
                    ), """Sequence lengths indicated by cu_seqlens_q must be no greater than
                        the sequence dimention in 'query_layer'!"""
                if cu_seqlens_kv is not None:
                    seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
                    assert all(
                        seqlens_kv <= max_seqlen_kv
                    ), """Sequence lengths indicated by cu_seqlens_kv must be no greater than
                        the sequence dimention in 'key_layer' and 'value_layer'!"""
4966

4967
4968
4969
4970
4971
4972
4973
4974
4975
4976
4977
4978
            if (
                isinstance(query_layer, Float8Tensor)
                and isinstance(key_layer, Float8Tensor)
                and isinstance(value_layer, Float8Tensor)
            ):
                qkv_layout, query_layer._data, key_layer._data, value_layer._data = _get_qkv_layout(
                    query_layer._data, key_layer._data, value_layer._data, qkv_format=qkv_format
                )
            else:
                qkv_layout, query_layer, key_layer, value_layer = _get_qkv_layout(
                    query_layer, key_layer, value_layer, qkv_format=qkv_format
                )
4979

4980
4981
4982
4983
4984
            # The priority for attention backends (subject to availability and clearing the filters)
            # is: FlashAttention > FusedAttention (cuDNN) > UnfusedDotProductAttention.
            use_flash_attention = self.use_flash_attention
            use_fused_attention = self.use_fused_attention
            use_unfused_attention = True
4985

4986
4987
            # The following section filters out some backends based on
            # certain asserts before executing the forward pass.
4988

4989
4990
4991
4992
            # Filter: QKV layout.
            if use_unfused_attention and qkv_format == "thd":
                self.logger.debug("Disabling UnusedDotProductAttention for qkv_format = thd")
                use_unfused_attention = False
4993

4994
4995
            # Filter: ONNX export.
            if is_in_onnx_export_mode():
4996
                if use_flash_attention:
4997
                    self.logger.debug("Disabling FlashAttention for ONNX mode")
4998
                use_flash_attention = False
4999
5000
5001
                if use_fused_attention:
                    self.logger.debug("Disabling FusedAttention for ONNX mode")
                use_fused_attention = False
5002

5003
5004
5005
5006
5007
5008
5009
5010
5011
5012
5013
5014
5015
5016
5017
5018
5019
5020
5021
5022
5023
5024
5025
5026
5027
5028
5029
5030
5031
5032
            # Filter: Input type.
            if use_flash_attention and (
                query_layer.dtype not in [torch.bfloat16, torch.float16]
                or key_layer.dtype not in [torch.bfloat16, torch.float16]
                or value_layer.dtype not in [torch.bfloat16, torch.float16]
                or any(isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer])
            ):
                self.logger.debug(
                    "Disabling FlashAttention due to unsupported QKV data types. "
                    "Supported: [torch.bfloat16, torch.float16]. "
                    "Found: query_layer.dtype=%s, key_layer.dtype=%s, value_layer.dtype=%s.",
                    query_layer.dtype,
                    key_layer.dtype,
                    value_layer.dtype,
                )
                use_flash_attention = False
            if use_fused_attention and (
                query_layer.dtype not in [torch.bfloat16, torch.float16]
                or key_layer.dtype not in [torch.bfloat16, torch.float16]
                or value_layer.dtype not in [torch.bfloat16, torch.float16]
            ):
                self.logger.debug(
                    "Disabling FusedAttention due to unsupported QKV data types. "
                    "Supported: [torch.bfloat16, torch.float16, Float8Tensor]. "
                    "Found: query_layer.dtype=%s, key_layer.dtype=%s, value_layer.dtype=%s.",
                    query_layer.dtype,
                    key_layer.dtype,
                    value_layer.dtype,
                )
                use_fused_attention = False
5033

5034
5035
5036
5037
5038
5039
5040
5041
5042
5043
5044
5045
5046
5047
5048
5049
5050
5051
5052
5053
5054
5055
5056
5057
5058
5059
5060
5061
5062
5063
            # Filter: Execution type.
            if use_flash_attention and self.fp8 and self.fp8_meta["recipe"].fp8_dpa:
                self.logger.debug("Disabling FlashAttention as it does not support FP8 execution.")
                use_flash_attention = False
            if use_unfused_attention and self.fp8 and self.fp8_meta["recipe"].fp8_dpa:
                self.logger.debug(
                    "Disabling UnfusedDotProductAttention as it does not support FP8 execution."
                )
                use_unfused_attention = False

            # Filter: Device and dimensions.
            # FAv2 supports head_dim <= 256, and for >192 requires sm80/sm90
            # FAv2 requires head_dim % 8 == 0
            if use_flash_attention and (
                query_layer.shape[-1] > 256
                or query_layer.shape[-1] % 8 != 0
                or (
                    query_layer.shape[-1] > 192
                    and self.device_compute_capability not in ((8, 0), (9, 0))
                )
            ):
                self.logger.debug(
                    "Disabling FlashAttention due to unsupported head_dim. "
                    "Supported: %%8 == 0, and <= 256; sm80/90 for >192. "
                    "Found: query_layer.shape[-1]=%s, key_layer.shape[-1]=%s, sm=%s",
                    query_layer.shape[-1],
                    key_layer.shape[-1],
                    ".".join([str(i) for i in self.device_compute_capability]),
                )
                use_flash_attention = False
5064

5065
5066
            # Filter: cross attention + causal mask.
            # (in training mode)
5067
            if (
5068
5069
5070
5071
5072
                use_flash_attention
                and inference_params is None
                and _flash_attn_2_1_plus
                and "causal" in attn_mask_type
                and max_seqlen_q != max_seqlen_kv
5073
            ):
5074
5075
5076
5077
5078
5079
                self.logger.warning(
                    "In training mode, disable the use of FlashAttention since version 2.1+ has "
                    "changed its behavior for causal mask in cross attention. See "
                    "https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag"
                )
                use_flash_attention = False
5080

5081
5082
            context_parallel = (
                self.cp_group is not None and get_distributed_world_size(self.cp_group) != 1
5083
            )
5084
5085
5086
5087
5088
5089
5090
5091
5092
5093
5094
5095
5096
5097
5098
5099
5100
5101
5102
5103
5104
5105
5106
5107
5108
5109
5110
5111
5112
5113

            # Filter: sliding window attention.
            # UnfusedDotProductAttention can support SWA via arbitrary attention mask.
            if window_size not in ((-1, -1), (-1, 0)):
                if use_fused_attention:
                    self.logger.debug("Disabling FusedAttention for SWA")
                use_fused_attention = False
                if (not _flash_attn_2_3_plus) or context_parallel:
                    if use_flash_attention:
                        self.logger.debug(
                            "Disabling FusedAttention as it requires flash-attn 2.3+ "
                            "and no context parallelism"
                        )
                    use_flash_attention = False

            # Filter: Attention mask type.
            #   attn_mask_type(s)    |     supported backends
            # ------------------------------------------------
            #   no_mask              |     All
            #   padding              |     UnfusedDotProductAttention, FlashAttention, FusedAttention
            #   causal               |     All
            #   padding + causal     |     FlashAttention, FusedAttention
            #   arbitrary            |     UnfusedDotProductAttention
            #
            if attn_mask_type == "arbitrary":
                if use_flash_attention:
                    self.logger.debug("Disabling FlashAttention for arbitrary mask")
                use_flash_attention = False
                if use_fused_attention:
                    self.logger.debug("Disabling FusedAttention for arbitrary mask")
5114
5115
                use_fused_attention = False

5116
5117
5118
5119
5120
5121
5122
5123
5124
5125
5126
5127
5128
5129
5130
5131
5132
5133
5134
5135
5136
5137
5138
5139
5140
5141
5142
5143
5144
5145
5146
5147
5148
5149
5150
5151
5152
5153
5154
5155
5156
5157
5158
5159
5160
5161
5162
5163
5164
5165
5166
5167
            if (
                use_unfused_attention
                and inference_params is None
                and "causal" in attn_mask_type
                and max_seqlen_q != max_seqlen_kv
            ):
                self.logger.debug("Disabling UnusedDotProductAttention for qkv_format = thd")
                use_unfused_attention = False

            # Filter: bias.
            global _alibi_cache
            if alibi_slopes is not None:
                assert (
                    core_attention_bias_type == "alibi"
                ), "core_attention_bias_type must be alibi in order to use alibi_slopes!"
                if self.layer_number == 1:
                    _alibi_cache["_alibi_slopes_require_update"] = True
                    _alibi_cache["_alibi_bias_require_update"] = True
            if core_attention_bias_type == "alibi":
                assert (
                    core_attention_bias is None
                ), "core_attention_bias must be None when core_attention_bias_type is alibi!"
                if (
                    _alibi_cache["_num_heads"] != query_layer.shape[-2]
                    or _alibi_cache["_max_seqlen_q"] != max_seqlen_q
                    or _alibi_cache["_max_seqlen_kv"] != max_seqlen_kv
                    or _alibi_cache["_alibi_slopes"] is None
                ):
                    _alibi_cache["_alibi_slopes_require_update"] = True
                    _alibi_cache["_alibi_bias_require_update"] = True

            if use_flash_attention and (
                core_attention_bias_type not in ["no_bias", "alibi"]
                or core_attention_bias is not None
            ):
                self.logger.debug("Disabling FlashAttention for pre/post_scale_bias")
                use_flash_attention = False

            fu_core_attention_bias_type = core_attention_bias_type
            fu_core_attention_bias = core_attention_bias
            if (
                core_attention_bias_type == "alibi"
                and use_fused_attention
                and alibi_slopes is not None
            ):
                fu_core_attention_bias_type = "post_scale_bias"
                _, fu_core_attention_bias = get_alibi(
                    query_layer.shape[-2],
                    max_seqlen_q,
                    max_seqlen_kv,
                    alibi_slopes=alibi_slopes,
                    bias_dtype=query_layer.dtype,
5168
5169
                )
            if (
5170
                use_fused_attention
5171
                and fu_core_attention_bias_type == "post_scale_bias"
5172
5173
5174
5175
5176
                and (
                    fu_core_attention_bias.shape[0] != 1
                    or fu_core_attention_bias.shape[1] != query_layer.shape[-2]
                )
            ):
5177
5178
5179
5180
5181
5182
5183
5184
5185
5186
5187
5188
5189
5190
5191
5192
5193
5194
5195
5196
5197
5198
5199
5200
5201
5202
5203
5204
5205
5206
5207
5208
5209
5210
5211
5212
5213
5214
5215
5216
5217
5218
5219
5220
5221
5222
5223
                if fu_core_attention_bias.requires_grad:
                    # remove this line when cuDNN adds bwd support for
                    # [1, 1, s, s], [b, 1, s, s] and [b, h, s, s]
                    self.logger.debug("Disabling FusedAttention for dBias in [1, H, S, S] shape")
                    use_fused_attention = False
                else:
                    # max512 backend will only support [1, h, s, s]
                    os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1"

            if use_fused_attention:
                q_type = TE_DType[query_layer.dtype]
                kv_type = TE_DType[key_layer.dtype]
                if self.fp8 and self.fp8_meta["recipe"].fp8_dpa:
                    if isinstance(query_layer, Float8Tensor) and isinstance(
                        key_layer, Float8Tensor
                    ):
                        q_type = query_layer._fp8_dtype
                        kv_type = value_layer._fp8_dtype
                    else:
                        q_type = forward_dtype
                        kv_type = forward_dtype
                fused_attention_backend = tex.get_fused_attn_backend(
                    q_type,
                    kv_type,
                    QKVLayout[qkv_layout],
                    AttnBiasType[fu_core_attention_bias_type],
                    AttnMaskType[attn_mask_type],
                    self.attention_dropout,
                    query_layer.shape[-2],  # num_attn_heads
                    key_layer.shape[-2],  # num_gqa_groups
                    max_seqlen_q,
                    max_seqlen_kv,
                    query_layer.shape[-1],  # head_dim
                )
                # DPA does not support FP8; for FP8, use cpp_extensions modules directly
                is_backend_avail = fused_attention_backend in [
                    FusedAttnBackend["F16_max512_seqlen"],
                    FusedAttnBackend["F16_arbitrary_seqlen"],
                    FusedAttnBackend["FP8"],
                ]
                use_fused_attention = (
                    use_fused_attention
                    and is_backend_avail
                    and (
                        not context_parallel
                        or fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]
                    )
5224
                )
5225
5226
5227
5228
5229
5230
5231
5232
5233
5234
5235
5236
5237
5238
5239
5240
5241
5242
5243
5244
5245
5246
5247
5248
5249
5250
5251
5252
5253
5254
5255
5256
                if (
                    fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]
                    and fu_core_attention_bias_type == "post_scale_bias"
                    and (
                        fu_core_attention_bias.shape[0] != 1
                        or fu_core_attention_bias.shape[1] != query_layer.shape[-2]
                    )
                ):
                    self.logger.debug(
                        "Disabling FusedAttention as no backend supports the provided input"
                    )
                    use_fused_attention = False

            # Filter: determinism.
            # backend                                  | deterministic
            # ---------------------------------------------------------
            # flash-attn v1                            | yes
            # flash-attn v2                            | no
            # FusedAttnBackend["F16_max512_seqlen"]    | yes
            # FusedAttnBackend["F16_arbitrary_seqlen"] | workspace optimization path: yes; otherwise: no
            # UnfusedDotProductAttention               | yes
            #
            # Note that FusedAttnBackend["F16_arbitrary_seqlen"] only has workspace optimization path
            # on sm90 architectures.
            #
            if (
                use_fused_attention
                and fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]
                and self.deterministic
                and self.device_compute_capability != (9, 0)
            ):
                self.logger.debug("Disabling FusedAttention for determinism reasons")
5257
                use_fused_attention = False
5258

5259
5260
5261
5262
5263
5264
5265
5266
5267
5268
5269
5270
            # Select FusedAttention on sm90 and FlashAttention on others for performance
            if (
                use_flash_attention
                and use_fused_attention
                and fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]
            ):
                if self.device_compute_capability == (9, 0):
                    self.logger.debug(
                        "Disabling FlashAttention to give FusedAttention preference on Hopper+ "
                        "for performance reasons"
                    )
                    use_flash_attention = False
5271

5272
5273
5274
5275
5276
5277
5278
5279
5280
5281
5282
5283
5284
5285
5286
5287
5288
5289
5290
5291
5292
5293
5294
5295
5296
5297
5298
            run_config = {
                "compute_capability": "sm"
                + str(
                    (lambda x, y: x * 10 + y)(
                        self.device_compute_capability[0], self.device_compute_capability[1]
                    )
                ),
                "q_dtype": query_layer.dtype,
                "k_dtype": key_layer.dtype,
                "v_dtype": value_layer.dtype,
                "q_shape": list(query_layer.shape),
                "k_shape": list(key_layer.shape),
                "v_shape": list(value_layer.shape),
                "qkv_format": qkv_format,
                "qkv_layout": qkv_layout,
                "mask_type": attn_mask_type,
                "bias_type": core_attention_bias_type,
                "bias_shape": (
                    core_attention_bias.shape if core_attention_bias is not None else None
                ),
                "dropout": self.attention_dropout,
                "context_parallel": context_parallel,
                "is_training": self.training,
                "transformer_engine_version": te.__version__,
                "flash_attn_version": _flash_attn_version,
                "cudnn_version": ".".join([str(i) for i in get_cudnn_version()]),
            }
5299

5300
5301
5302
5303
5304
5305
5306
5307
5308
5309
5310
5311
5312
5313
5314
5315
5316
5317
5318
5319
5320
5321
5322
5323
5324
5325
            if use_flash_attention:
                self.logger.info("Running with FlashAttention backend ")
                self.logger.debug("Running with config=%s", run_config)
                if core_attention_bias_type == "alibi":
                    alibi_slopes, _ = get_alibi(
                        query_layer.shape[-2],
                        max_seqlen_q,
                        max_seqlen_kv,
                        alibi_slopes=alibi_slopes,
                    )
                return self.flash_attention(
                    query_layer,
                    key_layer,
                    value_layer,
                    attention_mask=attention_mask,
                    qkv_layout=qkv_layout,
                    cu_seqlens_q=cu_seqlens_q,
                    cu_seqlens_kv=cu_seqlens_kv,
                    attn_mask_type=attn_mask_type,
                    window_size=window_size,
                    alibi_slopes=alibi_slopes,
                    cp_group=self.cp_group,
                    cp_global_ranks=self.cp_global_ranks,
                    cp_stream=self.cp_stream,
                    max_seqlen_q=max_seqlen_q,
                    max_seqlen_kv=max_seqlen_kv,
5326
                )
5327

5328
5329
5330
5331
            if use_fused_attention:
                self.logger.info(
                    "Running with FusedAttention backend (sub-backend %s)",
                    int(fused_attention_backend),
5332
                )
5333
5334
5335
5336
5337
5338
5339
5340
5341
5342
5343
5344
5345
5346
5347
5348
5349
5350
5351
5352
5353
5354
5355
5356
5357
5358
5359
5360
5361
5362
5363
5364
5365
5366
5367
5368
5369
5370
                if self.fp8:
                    self.logger.debug(
                        "Running with fp8_recipe.fp8_mha=%s, "
                        "fp8_recipe.fp8_dpa=%s%s, and NVTE_FP8_DPA_BWD=%s",
                        self.fp8_meta["recipe"].fp8_mha,
                        self.fp8_meta["recipe"].fp8_dpa,
                        forced_fp8_dpa,
                        int(os.getenv("NVTE_FP8_DPA_BWD", "1")),
                    )
                self.logger.debug("Running with config=%s", run_config)
                if checkpoint_core_attention:
                    return self._checkpointed_attention_forward(
                        self.fused_attention,
                        query_layer,
                        key_layer,
                        value_layer,
                        qkv_layout=qkv_layout,
                        cu_seqlens_q=cu_seqlens_q,
                        cu_seqlens_kv=cu_seqlens_kv,
                        seq_offsets_q=seq_offsets_q,
                        seq_offsets_k=seq_offsets_k,
                        seq_offsets_v=seq_offsets_v,
                        seq_offsets_o=seq_offsets_o,
                        max_seqlen_q=max_seqlen_q,
                        max_seqlen_kv=max_seqlen_kv,
                        attn_mask_type=attn_mask_type,
                        attention_mask=attention_mask,
                        fused_attention_backend=fused_attention_backend,
                        core_attention_bias_type=fu_core_attention_bias_type,
                        core_attention_bias=fu_core_attention_bias,
                        fast_zero_fill=fast_zero_fill,
                        cp_group=self.cp_group,
                        cp_global_ranks=self.cp_global_ranks,
                        cp_stream=self.cp_stream,
                        fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa,
                        fp8_meta=self.fp8_meta,
                    )
                return self.fused_attention(
5371
5372
5373
5374
5375
5376
                    query_layer,
                    key_layer,
                    value_layer,
                    qkv_layout=qkv_layout,
                    cu_seqlens_q=cu_seqlens_q,
                    cu_seqlens_kv=cu_seqlens_kv,
5377
5378
5379
5380
                    seq_offsets_q=seq_offsets_q,
                    seq_offsets_k=seq_offsets_k,
                    seq_offsets_v=seq_offsets_v,
                    seq_offsets_o=seq_offsets_o,
5381
5382
                    max_seqlen_q=max_seqlen_q,
                    max_seqlen_kv=max_seqlen_kv,
5383
5384
5385
                    attn_mask_type=attn_mask_type,
                    attention_mask=attention_mask,
                    fused_attention_backend=fused_attention_backend,
5386
5387
                    core_attention_bias_type=fu_core_attention_bias_type,
                    core_attention_bias=fu_core_attention_bias,
5388
5389
5390
5391
                    fast_zero_fill=fast_zero_fill,
                    cp_group=self.cp_group,
                    cp_global_ranks=self.cp_global_ranks,
                    cp_stream=self.cp_stream,
5392
5393
                    fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa,
                    fp8_meta=self.fp8_meta,
5394
                )
5395

5396
5397
5398
            assert (
                not context_parallel
            ), "Context parallelism is only implemented with Flash Attention and Fused Attention!"
5399

5400
            from .cpu_offload import CPUOffloadEnabled
5401

5402
5403
5404
5405
5406
            if CPUOffloadEnabled:
                warnings.warn(
                    "Attention activation Offloading is only implemented"
                    "with Flash Attention and Fused Attention!"
                )
5407

5408
5409
5410
5411
5412
5413
5414
5415
5416
5417
5418
5419
5420
5421
5422
5423
5424
5425
5426
            if use_unfused_attention:
                self.logger.info("Running with UnfusedDotProductAttention backend")
                self.logger.debug("Running with config=%s", run_config)
                if checkpoint_core_attention:
                    return self._checkpointed_attention_forward(
                        self.unfused_attention,
                        query_layer,
                        key_layer,
                        value_layer,
                        qkv_layout=qkv_layout,
                        cu_seqlens_q=cu_seqlens_q,
                        cu_seqlens_kv=cu_seqlens_kv,
                        attn_mask_type=attn_mask_type,
                        attention_mask=attention_mask,
                        core_attention_bias_type=core_attention_bias_type,
                        core_attention_bias=core_attention_bias,
                        alibi_slopes=alibi_slopes,
                    )
                return self.unfused_attention(
5427
5428
5429
                    query_layer,
                    key_layer,
                    value_layer,
5430
5431
5432
5433
5434
5435
5436
5437
5438
                    qkv_layout=qkv_layout,
                    cu_seqlens_q=cu_seqlens_q,
                    cu_seqlens_kv=cu_seqlens_kv,
                    attn_mask_type=attn_mask_type,
                    attention_mask=attention_mask,
                    core_attention_bias_type=core_attention_bias_type,
                    core_attention_bias=core_attention_bias,
                    alibi_slopes=alibi_slopes,
                )
5439

5440
            raise Exception("No dot product attention support for the provided inputs!")
5441
5442


5443
5444
5445
5446
5447
5448
5449
class MultiheadAttention(torch.nn.Module):
    r"""
    Multi-head Attention (MHA), including Query,
    Key, Value and Output projection.

    .. note::

5450
5451
        Argument :attr:`attention_mask` in the `forward` call is only used when
        :attr:`attn_mask_type` includes '"padding"' or `"arbitrary"`.
5452

5453
5454
5455
5456
5457
5458
5459
5460
5461
5462
5463
5464
5465
5466
5467
5468
5469
5470
5471
5472
5473
5474
5475
5476
5477
    Parameters
    ----------
    hidden_size : int
                 size of each input sample.
    num_attention_heads : int
                         number of attention heads in the transformer layer.
    kv_channels: int, default = `None`
                number of key-value channels. defaults to
                :attr:`hidden_size` / :attr:`num_attention_heads` if `None`.
    attention_dropout: float, default = 0.1
                      dropout probability for the dropout op during multi-head attention.
    layernorm_epsilon : float, default = 1e-5
                       a value added to the denominator of layer normalization
                       for numerical stability.
    init_method : Callable, default = `None`
                 used for initializing weights of QKV and FC1 weights in the following way:
                 `init_method(weight)`. When set to `None`, defaults to
                 `torch.nn.init.normal_(mean=0.0, std=0.023)`.
    output_layer_init_method : Callable, default = `None`
                              used for initializing weights of PROJ and FC2 in the following way:
                              `output_layer_init_method(weight)`. When set to `None`, defaults to
                              `torch.nn.init.normal_(mean=0.0, std=0.023)`.
    layer_number: int, default = `None`
                 layer number of the current `TransformerLayer` when multiple such modules are
                 concatenated to form a transformer block.
5478
5479
    attn_mask_type: {'no_mask', 'padding', 'causal', 'padding_causal' 'arbitrary'},
                   default = `causal`
5480
5481
5482
5483
5484
                   type of attention mask passed into softmax operation. Overridden by
                   :attr:`attn_mask_type` in the `forward` method. The forward
                   arg is useful for dynamically changing mask types, e.g. a different
                   mask for training and inference. The init arg is useful for cases
                   involving compilation/tracing, e.g. ONNX export.
5485
5486
5487
5488
5489
5490
    window_size: Optional[Tuple[int, int]], default = `None`
                sliding window size for local attention, where query at position i attends to keys
                in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
                + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
                window and causal mask specifically. Similar to :attr:`attn_mask_type`, it can
                be overridden by :attr:`window_size` in `forward` as well.
5491
5492
5493
5494
5495
5496
5497
5498
5499
5500
5501
5502
5503
    num_gqa_groups : int, default = `None`
                         number of GQA groups in the transformer layer.
                         Grouped Query Attention is described in
                         `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
                         This only affects the keys and values, not the querys.
                         GQA-1 is equivalent to Multi-Query Attention
                         (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
                         is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
    return_layernorm_output : bool, default = `False`
                             if set to `True`, output of layernorm is returned from the forward
                             together with the output of the linear transformation.
                             Example use case: residual connection for transformer module is
                             taken post layernorm.
5504
5505
    input_layernorm: bool, default = `False`
                     if set to `True`, layer normalization to the input is applied.
5506
5507
5508
5509
5510
5511
5512
5513
5514
5515
5516
5517
5518
5519
5520
5521
5522
5523
5524
5525
5526
5527
5528
    attention_type: { 'self', 'cross' }, default = 'self'
                   type of attention applied.
    zero_centered_gamma : bool, default = 'False'
                         if set to 'True', gamma parameter in LayerNorm is initialized to 0 and
                         the LayerNorm formula changes to

                         .. math::
                            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} *
                            (1 + \gamma) + \beta
    normalization : { 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm'
                   type of normalization applied.
    qkv_weight_interleaved : bool, default = `True`
                            if set to `False`, the QKV weight is interpreted as a concatenation of
                            query, key, and value weights along the `0th` dimension. The default
                            interpretation is that the individual `q`, `k`, and `v` weights for each
                            attention head are interleaved. This parameter is set to `False` when
                            using :attr:`fuse_qkv_params=False`.
    bias : bool, default = `True`
          if set to `False`, the transformer layer will not learn any additive biases.
    device : Union[torch.device, str], default = "cuda"
          The device on which the parameters of the model will allocated. It is the user's
          responsibility to ensure all parameters are moved to the GPU before running the
          forward pass.
5529
5530
5531
5532
5533
5534
5535
5536
    qkv_format: str, default = `sbhd`
            dimension format for `query_layer`, `key_layer` and `value_layer`,
            {`sbhd`, `bshd`}. `s` stands for the sequence length, `b` batch size,
            `h` the number of heads and `d` head size. `sbhd` and `bshd` formats
            are used for when sequences in a batch are of equal length or padded to
            equal length. Please note that these formats do not reflect how
            tensors `query_layer`, `key_layer`, `value_layer` are laid out in memory.
            For that, please use `_get_qkv_layout` to gain the layout information.
5537
5538
5539
5540
5541
5542
5543
5544
5545
5546
5547
5548
5549
5550
5551
5552
5553
5554
5555
5556
5557
5558
5559
5560
5561
5562
5563
5564
5565
5566
5567
5568
5569
5570
5571
5572
5573
5574
5575
5576

    Parallelism parameters
    ----------------------
    set_parallel_mode : bool, default = `False`
                      if set to `True`, QKV and FC1 layers are used as Column Parallel
                      whereas PROJ and FC2 is used as Row Parallel as described
                      `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
    sequence_parallel : bool, default = `False`
                       if set to `True`, uses sequence parallelism.
    tp_group : ProcessGroup, default = `None`
              tensor parallel process group.
    tp_size : int, default = 1
             used as TP (tensor parallel) world size when TP groups are not formed during
             initialization. In this case, users must call the
             `set_tensor_parallel_group(tp_group)` method on the initialized module before the
             forward pass to supply the tensor parallel group needed for tensor and sequence
             parallel collectives.

    Optimization parameters
    -----------------------
    fuse_wgrad_accumulation : bool, default = 'False'
                             if set to `True`, enables fusing of creation and accumulation of
                             the weight gradient. When enabled, it is assumed that the weights
                             have an additional `main_grad` attribute (used instead of the
                             regular `grad`) which is a pre-allocated buffer of the correct
                             size to accumulate gradients in.
    params_dtype : torch.dtype, default = `torch.get_default_dtype()`
                  it controls the type used to allocate the initial parameters. Useful when
                  the model is trained with lower precision and the original FP32 parameters
                  would not fit in GPU memory.
    return_bias : bool, default = `False`
                 when set to `True`, this module will not apply the additive bias itself, but
                 instead return the bias value during the forward pass together with the
                 output of the linear transformation :math:`y = xA^T`. This is useful when
                 the bias addition can be fused to subsequent operations.
    fuse_qkv_params: bool, default = 'False'
                    if set to `True`, `TransformerLayer` module exposes a single fused
                    parameter for query-key-value. This enables optimizations such as QKV
                    fusion without concatentations/splits and also enables the argument
                    `fuse_wgrad_accumulation`.
5577
5578
5579
5580
5581
5582
    """

    def __init__(
        self,
        hidden_size: int,
        num_attention_heads: int,
5583
5584
5585
5586
5587
        kv_channels: Optional[int] = None,
        attention_dropout: float = 0.1,
        layernorm_epsilon: float = 1e-5,
        init_method: Optional[Callable] = None,
        output_layer_init_method: Optional[Callable] = None,
5588
        layer_number: Optional[int] = None,
5589
        attn_mask_type: str = "causal",
5590
        window_size: Optional[Tuple[int, int]] = None,
5591
5592
        tp_group: Optional[dist_group_type] = None,
        tp_size: int = 1,
5593
        num_gqa_groups: Optional[int] = None,
5594
5595
5596
        fuse_wgrad_accumulation: bool = False,
        get_rng_state_tracker: Optional[Callable] = None,
        sequence_parallel: bool = False,
5597
        params_dtype: Optional[torch.dtype] = None,
5598
        return_bias: bool = False,
5599
5600
5601
5602
5603
5604
5605
5606
5607
        return_layernorm_output: bool = False,
        input_layernorm: bool = False,
        attention_type: str = "self",
        set_parallel_mode: bool = False,
        fuse_qkv_params: bool = False,
        zero_centered_gamma: bool = False,
        qkv_weight_interleaved: bool = True,
        ub_bulk_wgrad: bool = False,
        ub_bulk_dgrad: bool = False,
Jaemin Choi's avatar
Jaemin Choi committed
5608
        ub_overlap_rs_dgrad: bool = False,
5609
5610
        ub_overlap_rs: bool = False,
        ub_overlap_ag: bool = False,
5611
        bias: bool = True,
5612
        normalization: str = "LayerNorm",
5613
        device: Union[torch.device, str] = "cuda",
5614
        qkv_format: str = "sbhd",
5615
5616
    ) -> None:
        super().__init__()
5617

5618
        self.qkv_format = qkv_format
5619
        self.attn_mask_type = attn_mask_type
5620
5621
        self.window_size = window_size
        self.window_size = check_set_window_size(attn_mask_type, self.window_size)
5622
        self.layer_number = layer_number
5623
5624
5625
5626
5627
        self.input_layernorm = input_layernorm
        self.attention_type = attention_type
        self.get_rng_state_tracker = get_rng_state_tracker
        self.tp_group = tp_group
        self.return_layernorm_output = return_layernorm_output
5628
        self.params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
5629
        self.num_attention_heads = num_attention_heads
5630
5631
5632
5633
5634
5635
5636
5637
        self.return_bias = return_bias

        kv_channels = kv_channels if kv_channels else (hidden_size // num_attention_heads)

        if init_method is None:
            init_method = get_default_init_method()
        if output_layer_init_method is None:
            output_layer_init_method = get_default_init_method()
5638
5639
5640
5641
5642

        if not fuse_qkv_params:
            qkv_weight_interleaved = False
        self.qkv_weight_interleaved = qkv_weight_interleaved

5643
5644
5645
        assert attention_type in AttnTypes, f"attention_type {attention_type} not supported"
        if layer_number is not None:
            assert layer_number > 0, "layer_number must be a positive integer"
5646
5647
5648
5649
5650
5651

        tp_size = tp_size if tp_group is None else get_distributed_world_size(tp_group)
        self.tp_size = tp_size
        self.sequence_parallel = (tp_size > 1) and sequence_parallel

        self.num_attention_heads_per_partition = divide(num_attention_heads, tp_size)
5652
5653
5654
5655
5656
5657
5658
        self.num_gqa_groups = num_attention_heads if num_gqa_groups is None else num_gqa_groups
        assert (
            num_attention_heads % self.num_gqa_groups == 0
        ), "The number of attention heads must be divisible by the number of GQA groups!"
        assert (
            self.num_gqa_groups % tp_size == 0
        ), "The number of GQA groups must be divisible by tensor parallel size!"
5659
        self.num_gqa_groups_per_partition = int(self.num_gqa_groups // tp_size)
5660
5661
5662
5663

        self.hidden_size_per_attention_head = kv_channels
        self.hidden_size_q = self.hidden_size_per_attention_head * num_attention_heads
        self.hidden_size_kv = self.hidden_size_per_attention_head * self.num_gqa_groups
5664
5665
5666
5667
5668
5669
5670

        common_gemm_kwargs = {
            "fuse_wgrad_accumulation": fuse_wgrad_accumulation,
            "tp_group": tp_group,
            "tp_size": tp_size,
            "get_rng_state_tracker": get_rng_state_tracker,
            "sequence_parallel": sequence_parallel,
5671
            "params_dtype": self.params_dtype,
5672
            "device": device,
5673
5674
5675
5676
        }

        qkv_parallel_mode = "column" if set_parallel_mode else None

cyanguwa's avatar
cyanguwa committed
5677
        if self.attention_type == "self":
5678
5679
            parameters_split = None
            if not fuse_qkv_params:
5680
5681
5682
5683
5684
5685
5686
                parameters_split = collections.OrderedDict(
                    [
                        ("query", self.hidden_size_q),
                        ("key", self.hidden_size_kv),
                        ("value", self.hidden_size_kv),
                    ]
                )
5687
5688
5689
            if self.input_layernorm:
                self.layernorm_qkv = LayerNormLinear(
                    hidden_size,
5690
                    self.hidden_size_q + 2 * self.hidden_size_kv,
5691
5692
5693
5694
5695
5696
                    eps=layernorm_epsilon,
                    init_method=init_method,
                    bias=bias,
                    return_bias=False,
                    parallel_mode=qkv_parallel_mode,
                    return_layernorm_output=return_layernorm_output,
cyanguwa's avatar
cyanguwa committed
5697
                    parameters_split=parameters_split,
5698
5699
5700
                    zero_centered_gamma=zero_centered_gamma,
                    ub_bulk_wgrad=ub_bulk_wgrad,
                    ub_bulk_dgrad=ub_bulk_dgrad,
Jaemin Choi's avatar
Jaemin Choi committed
5701
                    ub_overlap_rs_dgrad=ub_overlap_rs_dgrad,
5702
                    ub_overlap_ag=ub_overlap_ag,
5703
                    normalization=normalization,
Przemyslaw Tredak's avatar
Przemyslaw Tredak committed
5704
                    ub_name="qkv",
5705
5706
5707
5708
5709
                    **common_gemm_kwargs,
                )
            else:
                self.qkv = Linear(
                    hidden_size,
5710
                    self.hidden_size_q + 2 * self.hidden_size_kv,
5711
5712
5713
5714
                    init_method=init_method,
                    bias=bias,
                    return_bias=False,
                    parallel_mode=qkv_parallel_mode,
cyanguwa's avatar
cyanguwa committed
5715
                    parameters_split=parameters_split,
5716
5717
                    **common_gemm_kwargs,
                )
cyanguwa's avatar
cyanguwa committed
5718
        elif self.attention_type == "cross":
5719
5720
5721
            if self.input_layernorm:
                self.layernorm_query = LayerNormLinear(
                    hidden_size,
5722
                    self.hidden_size_q,
5723
5724
5725
5726
5727
                    eps=layernorm_epsilon,
                    init_method=init_method,
                    bias=bias,
                    return_bias=False,
                    parallel_mode=qkv_parallel_mode,
5728
                    parameters_split=("query",) if not fuse_qkv_params else None,
5729
5730
5731
5732
                    return_layernorm_output=return_layernorm_output,
                    zero_centered_gamma=zero_centered_gamma,
                    ub_bulk_wgrad=ub_bulk_wgrad,
                    ub_bulk_dgrad=ub_bulk_dgrad,
Jaemin Choi's avatar
Jaemin Choi committed
5733
                    ub_overlap_rs_dgrad=ub_overlap_rs_dgrad,
5734
                    ub_overlap_ag=ub_overlap_ag,
5735
                    normalization=normalization,
Przemyslaw Tredak's avatar
Przemyslaw Tredak committed
5736
                    ub_name="qkv",
5737
5738
5739
5740
5741
                    **common_gemm_kwargs,
                )
            else:
                self.query_layer = Linear(
                    hidden_size,
5742
                    self.hidden_size_q,
5743
5744
5745
5746
5747
5748
5749
5750
                    init_method=init_method,
                    bias=bias,
                    return_bias=False,
                    parallel_mode=qkv_parallel_mode,
                    **common_gemm_kwargs,
                )
            self.key_value = Linear(
                hidden_size,
5751
                2 * self.hidden_size_kv,
5752
5753
5754
5755
                init_method=init_method,
                bias=bias,
                return_bias=False,
                parallel_mode=qkv_parallel_mode,
5756
                parameters_split=("key", "value") if not fuse_qkv_params else None,
5757
5758
5759
5760
5761
5762
                **common_gemm_kwargs,
            )

        # Attention.
        self.core_attention = DotProductAttention(
            num_attention_heads,
5763
            self.hidden_size_per_attention_head,
5764
5765
            num_gqa_groups=self.num_gqa_groups,
            attention_dropout=attention_dropout,
5766
            qkv_format=self.qkv_format,
5767
5768
5769
5770
            tp_size=tp_size,
            get_rng_state_tracker=get_rng_state_tracker,
            sequence_parallel=sequence_parallel,
            tp_group=tp_group,
5771
            layer_number=self.layer_number,
5772
            attention_type=self.attention_type,
5773
5774
5775
5776
        )

        # Linear
        self.proj = Linear(
5777
            self.hidden_size_q,
5778
5779
5780
            hidden_size,
            init_method=output_layer_init_method,
            bias=bias,
5781
            return_bias=return_bias,
5782
            parallel_mode="row" if set_parallel_mode else None,
5783
5784
            ub_overlap_rs=ub_overlap_rs,
            ub_overlap_ag=ub_overlap_ag,
Przemyslaw Tredak's avatar
Przemyslaw Tredak committed
5785
            ub_name="proj",
5786
5787
5788
5789
            **common_gemm_kwargs,
        )

    def _allocate_memory(
5790
        self, inference_max_sequence_len: int, batch_size: int, dtype: torch.dtype
5791
5792
5793
5794
    ) -> torch.Tensor:
        return torch.empty(
            inference_max_sequence_len,
            batch_size,
5795
            self.num_gqa_groups_per_partition,
5796
            self.hidden_size_per_attention_head,
5797
            dtype=dtype,
5798
5799
5800
5801
            device=torch.cuda.current_device(),
        )

    def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None:
5802
5803
5804
5805
5806
5807
5808
5809
5810
        """
        Set the tensor parallel group for the given
        module before executing the forward pass.

        Parameters
        ----------
        tp_group : ProcessGroup, default = `None`
                  tensor parallel process group.
        """
5811
5812
        self.tp_group = tp_group

5813
    def set_context_parallel_group(
5814
5815
        self,
        cp_group: Union[dist_group_type, None],
5816
        cp_global_ranks: List[int],
5817
5818
        cp_stream: torch.cuda.Stream,
    ) -> None:
5819
5820
5821
5822
5823
5824
5825
5826
5827
5828
5829
5830
5831
        """
        Set the context parallel attributes for the given
        module before executing the forward pass.

        Parameters
        ----------
        cp_group : ProcessGroup
                  context parallel process group.
        cp_global_ranks : List[int]
                         list of global ranks in the context group.
        cp_stream : torch.cuda.Stream
                   cuda stream for context parallel execution.
        """
5832
5833
5834
5835
5836
5837
        # Deep iterate but skip self to avoid infinite recursion.
        for index, child in enumerate(self.modules()):
            if index == 0:
                continue
            if hasattr(child, "set_context_parallel_group"):
                child.set_context_parallel_group(cp_group, cp_global_ranks, cp_stream)
5838

5839
5840
5841
    def forward(
        self,
        hidden_states: torch.Tensor,
5842
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
5843
        encoder_output: Optional[torch.Tensor] = None,
5844
        attn_mask_type: Optional[str] = None,
5845
        window_size: Optional[Tuple[int, int]] = None,
5846
5847
        is_first_microbatch: Optional[bool] = None,
        checkpoint_core_attention: bool = False,
5848
        inference_params: Optional[InferenceParams] = None,
5849
        rotary_pos_emb: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
5850
5851
        core_attention_bias_type: str = "no_bias",
        core_attention_bias: Optional[torch.Tensor] = None,
5852
        alibi_slopes: Optional[torch.Tensor] = None,
5853
        fast_zero_fill: bool = True,
5854
    ) -> Tuple[Union[torch.Tensor, None], ...]:
5855
5856
5857
5858
5859
        """
        Forward propagation for MultiheadAttention layer.

        .. note::

5860
5861
            Argument :attr:`attention_mask` is only used when :attr:`attn_mask_type`
            includes `"padding"` or `"arbitrary"`.
5862
5863
5864
5865
5866

        Parameters
        ----------
        hidden_states : torch.Tensor
             Input tensor.
5867
5868
5869
5870
5871
5872
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]],
             default = `None`. Boolean tensor(s) used to mask out attention softmax input.
             It should be 'None' for 'causal' and 'no_mask' types. For 'padding' masks, it should be
             a single tensor of [batch_size, 1, 1, seqlen_q] for self-attention, and a tuple of
             two tensors in shapes [batch_size, 1, 1, seqlen_q] and [batch_size, 1, 1, seqlen_kv]
             for cross-attention. For the 'arbitrary' mask type, it should be in a shape that is
5873
5874
5875
             broadcastable to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv]. A `True` value
             means the corresponding position is masked out and a `False` means that position is
             allowed to participate in attention.
5876
5877
        attn_mask_type: {'no_mask', 'padding', 'causal', 'padding_causal', 'arbitrary'},
                       default = `None`
5878
                       type of attention mask passed into softmax operation.
5879
5880
        window_size: Optional[Tuple[int, int]], default = `None`
                    sliding window size for local attention.
5881
5882
5883
5884
5885
5886
5887
5888
5889
5890
5891
5892
5893
5894
5895
5896
5897
5898
5899
5900
5901
5902
5903
5904
5905
        encoder_output : Optional[torch.Tensor], default = `None`
             Output of the encoder block to be fed into the decoder block if using
             `layer_type="decoder"`.
        is_first_microbatch : {True, False, None}, default = None
                             During training using either gradient accumulation or
                             pipeline parallelism a minibatch of data is further split
                             into microbatches. Between the microbatches of the same minibatch
                             the model weights are not updated. Setting this parameter indicates
                             whether the current microbatch is the first in a minibatch or not.
                             When set, this parameter enables additional optimizations:

                             * during FP8 training, it allows caching of the FP8 versions of
                               the weights
                             * it also allows skipping gradient accumulation during the
                               first microbatch (since it is the first gradient being
                               produced)
        checkpoint_core_attention: bool, default = `False`
                                  If true, forward activations for core attention are recomputed
                                  during the backward pass in order to save memory that would
                                  otherwise be occupied to store the forward activations until
                                  backprop.
        rotary_pos_emb: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], default = `None`
                       Embeddings for query and key tensors for applying rotary position
                       embedding. By default no input embedding is applied.
        core_attention_bias_type: str, default = `no_bias`
5906
                    Bias type, {`no_bias`, `pre_scale_bias`, 'post_scale_bias`, `alibi`}
5907
        core_attention_bias: Optional[torch.Tensor], default = `None`
5908
5909
                    Bias tensor for Q * K.T, shape [1, num_head, max_seqlen_q, max_seqlen_kv].
                    It should be 'None' for 'no_bias' and 'alibi' bias types.
5910
5911
5912
5913
        alibi_slopes: Optional[torch.Tensor], default = `None`
                     ALiBi slopes in FP32 and shape [nheads] or [batch_size, nheads].
                     It adds a bias of (-alibi_slope * (i + seqlen_k - seqlen_q - j))
                     to the attention score of query i and key j.
5914
5915
5916
        fast_zero_fill: bool, default = `True`
                    Whether to set output tensors to 0 or not before use.
        """
5917
5918
        # hidden_states: [sq, b, h]

5919
5920
        if attn_mask_type is not None:
            window_size = check_set_window_size(attn_mask_type, window_size)
5921
        if attn_mask_type is None:
5922
            attn_mask_type = self.attn_mask_type
5923
5924
        if window_size is None:
            window_size = self.window_size
5925

5926
        if "padding" in attn_mask_type and attention_mask is not None:
5927
            for i, _ in enumerate(attention_mask):
5928
5929
5930
                assert (
                    attention_mask[i].dtype == torch.bool
                ), "Attention mask must be in boolean type!"
5931

5932
5933
5934
        assert (
            core_attention_bias_type in AttnBiasTypes
        ), f"core_attention_bias_type {core_attention_bias_type} is not supported!"
5935

5936
        # =================================================
5937
        # Pre-allocate memory for key-values for inference
5938
5939
5940
5941
        # =================================================

        if inference_params and self.layer_number is not None:
            if self.layer_number not in inference_params.key_value_memory_dict:
5942
                inf_max_seq_len = inference_params.max_sequence_length
5943
5944
                inf_max_batch_size = inference_params.max_batch_size
                inference_key_memory = self._allocate_memory(
5945
                    inf_max_seq_len, inf_max_batch_size, hidden_states.dtype
5946
5947
                )
                inference_value_memory = self._allocate_memory(
5948
                    inf_max_seq_len, inf_max_batch_size, hidden_states.dtype
5949
5950
5951
5952
5953
5954
5955
5956
5957
5958
5959
                )
                inference_params.key_value_memory_dict[self.layer_number] = (
                    inference_key_memory,
                    inference_value_memory,
                )
            else:
                (
                    inference_key_memory,
                    inference_value_memory,
                ) = inference_params.key_value_memory_dict[self.layer_number]

5960
        # ======================
5961
        # Query, Key, and Value
5962
        # ======================
5963

cyanguwa's avatar
cyanguwa committed
5964
5965
        if self.attention_type == "self":
            # Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn]
5966
5967
5968
5969
5970
5971
5972
5973
5974
5975
5976
5977
5978
            if self.input_layernorm:
                layernorm_qkv_outputs = self.layernorm_qkv(
                    hidden_states,
                    is_first_microbatch=is_first_microbatch,
                )
                if self.return_layernorm_output:
                    mixed_x_layer, layernorm_output = layernorm_qkv_outputs
                else:
                    mixed_x_layer = layernorm_qkv_outputs
            else:
                mixed_x_layer = self.qkv(
                    hidden_states,
                    is_first_microbatch=is_first_microbatch,
5979
                    is_first_module_in_mha=True,  # specific to FP8 MHA
5980
5981
                )

5982
5983
5984
            num_queries_per_key_value = (
                self.num_attention_heads_per_partition // self.num_gqa_groups_per_partition
            )
5985
            if self.qkv_weight_interleaved:
cyanguwa's avatar
cyanguwa committed
5986
                # [sq, b, ng * (np/ng + 2) * hn] --> [sq, b, ng, (np/ng + 2), hn]
5987
                new_tensor_shape = mixed_x_layer.size()[:-1] + (
cyanguwa's avatar
cyanguwa committed
5988
5989
                    self.num_gqa_groups_per_partition,
                    (num_queries_per_key_value + 2),
5990
5991
5992
5993
                    self.hidden_size_per_attention_head,
                )
                # split along second last dimension
                split_dim = -2
cyanguwa's avatar
cyanguwa committed
5994
5995
5996
5997
5998
            else:
                # [sq, b, ng * (np/ng + 2) * hn] --> [sq, b, (np/ng + 2), ng, hn]
                new_tensor_shape = mixed_x_layer.size()[:-1] + (
                    (num_queries_per_key_value + 2),
                    self.num_gqa_groups_per_partition,
5999
                    self.hidden_size_per_attention_head,
cyanguwa's avatar
cyanguwa committed
6000
6001
6002
                )
                # split along third last dimension
                split_dim = -3
6003
6004
6005

            mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)

cyanguwa's avatar
cyanguwa committed
6006
6007
6008
6009
6010
6011
6012
6013
6014
            # qkv_weight_interleaved:
            #  [sq, b, ng, (np/ng + 2), hn]
            #  --> [sq, b, ng, np/ng, hn], [sq, b, ng, 1, hn], [sq, b, ng, 1, hn]
            # not qkv_weight_interleaved:
            #  [sq, b, (np/ng + 2), ng, hn]
            #  --> [sq, b, np/ng, np, hn], [sq, b, 1, ng, hn], [sq, b, 1, ng, hn]
            if not is_in_onnx_export_mode():
                query_layer, key_layer, value_layer = _SplitAlongDim.apply(
                    mixed_x_layer, split_dim, (num_queries_per_key_value, 1, 1)
6015
                )
6016
            else:
cyanguwa's avatar
cyanguwa committed
6017
                query_layer, key_layer, value_layer = torch.split(
6018
6019
6020
6021
                    mixed_x_layer,
                    (num_queries_per_key_value, 1, 1),
                    dim=split_dim,
                )
cyanguwa's avatar
cyanguwa committed
6022
6023
6024

            # query: -> [sq, b, np, hn]
            # key, value: -> [sq, b, ng, hn]
6025
6026
6027
6028
            query_layer, key_layer, value_layer = (
                x.reshape(x.size(0), x.size(1), -1, self.hidden_size_per_attention_head)
                for x in (query_layer, key_layer, value_layer)
            )
cyanguwa's avatar
cyanguwa committed
6029
6030
6031

        elif self.attention_type == "cross":
            # Attention heads [sk, b, h] --> [sk, b, (ng * 2 * hn)]
6032
            mixed_kv_layer = self.key_value(
cyanguwa's avatar
cyanguwa committed
6033
                encoder_output,
6034
                is_first_microbatch=is_first_microbatch,
6035
                is_first_module_in_mha=True,  # specific to FP8 MHA
6036
6037
6038
            )

            if self.qkv_weight_interleaved:
cyanguwa's avatar
cyanguwa committed
6039
                # [sq, b, (ng * 2 * hn)] --> [sq, b, ng, 2 * hn]
6040
                new_tensor_shape = mixed_kv_layer.size()[:-1] + (
6041
                    self.num_gqa_groups_per_partition,
6042
6043
6044
6045
6046
                    2 * self.hidden_size_per_attention_head,
                )
                # split along last dimension
                split_dim = -1
            else:
cyanguwa's avatar
cyanguwa committed
6047
                # [sq, b, (ng * 2 * hn)] --> [sq, b, 2 * ng, hn]
6048
                new_tensor_shape = mixed_kv_layer.size()[:-1] + (
6049
                    2 * self.num_gqa_groups_per_partition,
6050
6051
6052
6053
6054
6055
6056
                    self.hidden_size_per_attention_head,
                )
                # split along second last dimension
                split_dim = -2

            mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)

cyanguwa's avatar
cyanguwa committed
6057
6058
6059
            # mixed_kv_layer --> 2 [sk, b, ng, hn]
            if not is_in_onnx_export_mode():
                key_layer, value_layer = _SplitAlongDim.apply(
6060
6061
6062
                    mixed_kv_layer,
                    split_dim,
                    mixed_kv_layer.shape[split_dim] // 2,
cyanguwa's avatar
cyanguwa committed
6063
                )
6064
            else:
cyanguwa's avatar
cyanguwa committed
6065
                key_layer, value_layer = torch.split(
6066
6067
6068
                    mixed_kv_layer,
                    mixed_kv_layer.shape[split_dim] // 2,
                    dim=split_dim,
cyanguwa's avatar
cyanguwa committed
6069
                )
6070
6071
6072
6073
6074
6075
6076
6077
6078
            key_layer, value_layer = (
                x.reshape(
                    x.size(0),
                    x.size(1),
                    -1,
                    self.hidden_size_per_attention_head,
                )
                for x in (key_layer, value_layer)
            )
6079
6080
6081
6082
6083
6084
6085
6086
6087
6088
6089
6090
6091
6092
6093

            # Attention head [sq, b, h] --> [sq, b, hp]
            if self.input_layernorm:
                layernorm_query_outputs = self.layernorm_query(
                    hidden_states,
                    is_first_microbatch=is_first_microbatch,
                )
                if self.return_layernorm_output:
                    query_layer, layernorm_output = layernorm_query_outputs
                else:
                    query_layer = layernorm_query_outputs
            else:
                query_layer = self.query_layer(
                    hidden_states,
                    is_first_microbatch=is_first_microbatch,
6094
                    is_first_module_in_mha=True,  # specific to FP8 MHA
6095
6096
6097
6098
6099
6100
6101
6102
6103
                )

            # [sq, b, hp] --> [sq, b, np, hn]
            new_tensor_shape = query_layer.size()[:-1] + (
                self.num_attention_heads_per_partition,
                self.hidden_size_per_attention_head,
            )
            query_layer = query_layer.view(*new_tensor_shape)

6104
6105
6106
        # ======================================================
        # Apply relative positional encoding (rotary embedding)
        # ======================================================
6107

6108
        if rotary_pos_emb is not None:
6109
6110
6111
            assert not isinstance(query_layer, Float8Tensor) and not isinstance(
                key_layer, Float8Tensor
            ), "RoPE is not supported for Float8Tensors!"
6112
            # duplicate the pos_emb for self attention
6113
            if not isinstance(rotary_pos_emb, tuple):
6114
                rotary_pos_emb = (rotary_pos_emb,) * 2
6115
6116

            q_pos_emb, k_pos_emb = rotary_pos_emb
6117
6118
6119
6120
6121
6122
6123
6124
6125
6126
6127
6128
6129
6130

            # adjust key and value for inference
            if inference_params is not None:
                if self.qkv_format == "sbhd":
                    sequence_length = key_layer.size(0)
                elif self.qkv_format == "bshd":
                    sequence_length = key_layer.size(1)

                sequence_start = inference_params.sequence_len_offset
                sequence_end = sequence_start + sequence_length

                q_pos_emb = q_pos_emb[sequence_start:sequence_end, ...]
                k_pos_emb = k_pos_emb[sequence_start:sequence_end, ...]

6131
6132
            query_layer = apply_rotary_pos_emb(query_layer, q_pos_emb, self.qkv_format, fused=True)
            key_layer = apply_rotary_pos_emb(key_layer, k_pos_emb, self.qkv_format, fused=True)
6133

6134
6135
6136
6137
        # ===========================
        # Core attention computation
        # ===========================

6138
6139
6140
6141
        context_layer = self.core_attention(
            query_layer,
            key_layer,
            value_layer,
6142
            qkv_format=self.qkv_format,
6143
6144
            cu_seqlens_q=None,
            cu_seqlens_kv=None,
6145
6146
            attention_mask=attention_mask,
            attn_mask_type=attn_mask_type,
6147
            window_size=window_size,
6148
6149
6150
            checkpoint_core_attention=checkpoint_core_attention,
            core_attention_bias_type=core_attention_bias_type,
            core_attention_bias=core_attention_bias,
6151
            alibi_slopes=alibi_slopes,
6152
            fast_zero_fill=fast_zero_fill,
6153
            inference_params=inference_params,
6154
6155
        )

6156
        # ===================
6157
        # Output. [sq, b, h]
6158
        # ===================
6159

6160
        projection_output = self.proj(
6161
6162
            context_layer,
            is_first_microbatch=is_first_microbatch,
6163
6164
        )

6165
6166
6167
6168
6169
6170
6171
6172
        if self.return_bias:
            attention_output, attention_bias = projection_output
        else:
            attention_output, attention_bias = projection_output, None

        outputs = (attention_output,)
        if self.return_bias:
            outputs += (attention_bias,)
6173
        if self.input_layernorm and self.return_layernorm_output:
6174
6175
            outputs += (layernorm_output,)
        return outputs if len(outputs) > 1 else outputs[0]