attention.py 239 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.

"""Attention."""
6
import collections
7
from contextlib import nullcontext
8
from importlib.metadata import version as get_pkg_version
9
import math
10
import os
11
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
12
import warnings
13
import logging
14

cyanguwa's avatar
cyanguwa committed
15
import numpy as np
16
from packaging.version import Version as PkgVersion
17
18

import torch
19
import torch.nn.functional as F
20

21
import transformer_engine_torch as tex
22
23
import transformer_engine as te
from transformer_engine.pytorch.utils import get_cudnn_version
24
25
26
27
from transformer_engine.pytorch.cpp_extensions import (
    cast_to_fp8,
    cast_from_fp8,
)
28
29
30
31
32
from transformer_engine.pytorch.cpp_extensions.fused_attn import (
    fused_attn_fwd_qkvpacked,
    fused_attn_bwd_qkvpacked,
    fused_attn_fwd_kvpacked,
    fused_attn_bwd_kvpacked,
33
34
    fused_attn_fwd,
    fused_attn_bwd,
35
36
37
38
39
    QKVLayout,
    AttnBiasType,
    AttnMaskType,
    FusedAttnBackend,
)
40
41
from transformer_engine.pytorch.fp8 import get_fp8_te_dtype
from transformer_engine.pytorch.float8_tensor import Float8Tensor
42
from transformer_engine.pytorch.module import LayerNormLinear, Linear
43
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
44
45
46
47
48
from transformer_engine.pytorch.utils import (
    divide,
    attention_mask_func,
    split_tensor_along_dim,
    get_device_compute_capability,
49
    get_default_init_method,
50
51
52
53
)
from transformer_engine.pytorch.constants import (
    AttnMaskTypes,
    AttnTypes,
54
    AttnBiasTypes,
55
    QKVLayouts,
56
    dist_group_type,
57
    TE_DType,
58
59
60
61
)
from transformer_engine.pytorch.softmax import FusedScaleMaskSoftmax
from transformer_engine.pytorch.distributed import (
    get_distributed_world_size,
62
    get_distributed_rank,
63
    checkpoint,
64
65
66
    set_all_rng_states,
    CudaRNGStatesTracker,
    graph_safe_rng_available,
67
68
)
from transformer_engine.pytorch.export import is_in_onnx_export_mode
69
from transformer_engine.pytorch.jit import jit_fuser, no_torch_dynamo
70
71
from transformer_engine.pytorch.graph import is_graph_capturing

72

73
74
75
76
77
78
79
_flash_attn_version = PkgVersion(get_pkg_version("flash-attn"))
_flash_attn_version_required = PkgVersion("2.0.6")
_flash_attn_max_version = PkgVersion("2.5.8")
_flash_attn_2_1_plus = _flash_attn_version >= PkgVersion("2.1")
_flash_attn_2_3_plus = _flash_attn_version >= PkgVersion("2.3")
_flash_attn_2_4_plus = _flash_attn_version >= PkgVersion("2.4")
_flash_attn_2_4_1_plus = _flash_attn_version >= PkgVersion("2.4.1")
80

81
if _flash_attn_version >= _flash_attn_version_required:
82
    from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_forward_func # pylint: disable=no-name-in-module
83
    from flash_attn_2_cuda import varlen_bwd as flash_attn_cuda_bwd # pylint: disable=no-name-in-module
84
85
    from flash_attn.flash_attn_interface import _flash_attn_varlen_forward as _flash_attn_forward # pylint: disable=no-name-in-module,ungrouped-imports
    from flash_attn.flash_attn_interface import _flash_attn_varlen_backward as _flash_attn_backward # pylint: disable=no-name-in-module
86

87
88
89
90
91
92
META_QKV  = tex.FP8FwdTensors.GEMM1_OUTPUT
META_DQKV = tex.FP8BwdTensors.GRAD_OUTPUT1
META_O    = tex.FP8FwdTensors.GEMM2_INPUT
META_DO   = tex.FP8BwdTensors.GRAD_INPUT2
META_S    = tex.FP8FwdTensors.GEMM3_OUTPUT
META_DP   = tex.FP8BwdTensors.GRAD_INPUT3
93

94
# NVTE_DEBUG = 0/1 # disables/enables debug mode, default = 0
95
_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
96
97
98
99
100
101
102
103
104
# NVTE_DEBUG_LEVEL = 0/1/2 # enables more and more verbose debug mode, default = 0
_NVTE_DEBUG_LEVEL = int(os.getenv("NVTE_DEBUG_LEVEL", "0"))
log_level = _NVTE_DEBUG * _NVTE_DEBUG_LEVEL
log_levels = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG}
logging.basicConfig(
    format='[%(levelname)-8s | %(name)-19s]: %(message)s',
    level=log_levels[log_level if log_level in [0,1,2] else 2],
)

105
106
107
108
109
110
111
112
113
_alibi_cache = {
    "_num_heads": None,
    "_alibi_slopes": None,
    "_max_seqlen_q": None,
    "_max_seqlen_kv": None,
    "_alibi_bias": None,
    "_alibi_slopes_require_update": False,
    "_alibi_bias_require_update": False,
    }
114
115


116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
__all__ = ["DotProductAttention", "InferenceParams", "MultiheadAttention"]

class InferenceParams: # pylint: disable=too-few-public-methods
    """
    Inference parameters that are passed to the main model in order
    to efficienly calculate and store the context during inference.

    Parameters
    ----------
    max_batch_size : int
                    maximum batch size during inference.
    max_sequence_length : int
                         maximum sequence length during inference.
    """

    def __init__(self, max_batch_size, max_sequence_length):
        self.max_sequence_length = max_sequence_length
        self.max_batch_size = max_batch_size
        self.sequence_len_offset = 0
        self.batch_size_offset = 0
        self.key_value_memory_dict = {}

    def swap_key_value_dict(self, batch_indices):
        """
        Reorders the KV cache using the specified batch indices.

        Parameters
        ----------
        batch_indices : List[int]
                       Sequence of indices to reorder along the batch dimensions of
                       the KV cache. Must have a length equal to the batch size.
        """
        if len(self.key_value_memory_dict) == 0:
            raise ValueError("should not swap when dict in empty")

        for layer_number, inference_memory in self.key_value_memory_dict.items():
            inference_key_memory, inference_value_memory = inference_memory
            assert (
                len(batch_indices) == inference_key_memory.shape[1]
            )  # make sure batch size is the same
            new_inference_key_memory = inference_key_memory[:, batch_indices]
            new_inference_value_memory = inference_value_memory[:, batch_indices]
            self.key_value_memory_dict[layer_number] = (
                new_inference_key_memory,
                new_inference_value_memory,
            )
162

163
164
165
166
167
@torch.no_grad()
def get_alibi(
    num_heads: int,
    max_seqlen_q: int,
    max_seqlen_kv: int,
168
169
170
    alibi_slopes: Optional[torch.Tensor] = None,
    bias_dtype: Optional[torch.dtype] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
171
    """
172
173
174
175
176
177
178
179
180
181
182
183
    Parameters
    ----------
    num_heads: int
        Number of heads.
    max_seqlen_q: int
        Maximum sequence length for queries.
    max_seqlen_kv: int
        Maximum sequence length for keys and values.
    alibi_slopes: Optional[torch.Tensor], default = `None`
        Custom ALiBi slopes, FP32, CUDA tensor, in shape [num_heads] or [batch_size, num_heads].
    bias_dtype: Optional[torch.dtype], default = `None`
        Dtype of the generated ALiBi bias. If None, use torch.float32.
184

185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
    Returns
    ----------
    alibi_slopes: torch.Tensor
        ALiBi slopes in FP32 and shape [num_heads] or [batch_size, num_heads].
    alibi_bias: torch.Tensor
        ALiBi bias in FP32 or `bias_dtype`. If `alibi_slopes` is in [num_heads] shape,
        then `alibi_bias` is in [1, num_heads, max_seqlen_q, max_seqlen_kv], and if
        `alibi_slopes` is in [batch_size, num_heads], then the bias is in
        [batch_size, num_heads, max_seqlen_q, max_seqlen_kv].
    """
    global _alibi_cache
    if _alibi_cache["_alibi_slopes_require_update"]:
        if alibi_slopes is not None:
            _alibi_cache["_alibi_slopes"] = alibi_slopes
        else:
            n = 2 ** math.floor(math.log2(num_heads))
            m_0 = 2.0 ** (-8.0 / n)
            m = torch.pow(m_0, torch.arange(1, 1 + n))

            if n < num_heads:
                m_hat_0 = 2.0 ** (-4.0 / n)
                m_hat = torch.pow(m_hat_0, torch.arange(1, 1 + 2 * (num_heads - n), 2))
                m = torch.cat([m, m_hat])

            _alibi_cache["_alibi_slopes"] = m.to(dtype=torch.float32, device="cuda")
        _alibi_cache["_num_heads"] = num_heads
        _alibi_cache["_alibi_slopes_require_update"] = False

    if _alibi_cache["_alibi_bias_require_update"]:
        assert _alibi_cache["_alibi_slopes"] is not None, "ALiBi slopes can not be None!"
        if _alibi_cache["_alibi_slopes"].dim() == 1:
            slopes_shape = torch.Size([1, _alibi_cache["_alibi_slopes"].shape[0], 1, 1])
        if _alibi_cache["_alibi_slopes"].dim() == 2:
            slopes_shape = torch.Size([*_alibi_cache["_alibi_slopes"].shape[:], 1, 1])
        bias = torch.arange(
            1 - max_seqlen_kv, 1, dtype=torch.int32, device="cuda").view(1, 1, 1, max_seqlen_kv)
        bias = bias - torch.arange(
            1 - max_seqlen_q, 1, dtype=torch.int32, device="cuda").view(1, 1, max_seqlen_q, 1)
        bias = bias.abs().mul(-1)
        bias = bias * _alibi_cache["_alibi_slopes"].view(slopes_shape)
        _alibi_cache["_max_seqlen_q"], _alibi_cache["_max_seqlen_kv"] = max_seqlen_q, max_seqlen_kv
        bias_dtype = torch.float32 if bias_dtype is None else bias_dtype
        _alibi_cache["_alibi_bias"] = bias.contiguous().to(dtype=bias_dtype, device="cuda")
        _alibi_cache["_alibi_bias_require_update"] = False

    return _alibi_cache["_alibi_slopes"], _alibi_cache["_alibi_bias"]
231
232
233
234
235
236
237
238
239


def get_cu_seqlens(mask: torch.Tensor) -> torch.Tensor:
    """
    Given a padding mask of shape [batch_size, 1, 1, max_seqlen], returns an int32
    tensor of shape [batch_size + 1] containing the cumulative sequence lengths of
    the samples in a batch.
    """
    mask = mask.squeeze(1).squeeze(1)
240
    reduced_mask = mask.logical_not().sum(dim=1)
241
242
243
244
245
246
    cu_seqlens = reduced_mask.cumsum(dim=0).to(torch.int32)
    zero = torch.zeros(1, dtype=torch.int32, device="cuda")
    cu_seqlens = torch.cat((zero, cu_seqlens))

    return cu_seqlens

247

248
249
250
def get_cu_seqlens_and_indices(mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Given a padding mask of shape [batch_size, 1, 1, max_seqlen], returns an int32
251
252
253
    tensor of shape [batch_size + 1] containing the cumulative sequence lengths of
    the samples in a batch, and another int32 tensor of shape [batch_size * max_seqlen, 1, 1]
    containing the indices for the valid tokens.
254
255
256
257
    """
    mask = mask.squeeze(1).squeeze(1)
    bs, seqlen = mask.shape

258
    reduced_mask = mask.logical_not().sum(dim=1)
259
260
261
262
263
    cu_seqlens = reduced_mask.cumsum(dim=0).to(torch.int32)
    zero = torch.zeros(1, dtype=torch.int32, device="cuda")
    cu_seqlens = torch.cat((zero, cu_seqlens))

    mask = mask.reshape(-1)
264
    indices = mask.logical_not().nonzero()
265
266
267
268
269
270
271
272
273
274
    indices = indices.unsqueeze(-1)

    num_nonzeros = indices.shape[0]
    pad_amount = bs * seqlen - num_nonzeros
    indices = F.pad(input=indices, pad=(0, 0, 0, 0, 0, pad_amount),
                    mode="constant", value=float(bs * seqlen))

    return cu_seqlens, indices


275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
def get_indices(max_seqlen: int, cu_seqlens: torch.Tensor) -> torch.Tensor:
    """
    Given max_seqlen and cu_seqlens of shape [batch_size + 1], returns an int32
    tensor of shape [batch_size * max_seqlen, 1, 1] containing the indices for
    the valid tokens in a batch.
    """
    bs = len(cu_seqlens) - 1
    seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
    indices = [i*max_seqlen + ii for i,j in enumerate(seqlens) for ii in range(j)]
    indices = torch.Tensor(indices).unsqueeze(1).unsqueeze(1).to(
                    dtype=torch.int64, device="cuda")

    num_nonzeros = indices.shape[0]
    pad_amount = bs * max_seqlen - num_nonzeros
    indices = F.pad(input=indices, pad=(0, 0, 0, 0, 0, pad_amount),
                    mode="constant", value=float(bs * max_seqlen))

    return indices

294
_cu_seqlens_cache = {}
295
296
297
298
299
300
301
302
303
304
def _get_full_cu_seqlens(
    batch_size: int,
    max_seqlen: int,
    device: torch.device,
) -> torch.Tensor:
    """Cumulative sequence lengths in full data batch

    All sequences in batch have the maximum sequence length.

    """
305
306
307
308
309
310
311
312
313
314
    global _cu_seqlens_cache
    if (batch_size, max_seqlen) not in _cu_seqlens_cache:
        _cu_seqlens_cache[(batch_size, max_seqlen)] = torch.arange(
            0,
            (batch_size + 1) * max_seqlen,
            step=max_seqlen,
            dtype=torch.int32,
            device=device,
        )
    return _cu_seqlens_cache[(batch_size, max_seqlen)]
315
316


317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
@jit_fuser
def pack_tensor(
    indices: torch.Tensor,
    tensor: torch.Tensor,
) -> torch.Tensor:
    """
    Packs the given tensor using the `indices`.
    """
    padding_indice = torch.zeros(
        1, tensor.shape[1], tensor.shape[2], dtype=tensor.dtype, device=tensor.device)
    tensor = torch.cat((tensor, padding_indice), dim=0)

    indices = indices.repeat(1, tensor.shape[1], tensor.shape[2])
    packed = torch.gather(tensor, 0, indices)
    return packed


@jit_fuser
def pack_2_tensors(
    indices: torch.Tensor,
    t1: torch.Tensor,
    t2: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Packs the given 2 tensors using the `indices`.
    """
    t1_packed = pack_tensor(indices, t1)
    t2_packed = pack_tensor(indices, t2)
    return t1_packed, t2_packed


@jit_fuser
def pack_3_tensors(
    indices: torch.Tensor,
    t1: torch.Tensor,
    t2: torch.Tensor,
    t3: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Packs the given 3 tensors using the `indices`.
    """
    t1_packed = pack_tensor(indices, t1)
    t2_packed = pack_tensor(indices, t2)
    t3_packed = pack_tensor(indices, t3)
    return t1_packed, t2_packed, t3_packed


@jit_fuser
def unpack_tensor(
    indices: torch.Tensor,
    dim0: int,
    tensor: torch.Tensor,
) -> torch.Tensor:
    """
    Inverse of `pack_tensor`.
    """
    indices = indices.repeat(1, tensor.shape[1], tensor.shape[2])
    unpacked = torch.zeros(
        dim0 + 1, tensor.shape[1], tensor.shape[2], dtype=tensor.dtype, device=tensor.device)
    unpacked.scatter_(0, indices, tensor)
    unpacked = unpacked[0:-1,:,:]
    return unpacked


@jit_fuser
def unpack_2_tensors(
    indices: torch.Tensor,
    dim0: int,
    t1: torch.Tensor,
    t2: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Inverse of `pack_2_tensors`.
    """
    t1_unpacked = unpack_tensor(indices, dim0, t1)
    t2_unpacked = unpack_tensor(indices, dim0, t2)
    return t1_unpacked, t2_unpacked


@jit_fuser
def unpack_3_tensors(
    indices: torch.Tensor,
    dim0: int,
    t1: torch.Tensor,
    t2: torch.Tensor,
    t3: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Inverse of `pack_3_tensors`.
    """
    t1_unpacked = unpack_tensor(indices, dim0, t1)
    t2_unpacked = unpack_tensor(indices, dim0, t2)
    t3_unpacked = unpack_tensor(indices, dim0, t3)
    return t1_unpacked, t2_unpacked, t3_unpacked


class PackTensors(torch.autograd.Function):
    """
    Autograd function to pack tensors.
    """
    @staticmethod
    def forward(
        ctx,
        indices: torch.Tensor,
        *tensors: Tuple[torch.Tensor, ...]
    ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
        assert 1 <= len(tensors) <= 3, f"Packing {len(tensors)} tensors not supported."
424
        ctx.save_for_backward(indices)
425
426
427
428
429
430
431
432
433
        ctx.dim0 = tensors[0].shape[0]
        if len(tensors) == 1:
            return pack_tensor(indices, *tensors)
        if len(tensors) == 2:
            return pack_2_tensors(indices, *tensors)
        return pack_3_tensors(indices, *tensors)

    @staticmethod
    def backward(ctx, *grad_outputs: Tuple[torch.Tensor, ...]):
434
        (indices,) = ctx.saved_tensors
435
        if len(grad_outputs) == 1:
436
            return None, unpack_tensor(indices, ctx.dim0, *grad_outputs)
437
        if len(grad_outputs) == 2:
438
439
            return None, *unpack_2_tensors(indices, ctx.dim0, *grad_outputs)
        return None, *unpack_3_tensors(indices, ctx.dim0, *grad_outputs)
440
441
442
443
444
445
446
447
448
449
450
451
452


class UnpackTensor(torch.autograd.Function):
    """
    Autograd function to unpack a tensor.
    """
    @staticmethod
    def forward(
        ctx,
        indices: torch.Tensor,
        dim0: int,
        tensor: torch.Tensor,
    ) -> torch.Tensor:
453
        ctx.save_for_backward(indices)
454
455
456
457
        return unpack_tensor(indices, dim0, tensor)

    @staticmethod
    def backward(ctx, grad_output):
458
459
        (indices,) = ctx.saved_tensors
        return None, None, pack_tensor(indices, grad_output)
460
461


462
463
464
def flash_attn_p2p_communicate(rank, send_tensor, send_dst,
                               recv_tensor, recv_src,
                               cp_group, batch_p2p_comm):
465
    """Point-to-point communications of KV and dKV in Attention with context parallelism"""
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
    send_recv_ops = []

    if batch_p2p_comm:
        if rank % 2 == 0:
            send_op = torch.distributed.P2POp(torch.distributed.isend,
                                              send_tensor,
                                              send_dst,
                                              cp_group)
            recv_op = torch.distributed.P2POp(torch.distributed.irecv,
                                              recv_tensor,
                                              recv_src,
                                              cp_group)
            send_recv_ops.append(send_op)
            send_recv_ops.append(recv_op)
        else:
            recv_op = torch.distributed.P2POp(torch.distributed.irecv,
                                              recv_tensor,
                                              recv_src,
                                              cp_group)
            send_op = torch.distributed.P2POp(torch.distributed.isend,
                                              send_tensor,
                                              send_dst,
                                              cp_group)
            send_recv_ops.append(recv_op)
            send_recv_ops.append(send_op)
        send_recv_reqs = torch.distributed.batch_isend_irecv(send_recv_ops)
    else:
        if rank % 2 == 0:
            send_op = torch.distributed.isend(send_tensor, send_dst, cp_group)
            recv_op = torch.distributed.irecv(recv_tensor, recv_src, cp_group)
            send_recv_ops.append(send_op)
            send_recv_ops.append(recv_op)
        else:
            recv_op = torch.distributed.irecv(recv_tensor, recv_src, cp_group)
            send_op = torch.distributed.isend(send_tensor, send_dst, cp_group)
            send_recv_ops.append(recv_op)
            send_recv_ops.append(send_op)
        send_recv_reqs = send_recv_ops

    return send_recv_reqs


508
@jit_fuser
509
510
def flash_attn_fwd_out_correction(out, out_per_step, seq_dim,
                                  softmax_lse, softmax_lse_per_step):
511
    """Merge partial outputs of each step in Attention with context parallelism"""
512
    softmax_lse_corrected_exp = torch.exp(softmax_lse_per_step - softmax_lse).movedim(2, seq_dim)
513
514
515
516
517
    softmax_lse_corrected_exp = softmax_lse_corrected_exp.unsqueeze(-1)
    out_corrected = out_per_step*softmax_lse_corrected_exp
    out.add_(out_corrected)


518
@jit_fuser
519
def flash_attn_fwd_softmax_lse_correction(softmax_lse, softmax_lse_per_step):
520
    """Merge softmax stats of each step in Attention with context parallelism"""
521
522
523
524
    max_scale = torch.max(softmax_lse, softmax_lse_per_step)
    min_scale = torch.min(softmax_lse, softmax_lse_per_step)
    new_scale = max_scale + torch.log(1 + torch.exp(min_scale - max_scale))
    softmax_lse.copy_(new_scale)
525
526


527
class AttnFuncWithCP(torch.autograd.Function):
528
    """
529
530
    Attention implementation with context parallelism.
    Split attention compute into multiple steps, and overlap current-step
531
532
533
534
    compute with next-step communication.
    """

    @staticmethod
535
    def forward(ctx, is_training, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
536
537
538
                seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, dropout_p,
                cp_group, cp_global_ranks, cp_stream, softmax_scale, qkv_format, attn_mask_type,
                attn_bias_type, attn_bias, deterministic, use_fused_attention):
539
540
541
542
543
544
        if softmax_scale is None:
            softmax_scale = q.shape[-1] ** (-0.5)

        cp_size = get_distributed_world_size(cp_group)
        rank = get_distributed_rank(cp_group)
        send_dst = cp_global_ranks[(rank + 1) % cp_size]
545
        recv_src = cp_global_ranks[(rank - 1) % cp_size]
546
547
        batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2)

548
549
        causal = ("causal" in attn_mask_type)
        padding = ("padding" in attn_mask_type)
550

551
552
        qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format

553
        if causal:
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
            if qkv_format == "bshd":
                # [b, s, np, hn] -> [b, 2, s//2, np, hn]
                q, k, v = [x.view(x.shape[0], 2, x.shape[1]//2, *x.shape[2:]) for x in [q, k, v]]
            elif qkv_format == "sbhd":
                # [s, b, np, hn] -> [2, s//2, b, np, hn]
                q, k, v = [x.view(2, x.shape[0]//2, *x.shape[1:]) for x in [q, k, v]]
        if attn_bias is not None:
            assert (len(attn_bias.shape) == 4), (
                "Only support bias shape of [b, h, sq, sk] for forward, "
                "and [1, h, sq, sk] for backward!"
            )
            # [b, np, sq, sk] -> [b, np, 2, sq//2, 2*cp, sk//(2*cp)]
            attn_bias_ = attn_bias.view( \
                *attn_bias.shape[:-2], \
                2, attn_bias.shape[-2]//2, \
                2*cp_size, attn_bias.shape[-1]//(2*cp_size) \
            )
            # [b, np, sq, sk] -> [b, np, sq, 2*cp, sk//(2*cp)]
            attn_bias = attn_bias.view( \
                *attn_bias.shape[:-1], \
                2*cp_size, attn_bias.shape[-1]//(2*cp_size) \
            )
576
        assert(q.shape[-1] % 8 == 0), "hidden size per attention head should be multiple of 8"
577
578
579
580
581
        fa_optional_forward_kwargs = {}
        if _flash_attn_2_3_plus:
            fa_optional_forward_kwargs["window_size"] = [-1, 0] if causal else [-1, -1]
        if _flash_attn_2_4_plus:
            fa_optional_forward_kwargs["alibi_slopes"] = None
582

583
584
585
        # Flash Attn inputs
        q_inputs = [None, None]
        kv_inputs = [None, None]
586
        attn_bias_inputs = [None, None]
587
588
589
590
        # Flash Attn outputs
        out_per_step = [None for _ in range(cp_size)]
        softmax_lse_per_step = [None for _ in range(cp_size)]
        rng_states = [None for _ in range(cp_size)]
591
        attn_biases = [None for _ in range(cp_size)]
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621

        # create two streams to resolve wave quantization issue of Flash Attn in each step
        flash_attn_streams = [torch.cuda.current_stream(), cp_stream]
        # synchronize fwd results correction across steps
        fwd_results_correction_done = torch.cuda.Event()

        p2p_comm_buffers = [None for _ in range(cp_size)]
        p2p_comm_buffers[0] = torch.cat((k.unsqueeze(0), v.unsqueeze(0)), dim=0)
        send_recv_reqs = [[], []]

        for i in range(cp_size+1):
            if i < cp_size:
                with torch.cuda.stream(flash_attn_streams[i%2]):
                    # wait until KV is received
                    for req in send_recv_reqs[(i+1)%2]:
                        req.wait()

                    if i < (cp_size-1):
                        p2p_comm_buffers[i+1] = torch.empty_like(p2p_comm_buffers[i])
                        send_recv_reqs[i%2] = flash_attn_p2p_communicate(rank,
                                                                         p2p_comm_buffers[i],
                                                                         send_dst,
                                                                         p2p_comm_buffers[i+1],
                                                                         recv_src,
                                                                         cp_group,
                                                                         batch_p2p_comm)

                    kv_inputs[i%2] = p2p_comm_buffers[i]
                    if causal:
                        if i == 0:
622
                            if use_fused_attention:
623
624
625
626
627
628
629
630
631
632
633
634
                                if qkv_format == "bshd":
                                    # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
                                    q_inputs[i%2] = q.view(q.shape[0], -1, *q.shape[-2:])
                                    # [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn]
                                    kv_inputs[i%2] = kv_inputs[i%2].view(
                                        2, k.shape[0], -1, *k.shape[-2:])
                                elif qkv_format == "sbhd":
                                    # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
                                    q_inputs[i%2] = q.view(-1, *q.shape[-3:])
                                    # [2, 2, sk//2, b, np, hn] -> [2, sk, b, np, hn]
                                    kv_inputs[i%2] = kv_inputs[i%2].view(
                                        2, -1, *k.shape[-3:])
635
636
                                elif qkv_format == "thd":
                                    q_inputs[i%2] = q
637
638
639
640
641
642
643
644
                                if attn_bias is not None:
                                    idx = (rank - i) % cp_size
                                    attn_bias_inputs[i%2] = torch.cat(
                                        (attn_bias[..., idx, :], \
                                         attn_bias[..., (2*cp_size-idx-1), :]),
                                        dim=-1
                                    ).contiguous()
                                out_per_step[i], [softmax_lse_per_step[i], rng_states[i], *rest] = \
645
646
647
648
649
650
                                fused_attn_fwd(
                                    is_training, max_seqlen_q, max_seqlen_k, cu_seqlens_q,
                                    cu_seqlens_k, q_inputs[i%2], kv_inputs[i%2][0],
                                    kv_inputs[i%2][1], TE_DType[q.dtype],
                                    tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
                                    attn_scale=softmax_scale, dropout=dropout_p,
651
                                    qkv_layout=qkv_layout, attn_mask_type=attn_mask_type,
652
                                    attn_bias_type=attn_bias_type, attn_bias=attn_bias_inputs[i%2],
653
654
                                    seq_offsets_q=seq_offsets_q, seq_offsets_k=seq_offsets_k,
                                    seq_offsets_v=seq_offsets_v, seq_offsets_o=seq_offsets_o,
655
                                )
656
657
                                if len(rest) > 0:
                                    attn_biases[i] = rest[0]
658
659
660
661
662
663
664
665
666
667
668
669
                            else:
                                # [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
                                q_inputs[i%2] = q.view(-1, *q.shape[-2:])
                                # [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn]
                                kv_inputs[i%2] = kv_inputs[i%2].view(2, -1, *k.shape[-2:])
                                _, _, _, _, out_per_step[i], \
                                softmax_lse_per_step[i], _, rng_states[i] = _flash_attn_forward(
                                    q_inputs[i%2], kv_inputs[i%2][0], kv_inputs[i%2][1],
                                    cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
                                    dropout_p, softmax_scale, causal=True, return_softmax=False,
                                    **fa_optional_forward_kwargs
                                )
670
                        elif i <= rank:
671
                            if use_fused_attention:
672
673
674
675
676
677
678
679
680
681
                                if qkv_format == "bshd":
                                    # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
                                    q_inputs[i%2] = q.view(q.shape[0], -1, *q.shape[-2:])
                                    # [2, b, 2, sk//2, np, hn] -> [2, b, sk//2, np, hn]
                                    kv_inputs[i%2] = kv_inputs[i%2][:, :, 0, ...].contiguous()
                                elif qkv_format == "sbhd":
                                    # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
                                    q_inputs[i%2] = q.view(-1, *q.shape[-3:])
                                    # [2, 2, sk//2, b, np, hn] -> [2, sk//2, b, np, hn]
                                    kv_inputs[i%2] = kv_inputs[i%2][:, 0, ...].contiguous()
682
683
684
685
686
                                elif qkv_format == "thd":
                                    q_inputs[i%2] = q
                                    # [2, t, np, hn] -> [2, t/2, np, hn]
                                    kv_inputs[i%2] = tex.thd_read_half_tensor(
                                        kv_inputs[i%2], cu_seqlens_k, 0)
687
688
689
690
                                if attn_bias is not None:
                                    idx = (rank - i) % cp_size
                                    attn_bias_inputs[i%2] = attn_bias[..., idx, :].contiguous()
                                out_per_step[i], [softmax_lse_per_step[i], rng_states[i], *rest] = \
691
692
693
694
695
                                fused_attn_fwd(
                                    is_training, max_seqlen_q, max_seqlen_k//2, cu_seqlens_q,
                                    cu_seqlens_k//2, q_inputs[i%2], kv_inputs[i%2][0],
                                    kv_inputs[i%2][1], TE_DType[q.dtype],
                                    tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
696
697
698
699
700
701
702
703
704
705
706
707
                                    attn_scale=softmax_scale,
                                    dropout=dropout_p,
                                    qkv_layout=qkv_layout,
                                    attn_mask_type="padding" if padding else "no_mask",
                                    attn_bias_type=attn_bias_type,
                                    attn_bias=attn_bias_inputs[i%2],
                                    seq_offsets_q=seq_offsets_q,
                                    seq_offsets_k=None if seq_offsets_k is None \
                                        else seq_offsets_k//2,
                                    seq_offsets_v=None if seq_offsets_v is None \
                                        else seq_offsets_v//2,
                                    seq_offsets_o=seq_offsets_o,
708
                                )
709
710
                                if len(rest) > 0:
                                    attn_biases[i] = rest[0]
711
712
713
                            else:
                                # [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
                                q_inputs[i%2] = q.view(-1, *q.shape[-2:])
714
715
716
717
718
719
720
                                if qkv_format == "thd":
                                    # [2, t, np, hn] -> [2, t/2, np, hn]
                                    kv_inputs[i%2] = tex.thd_read_half_tensor(
                                        kv_inputs[i%2], cu_seqlens_k, 0)
                                else:
                                    # [2, b, 2, sk//2, np, hn] -> [2, b, sk//2, np, hn]
                                    kv_inputs[i%2] = kv_inputs[i%2][:, :, 0, ...].contiguous()
721
722
723
724
725
726
727
728
729
730
731
732
733
                                # [2, b, sk//2, np, hn] -> [2, b*sk//2, np, hn]
                                kv_inputs[i%2] = kv_inputs[i%2].view(2, -1, *k.shape[-2:])
                                if _flash_attn_2_3_plus:
                                    fa_optional_forward_kwargs["window_size"] = [-1, -1]
                                _, _, _, _, out_per_step[i], \
                                softmax_lse_per_step[i], _, rng_states[i] = _flash_attn_forward(
                                    q_inputs[i%2], kv_inputs[i%2][0], kv_inputs[i%2][1],
                                    cu_seqlens_q, cu_seqlens_k//2, max_seqlen_q, max_seqlen_k//2,
                                    dropout_p, softmax_scale, causal=False, return_softmax=False,
                                    **fa_optional_forward_kwargs
                                )
                        else:
                            if use_fused_attention:
734
735
736
737
738
739
740
741
742
743
744
745
                                if qkv_format == "bshd":
                                    # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
                                    q_inputs[i%2] = q[:, 1, ...].contiguous()
                                    # [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn]
                                    kv_inputs[i%2] = kv_inputs[i%2].view(
                                        2, k.shape[0], -1, *k.shape[-2:])
                                elif qkv_format == "sbhd":
                                    # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn]
                                    q_inputs[i%2] = q[1].contiguous()
                                    # [2, 2, sk//2, b, np, hn] -> [2, sk, b, np, hn]
                                    kv_inputs[i%2] = kv_inputs[i%2].view(
                                        2, -1, *k.shape[-3:])
746
747
748
                                elif qkv_format == "thd":
                                    # [t, np, hn] -> [t/2, np, hn]
                                    q_inputs[i%2] = tex.thd_read_half_tensor(q, cu_seqlens_q, 1)
749
750
751
752
753
754
755
756
                                if attn_bias is not None:
                                    idx = (rank - i) % cp_size
                                    attn_bias_inputs[i%2] = torch.cat(
                                        (attn_bias_[..., 1, :, idx, :], \
                                         attn_bias_[..., 1, :, (2*cp_size-idx-1), :]),
                                        dim=-1
                                    ).contiguous()
                                out_per_step[i], [softmax_lse_per_step[i], rng_states[i], *rest] = \
757
758
759
760
761
                                fused_attn_fwd(
                                    is_training, max_seqlen_q//2, max_seqlen_k, cu_seqlens_q//2,
                                    cu_seqlens_k, q_inputs[i%2], kv_inputs[i%2][0],
                                    kv_inputs[i%2][1], TE_DType[q.dtype],
                                    tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
762
763
764
765
766
767
768
769
770
771
772
773
                                    attn_scale=softmax_scale,
                                    dropout=dropout_p,
                                    qkv_layout=qkv_layout,
                                    attn_mask_type="padding" if padding else "no_mask",
                                    attn_bias_type=attn_bias_type,
                                    attn_bias=attn_bias_inputs[i%2],
                                    seq_offsets_q=None if seq_offsets_q is None \
                                        else seq_offsets_q//2,
                                    seq_offsets_k=seq_offsets_k,
                                    seq_offsets_v=seq_offsets_v,
                                    seq_offsets_o=None if seq_offsets_o is None \
                                        else seq_offsets_o//2,
774
                                )
775
776
                                if len(rest) > 0:
                                    attn_biases[i] = rest[0]
777
                            else:
778
779
780
781
782
783
784
                                if qkv_format == "thd":
                                    # [t, np, hn] -> [t/2, np, hn]
                                    q_inputs[i%2] = tex.thd_read_half_tensor(q, cu_seqlens_q, 1)
                                else:
                                    # [b, 2, sq//2, np, hn]->[b, sq//2, np, hn]->[b*sq//2, np, hn]
                                    q_inputs[i%2] = \
                                        q[:, 1, ...].contiguous().view(-1, *q.shape[-2:])
785
786
787
788
789
790
791
792
793
794
795
796
797
                                # [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn]
                                kv_inputs[i%2] = kv_inputs[i%2].view(2, -1, *k.shape[-2:])
                                if _flash_attn_2_3_plus:
                                    fa_optional_forward_kwargs["window_size"] = [-1, -1]
                                _, _, _, _, out_per_step[i], \
                                softmax_lse_per_step[i], _, rng_states[i] = _flash_attn_forward(
                                    q_inputs[i%2], kv_inputs[i%2][0], kv_inputs[i%2][1],
                                    cu_seqlens_q//2, cu_seqlens_k, max_seqlen_q//2, max_seqlen_k,
                                    dropout_p, softmax_scale, causal=False, return_softmax=False,
                                    **fa_optional_forward_kwargs
                                )
                    else:
                        if use_fused_attention:
798
799
800
801
802
803
804
                            if attn_bias is not None:
                                idx = (rank - i) % cp_size
                                attn_bias_inputs[i%2] = torch.cat(
                                    (attn_bias[..., idx, :], attn_bias[..., (2*cp_size-idx-1), :]),
                                    dim=-1
                                ).contiguous()
                            out_per_step[i], [softmax_lse_per_step[i], rng_states[i], *rest] = \
805
806
807
808
809
810
                            fused_attn_fwd(
                                is_training, max_seqlen_q, max_seqlen_k, cu_seqlens_q,
                                cu_seqlens_k, q, kv_inputs[i%2][0],
                                kv_inputs[i%2][1], TE_DType[q.dtype],
                                tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
                                attn_scale=softmax_scale, dropout=dropout_p,
811
                                qkv_layout=qkv_layout, attn_mask_type=attn_mask_type,
812
                                attn_bias_type=attn_bias_type, attn_bias=attn_bias_inputs[i%2],
813
814
                                seq_offsets_q=seq_offsets_q, seq_offsets_k=seq_offsets_k,
                                seq_offsets_v=seq_offsets_v, seq_offsets_o=seq_offsets_o,
815
                            )
816
817
                            if len(rest) > 0:
                                attn_biases[i] = rest[0]
818
                        else:
819
820
821
                            # [b, sq, np, hn] -> [b*sq, np, hn]
                            q_inputs[i%2] = q.view(-1, *q.shape[-2:])
                            # [2, b, sk, np, hn] -> [2, b*sk, np, hn]
822
                            kv_inputs[i%2] = kv_inputs[i%2].view(2, -1, *k.shape[-2:])
823
824
825
                            _, _, _, _, out_per_step[i], \
                            softmax_lse_per_step[i], _, rng_states[i] = _flash_attn_forward(
                                q_inputs[i%2], kv_inputs[i%2][0], kv_inputs[i%2][1],
826
827
828
                                cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
                                dropout_p, softmax_scale, causal=False, return_softmax=False,
                                **fa_optional_forward_kwargs
829
                            )
830
831
832
833
834
835

            if i > 0:
                # wait until fwd restuls correction of last step is done
                if i > 1:
                    flash_attn_streams[(i-1)%2].wait_event(fwd_results_correction_done)

836
837
838
839
                if use_fused_attention:
                    # [b, np, sq, 1] -> [b, np, sq]
                    softmax_lse_per_step[i-1].squeeze_(-1)

840
                with torch.cuda.stream(flash_attn_streams[(i-1)%2]):
841
842
843
                    if i == 1:
                        out = torch.empty_like(q).zero_()
                        softmax_lse = torch.clone(softmax_lse_per_step[0]).to(torch.double)
844
                        if causal and qkv_format != "thd":
845
846
847
848
                            # [b, np, sq] -> [b, np, 2, sq//2]
                            softmax_lse_ = softmax_lse.view(
                                *softmax_lse.shape[:-1], 2, softmax_lse.shape[-1]//2
                            )
849
850
851
                    elif (i-1) <= rank or not causal:
                        flash_attn_fwd_softmax_lse_correction(softmax_lse,
                                                              softmax_lse_per_step[i-1])
852
                    else:
853
854
855
856
857
858
859
860
                        if qkv_format == "thd":
                            tex.thd_second_half_lse_correction(softmax_lse,
                                                               softmax_lse_per_step[i-1],
                                                               cu_seqlens_q,
                                                               q.size(0))
                        else:
                            flash_attn_fwd_softmax_lse_correction(softmax_lse_[..., 1, :],
                                                                  softmax_lse_per_step[i-1])
861
862
863
864
865
866
867

                if i < cp_size:
                    flash_attn_streams[(i-1)%2].record_event(fwd_results_correction_done)

        torch.cuda.current_stream().wait_stream(flash_attn_streams[1])

        softmax_lse = softmax_lse.to(torch.float)
868
869
        if qkv_format in ["bshd", "sbhd"]:
            seq_dim = qkv_format.index("s")
870
        for i in range(cp_size):
871
872
873
874
875
876
            if qkv_format == "bshd":
                out_per_step[i] = out_per_step[i].view(out.shape[0], -1, *out.shape[-2:])
                out_ = out[:, 1, ...]
            elif qkv_format == "sbhd":
                out_per_step[i] = out_per_step[i].view(-1, *out.shape[-3:])
                out_ = out[1]
877

878
            if i <= rank or not causal:
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
                if qkv_format in ["bshd", "sbhd"]:
                    flash_attn_fwd_out_correction(out.view(*out_per_step[i].shape),
                                                  out_per_step[i],
                                                  seq_dim,
                                                  softmax_lse,
                                                  softmax_lse_per_step[i])
                elif qkv_format == "thd":
                    tex.thd_out_correction(out,
                                           out_per_step[i],
                                           softmax_lse,
                                           softmax_lse_per_step[i],
                                           cu_seqlens_q,
                                           False)
                else:
                    assert False, f"{qkv_format} is an unsupported qkv_format!"
894
            else:
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
                if qkv_format in ["bshd", "sbhd"]:
                    flash_attn_fwd_out_correction(out_,
                                                  out_per_step[i],
                                                  seq_dim,
                                                  softmax_lse_[..., 1, :],
                                                  softmax_lse_per_step[i])
                elif qkv_format == "thd":
                    tex.thd_out_correction(out,
                                           out_per_step[i],
                                           softmax_lse,
                                           softmax_lse_per_step[i],
                                           cu_seqlens_q,
                                           True)
                else:
                    assert False, f"{qkv_format} is an unsupported qkv_format!"
910
911

        kv = p2p_comm_buffers[-1]
912
        if use_fused_attention:
913
914
915
916
            if qkv_format == "bshd":
                out = out.view(out.shape[0], -1, *out.shape[-2:])
            elif qkv_format == "sbhd":
                out = out.view(-1, *out.shape[-3:])
917
918
        else:
            out = out.view(-1, *out.shape[-2:])
919

920
921
922
923
924
        ctx.save_for_backward(
            q, kv, out, softmax_lse, cu_seqlens_q, cu_seqlens_k,
            seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
            *rng_states, *attn_biases
        )
925
926
927
928
929
930
        ctx.cp_group = cp_group
        ctx.cp_global_ranks = cp_global_ranks
        ctx.dropout_p = dropout_p
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_k = max_seqlen_k
        ctx.softmax_scale = softmax_scale
931
        ctx.qkv_format = qkv_format
932
        ctx.attn_mask_type = attn_mask_type
933
934
        ctx.attn_bias_type = attn_bias_type
        ctx.attn_bias_shape = None if attn_bias is None else attn_bias.shape
935
        ctx.deterministic = deterministic
936
        ctx.use_fused_attention = use_fused_attention
937
938
939
940
        return out

    @staticmethod
    def backward(ctx, dout):
941
        (q, kv, out, softmax_lse, cu_seqlens_q, cu_seqlens_k) = ctx.saved_tensors[:6]
942
        (seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o) = ctx.saved_tensors[6:10]
943
        cp_size = get_distributed_world_size(ctx.cp_group)
944
945
        rng_states = ctx.saved_tensors[10:10+cp_size]
        attn_biases = ctx.saved_tensors[10+cp_size:10+cp_size*2]
946

947
        rank = get_distributed_rank(ctx.cp_group)
948
        send_dst = ctx.cp_global_ranks[(rank - 1) % cp_size]
949
950
951
        recv_src = ctx.cp_global_ranks[(rank + 1) % cp_size]
        batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2)

952
953
        causal = ("causal" in ctx.attn_mask_type)
        padding = ("padding" in ctx.attn_mask_type)
954
955
        qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format

956
        if attn_biases[0] is not None:
957
958
959
            # [b, np, sq, 2*cp, sk//(2*cp)]
            attn_dbias = torch.zeros(
                *ctx.attn_bias_shape,
960
961
                dtype=attn_biases[0].dtype,
                device=attn_biases[0].device
962
963
964
965
966
967
968
969
            )
            # [b, np, sq, 2*cp, sk//(2*cp)] -> [b, np, 2, sq//2, 2*cp, sk//(2*cp)]
            attn_dbias_ = attn_dbias.view(
                *attn_dbias.shape[:-3], 2, attn_dbias.shape[-3]//2, *attn_dbias.shape[-2:]
            )
        else:
            attn_dbias = None

970
        if causal:
971
972
973
974
975
976
977
978
979
980
981
            if ctx.qkv_format == "thd":
                softmax_lse_ = tex.thd_read_second_half_lse(softmax_lse, cu_seqlens_q, q.size(0))
            else:
                # [b, np, sq] -> [b, np, 2, sq//2]
                softmax_lse_ = \
                    softmax_lse.view(*softmax_lse.shape[:-1], 2, softmax_lse.shape[-1]//2)
                softmax_lse_ = softmax_lse_[..., 1, :].contiguous()
                if ctx.use_fused_attention:
                    # [b, np, sq//2] -> [b, np, sq//2, 1]
                    softmax_lse_.unsqueeze_(-1)

982
983
984
        if ctx.use_fused_attention:
            # [b, np, sq] -> [b, np, sq, 1]
            softmax_lse.unsqueeze_(-1)
985
986
987
988
989
990
991
992
993
994
        out = out.view(*q.shape)
        dout = dout.view(*q.shape)
        # Flash Attn outputs
        dq = torch.empty_like(q)

        p2p_comm_buffers = [torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device), \
                            torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device)]
        p2p_comm_buffers[0][0].copy_(kv)
        send_recv_reqs = []

995
996
997
998
999
1000
        fa_optional_backward_kwargs = {}
        if _flash_attn_2_4_plus:
            fa_optional_backward_kwargs["alibi_slopes"] = None
        if _flash_attn_2_4_1_plus:
            fa_optional_backward_kwargs["deterministic"] = ctx.deterministic

1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
        for i in range(cp_size):
            # wait until KV is received
            for req in send_recv_reqs:
                req.wait()

            send_tensor = p2p_comm_buffers[i%2]
            recv_tensor = p2p_comm_buffers[(i+1)%2]
            if i == 0:
                send_tensor = send_tensor[0]
                recv_tensor = recv_tensor[0]
            if i == (cp_size-1):
                send_tensor = send_tensor[1]
                recv_tensor = recv_tensor[1]

            send_recv_reqs = flash_attn_p2p_communicate(rank,
                                                        send_tensor,
                                                        send_dst,
                                                        recv_tensor,
                                                        recv_src,
                                                        ctx.cp_group,
                                                        batch_p2p_comm)

            kv = p2p_comm_buffers[i%2][0]
            # In reversed order of fwd
1025
            if causal:
1026
                if i == (cp_size-1):
1027
                    if ctx.use_fused_attention:
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
                        if ctx.qkv_format == "bshd":
                            # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
                            q_ = q.view(q.shape[0], -1, *q.shape[-2:])
                            # [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn]
                            kv_ = kv.view(*kv.shape[0:2], -1, *kv.shape[-2:])
                            # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
                            out_ = out.view(out.shape[0], -1, *out.shape[-2:])
                            dout_ = dout.view(dout.shape[0], -1, *dout.shape[-2:])
                        elif ctx.qkv_format == "sbhd":
                            # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
                            q_ = q.view(-1, *q.shape[-3:])
                            # [2, 2, sk//2, b, np, hn] -> [2, sk, b, np, hn]
                            kv_ = kv.view(kv.shape[0], -1, *kv.shape[-3:])
                            # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
                            out_ = out.view(-1, *out.shape[-3:])
                            dout_ = dout.view(-1, *dout.shape[-3:])
1044
1045
                        elif ctx.qkv_format == "thd":
                            q_, kv_, out_, dout_ = q, kv, out, dout
1046
                        aux_ctx_tensors = [softmax_lse, rng_states[cp_size-i-1]]
1047
                        if attn_dbias is not None:
1048
                            aux_ctx_tensors += [attn_biases[cp_size-i-1]]
1049
                        dq_, dk_, dv_, dbias_ = fused_attn_bwd(
1050
1051
                            ctx.max_seqlen_q, ctx.max_seqlen_k,
                            cu_seqlens_q, cu_seqlens_k,
1052
1053
                            q_, kv_[0], kv_[1], out_, dout_,
                            TE_DType[q.dtype], TE_DType[kv.dtype], aux_ctx_tensors,
1054
                            tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
1055
                            seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
1056
1057
                            attn_scale=ctx.softmax_scale,
                            dropout=ctx.dropout_p,
1058
                            qkv_layout=qkv_layout,
1059
                            attn_mask_type=ctx.attn_mask_type,
1060
                            attn_bias_type=ctx.attn_bias_type,
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
                        )
                    else:
                        # [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
                        q_ = q.view(-1, *q.shape[-2:])
                        dq_ = torch.empty_like(q_)
                        # [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn]
                        kv_ = kv.view(2, -1, *kv.shape[-2:])
                        dkv_ = torch.empty_like(kv_)
                        # [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
                        out_ = out.view(-1, *out.shape[-2:])
                        dout_ = dout.view(-1, *dout.shape[-2:])
                        if _flash_attn_2_3_plus:
                            fa_optional_backward_kwargs["window_size"] = [-1, 0]
                        _flash_attn_backward(
                            dout_, q_, kv_[0], kv_[1], out_, softmax_lse,
                            dq_, dkv_[0], dkv_[1], cu_seqlens_q, cu_seqlens_k,
                            ctx.max_seqlen_q, ctx.max_seqlen_k,
                            ctx.dropout_p, ctx.softmax_scale, True,
1079
                            rng_state=rng_states[cp_size-i-1],
1080
1081
1082
1083
                            **fa_optional_backward_kwargs
                        )
                elif i >= (cp_size-rank-1):
                    if ctx.use_fused_attention:
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
                        if ctx.qkv_format == "bshd":
                            # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
                            q_ = q.view(q.shape[0], -1, *q.shape[-2:])
                            # [2, b, 2, sk//2, np, hn] -> [2, b, sk//2, np, hn]
                            kv_ = kv[:, :, 0, ...].contiguous()
                            # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
                            out_ = out.view(out.shape[0], -1, *out.shape[-2:])
                            dout_ = dout.view(dout.shape[0], -1, *dout.shape[-2:])
                        elif ctx.qkv_format == "sbhd":
                            # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
                            q_ = q.view(-1, *q.shape[-3:])
                            # [2, 2, sk//2, b, np, hn] -> [2, sk//2, b, np, hn]
                            kv_ = kv[:, 0, ...].contiguous()
                            # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
                            out_ = out.view(-1, *out.shape[-3:])
                            dout_ = dout.view(-1, *dout.shape[-3:])
1100
1101
1102
1103
                        elif ctx.qkv_format == "thd":
                            q_, out_, dout_ = q, out, dout
                            # [2, t, np, hn] -> [2, t/2, np, hn]
                            kv_ = tex.thd_read_half_tensor(kv, cu_seqlens_k, 0)
1104
                        aux_ctx_tensors = [softmax_lse, rng_states[cp_size-i-1]]
1105
                        if attn_dbias is not None:
1106
                            aux_ctx_tensors += [attn_biases[cp_size-i-1]]
1107
                        dq_, dk_, dv_, dbias_ = fused_attn_bwd(
1108
1109
                            ctx.max_seqlen_q, ctx.max_seqlen_k//2,
                            cu_seqlens_q, cu_seqlens_k//2,
1110
1111
                            q_, kv_[0], kv_[1], out_, dout_,
                            TE_DType[q.dtype], TE_DType[kv.dtype], aux_ctx_tensors,
1112
                            tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
1113
1114
                            seq_offsets_q, None if seq_offsets_k is None else seq_offsets_k//2,
                            None if seq_offsets_v is None else seq_offsets_v//2, seq_offsets_o,
1115
1116
                            attn_scale=ctx.softmax_scale,
                            dropout=ctx.dropout_p,
1117
                            qkv_layout=qkv_layout,
1118
                            attn_mask_type="padding" if padding else "no_mask",
1119
                            attn_bias_type=ctx.attn_bias_type,
1120
1121
1122
1123
1124
                        )
                    else:
                        # [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
                        q_ = q.view(-1, *q.shape[-2:])
                        dq_ = torch.empty_like(q_)
1125
1126
1127
1128
1129
1130
                        if ctx.qkv_format == "thd":
                            # [2, t, np, hn] -> [2, t/2, np, hn]
                            kv_ = tex.thd_read_half_tensor(kv, cu_seqlens_k, 0)
                        else:
                            # [2, b, 2, sk//2, np, hn]->[2, b, sk//2, np, hn]->[2, b*sk//2, np, hn]
                            kv_ = kv[:, :, 0, ...].contiguous().view(2, -1, *kv.shape[-2:])
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
                        dkv_ = torch.empty_like(kv_)
                        # [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
                        out_ = out.view(-1, *out.shape[-2:])
                        dout_ = dout.view(-1, *dout.shape[-2:])
                        if _flash_attn_2_3_plus:
                            fa_optional_backward_kwargs["window_size"] = [-1, -1]
                        _flash_attn_backward(
                            dout_, q_, kv_[0], kv_[1], out_, softmax_lse,
                            dq_, dkv_[0], dkv_[1], cu_seqlens_q, cu_seqlens_k//2,
                            ctx.max_seqlen_q, ctx.max_seqlen_k//2,
                            ctx.dropout_p, ctx.softmax_scale, False,
1142
                            rng_state=rng_states[cp_size-i-1],
1143
1144
1145
1146
                            **fa_optional_backward_kwargs
                        )
                else:
                    if ctx.use_fused_attention:
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
                        if ctx.qkv_format == "bshd":
                            # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
                            q_ = q[:, 1, ...].contiguous()
                            # [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn]
                            kv_ = kv.view(*kv.shape[0:2], -1, *kv.shape[-2:])
                            # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
                            out_ = out[:, 1, ...].contiguous()
                            dout_ = dout[:, 1, ...].contiguous()
                        elif ctx.qkv_format == "sbhd":
                            # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn]
                            q_ = q[1].contiguous()
                            # [2, 2, sk//2, b, np, hn] -> [2, sk, b, np, hn]
                            kv_ = kv.view(kv.shape[0], -1, *kv.shape[-3:])
                            # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn]
                            out_ = out[1].contiguous()
                            dout_ = dout[1].contiguous()
1163
1164
1165
1166
1167
1168
                        elif ctx.qkv_format == "thd":
                            # [t, np, hn] -> [t/2, np, hn]
                            q_ = tex.thd_read_half_tensor(q, cu_seqlens_q, 1)
                            out_ = tex.thd_read_half_tensor(out, cu_seqlens_q, 1)
                            dout_ = tex.thd_read_half_tensor(dout, cu_seqlens_q, 1)
                            kv_ = kv
1169
                        aux_ctx_tensors = [softmax_lse_, rng_states[cp_size-i-1]]
1170
                        if attn_dbias is not None:
1171
                            aux_ctx_tensors += [attn_biases[cp_size-i-1]]
1172
                        dq_, dk_, dv_, dbias_ = fused_attn_bwd(
1173
1174
                            ctx.max_seqlen_q//2, ctx.max_seqlen_k,
                            cu_seqlens_q//2, cu_seqlens_k,
1175
1176
                            q_, kv_[0], kv_[1], out_, dout_,
                            TE_DType[q.dtype], TE_DType[kv.dtype], aux_ctx_tensors,
1177
                            tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
1178
1179
                            None if seq_offsets_q is None else seq_offsets_q//2, seq_offsets_k,
                            seq_offsets_v, None if seq_offsets_o is None else seq_offsets_o//2,
1180
1181
                            attn_scale=ctx.softmax_scale,
                            dropout=ctx.dropout_p,
1182
                            qkv_layout=qkv_layout,
1183
                            attn_mask_type="padding" if padding else "no_mask",
1184
                            attn_bias_type=ctx.attn_bias_type,
1185
1186
                        )
                    else:
1187
1188
1189
1190
1191
1192
                        if ctx.qkv_format == "thd":
                            # [t, np, hn] -> [t/2, np, hn]
                            q_ = tex.thd_read_half_tensor(q, cu_seqlens_q, 1)
                        else:
                            # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] -> [b*sq//2, np, hn]
                            q_ = q[:, 1, ...].contiguous().view(-1, *q.shape[-2:])
1193
1194
1195
1196
                        dq_ = torch.empty_like(q_)
                        # [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn]
                        kv_ = kv.view(2, -1, *kv.shape[-2:])
                        dkv_ = torch.empty_like(kv_)
1197
1198
1199
1200
1201
1202
1203
                        if ctx.qkv_format == "thd":
                            out_ = tex.thd_read_half_tensor(out, cu_seqlens_q, 1)
                            dout_ = tex.thd_read_half_tensor(dout, cu_seqlens_q, 1)
                        else:
                            # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] -> [b*sq//2, np, hn]
                            out_ = out[:, 1, ...].contiguous().view(-1, *out.shape[-2:])
                            dout_ = dout[:, 1, ...].contiguous().view(-1, *dout.shape[-2:])
1204
1205
1206
1207
1208
1209
1210
                        if _flash_attn_2_3_plus:
                            fa_optional_backward_kwargs["window_size"] = [-1, -1]
                        _flash_attn_backward(
                            dout_, q_, kv_[0], kv_[1], out_, softmax_lse_,
                            dq_, dkv_[0], dkv_[1], cu_seqlens_q//2, cu_seqlens_k,
                            ctx.max_seqlen_q//2, ctx.max_seqlen_k,
                            ctx.dropout_p, ctx.softmax_scale, False,
1211
                            rng_state=rng_states[cp_size-i-1],
1212
1213
1214
1215
                            **fa_optional_backward_kwargs
                        )
            else:
                if ctx.use_fused_attention:
1216
                    aux_ctx_tensors = [softmax_lse, rng_states[cp_size-i-1]]
1217
                    if attn_dbias is not None:
1218
                        aux_ctx_tensors += [attn_biases[cp_size-i-1]]
1219
                    dq_, dk_, dv_, dbias_ = fused_attn_bwd(
1220
1221
                        ctx.max_seqlen_q, ctx.max_seqlen_k,
                        cu_seqlens_q, cu_seqlens_k,
1222
1223
                        q, kv[0], kv[1], out, dout,
                        TE_DType[q.dtype], TE_DType[kv.dtype], aux_ctx_tensors,
1224
                        tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
1225
                        seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
1226
1227
                        attn_scale=ctx.softmax_scale,
                        dropout=ctx.dropout_p,
1228
                        qkv_layout=qkv_layout,
1229
                        attn_mask_type=ctx.attn_mask_type,
1230
                        attn_bias_type=ctx.attn_bias_type,
1231
1232
1233
                    )
                else:
                    # [b, sq, np, hn] -> [b*sq, np, hn]
1234
1235
                    q_ = q.view(-1, *q.shape[-2:])
                    dq_ = torch.empty_like(q_)
1236
                    # [2, b, sk, np, hn] -> [2, b*sk, np, hn]
1237
1238
                    kv_ = kv.view(2, -1, *kv.shape[-2:])
                    dkv_ = torch.empty_like(kv_)
1239
                    # [b, sq, np, hn] -> [b*sq, np, hn]
1240
1241
                    out_ = out.view(-1, *out.shape[-2:])
                    dout_ = dout.view(-1, *dout.shape[-2:])
1242
1243
                    if _flash_attn_2_3_plus:
                        fa_optional_backward_kwargs["window_size"] = [-1, -1]
1244
1245
1246
1247
1248
                    _flash_attn_backward(
                        dout_, q_, kv_[0], kv_[1], out_, softmax_lse,
                        dq_, dkv_[0], dkv_[1], cu_seqlens_q, cu_seqlens_k,
                        ctx.max_seqlen_q, ctx.max_seqlen_k,
                        ctx.dropout_p, ctx.softmax_scale, False,
1249
                        **fa_optional_backward_kwargs
1250
1251
                    )

1252
            if i >= (cp_size-rank-1) or not causal:
1253
1254
1255
1256
                # [b*sq, np, hn] -> [b, 2, sq//2, np, hn] if causal
                # [b*sq, np, hn] -> [b, sq, np, hn] if not causal
                dq_ = dq_.view(*dq.shape)
            else:
1257
1258
1259
1260
1261
1262
                if ctx.qkv_format == "bshd":
                    # [b*sq//2, np, hn] -> [b, sq//2, np, hn]
                    dq_ = dq_.view(dq.shape[0], *dq.shape[2:])
                elif ctx.qkv_format == "sbhd":
                    # [b*sq//2, np, hn] -> [sq//2, b, np, hn]
                    dq_ = dq_.view(-1, *dq.shape[-3:])
1263

1264
            if causal:
1265
1266
1267
1268
1269
1270
                if i > (cp_size-rank-1):
                    dq.add_(dq_)
                elif i == (cp_size-rank-1):
                    if rank == (cp_size-1):
                        dq.copy_(dq_)
                    else:
1271
1272
1273
1274
1275
1276
                        if ctx.qkv_format == "bshd":
                            dq[:, 0, ...].copy_(dq_[:, 0, ...])
                            dq[:, 1, ...].add_(dq_[:, 1, ...])
                        elif ctx.qkv_format == "sbhd":
                            dq[0].copy_(dq_[0])
                            dq[1].add_(dq_[1])
1277
1278
                        elif ctx.qkv_format == "thd":
                            tex.thd_grad_correction(dq, dq_, cu_seqlens_q, "copy", "add")
1279
                elif i > 0:
1280
1281
1282
1283
                    if ctx.qkv_format == "bshd":
                        dq[:, 1, ...].add_(dq_)
                    elif ctx.qkv_format == "sbhd":
                        dq[1].add_(dq_)
1284
1285
                    elif ctx.qkv_format == "thd":
                        tex.thd_grad_correction(dq, dq_, cu_seqlens_q, "none", "add")
1286
                else:
1287
1288
1289
1290
                    if ctx.qkv_format == "bshd":
                        dq[:, 1, ...].copy_(dq_)
                    elif ctx.qkv_format == "sbhd":
                        dq[1].copy_(dq_)
1291
1292
                    elif ctx.qkv_format == "thd":
                        tex.thd_grad_correction(dq, dq_, cu_seqlens_q, "none", "copy")
1293
1294
1295
1296
1297
            else:
                if i == 0:
                    dq.copy_(dq_)
                else:
                    dq.add_(dq_)
1298

1299
1300
            if attn_dbias is not None:
                idx = (rank+i+1)%cp_size
1301
                if i == (cp_size - 1) or not causal:
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
                    # [b, np, sq, sk//cp] -> [b, np, sq, 2, sk//(2*cp)]
                    dbias_ = dbias_.view(*dbias_.shape[:-1], 2, dbias_.shape[-1]//2)
                    attn_dbias[..., idx, :].copy_(dbias_[..., 0, :])
                    attn_dbias[..., (2*cp_size-idx-1), :].copy_(dbias_[..., 1, :])
                elif i >= (cp_size-rank-1):
                    # [b, np, sq, sk//(2*cp)]
                    attn_dbias[..., idx, :].copy_(dbias_)
                else:
                    # [b, np, sq//2, sk//cp] -> [b, np, sq//2, 2, sk//(2*cp)]
                    dbias_ = dbias_.view(*dbias_.shape[:-1], 2, dbias_.shape[-1]//2)
                    attn_dbias_[..., 1, :, idx, :].copy_(dbias_[..., 0, :])
                    attn_dbias_[..., 1, :, (2*cp_size-idx-1), :].copy_(dbias_[..., 1, :])

1315
1316
1317
            # wait until dKV is received
            for req in send_recv_reqs:
                req.wait()
1318

1319
1320
1321
            dkv = p2p_comm_buffers[(i+1)%2][1]
            if ctx.use_fused_attention:
                dkv_ = torch.cat((dk_.unsqueeze(0), dv_.unsqueeze(0)), dim=0)
1322
            if causal and i >= (cp_size-rank-1) and i != (cp_size-1):
1323
1324
1325
1326
1327
1328
                if ctx.qkv_format == "bshd":
                    # [2, b*sk//2, np, hn] -> [2, b, sk//2, np, hn]
                    dkv_ = dkv_.view(*dkv.shape[0:2], *dkv.shape[3:])
                elif ctx.qkv_format == "sbhd":
                    # [2, b*sk//2, np, hn] -> [2, sk//2, b, np, hn]
                    dkv_ = dkv_.view(dkv.shape[0], -1, *dkv.shape[-3:])
1329
1330
1331
1332
            else:
                # [2, b*sk, np, hn] -> [2, b, 2, sk//2, np, hn] if causal
                # [2, b*sk, np, hn] -> [2, b, sk, np, hn] if not causal
                dkv_ = dkv_.view(*dkv.shape)
1333

1334
            if causal:
1335
1336
                if i == (cp_size-1):
                    if rank == 0:
1337
1338
1339
1340
1341
1342
                        if ctx.qkv_format == "bshd":
                            dkv[:, :, 0, ...].add_(dkv_[:, :, 0, ...])
                            dkv[:, :, 1, ...].copy_(dkv_[:, :, 1, ...])
                        elif ctx.qkv_format == "sbhd":
                            dkv[:, 0, ...].add_(dkv_[:, 0, ...])
                            dkv[:, 1, ...].copy_(dkv_[:, 1, ...])
1343
1344
                        elif ctx.qkv_format == "thd":
                            tex.thd_grad_correction(dkv, dkv_, cu_seqlens_k, "add", "copy")
1345
1346
1347
1348
                    else:
                        dkv.add_(dkv_)
                elif i >= (cp_size-rank-1):
                    if i == 0 and rank == (cp_size-1):
1349
1350
1351
1352
                        if ctx.qkv_format == "bshd":
                            dkv[:, :, 0, ...].copy_(dkv_)
                        elif ctx.qkv_format == "sbhd":
                            dkv[:, 0, ...].copy_(dkv_)
1353
1354
                        elif ctx.qkv_format == "thd":
                            tex.thd_grad_correction(dkv, dkv_, cu_seqlens_k, "copy", "none")
1355
                    else:
1356
1357
1358
1359
                        if ctx.qkv_format == "bshd":
                            dkv[:, :, 0, ...].add_(dkv_)
                        elif ctx.qkv_format == "sbhd":
                            dkv[:, 0, ...].add_(dkv_)
1360
1361
                        elif ctx.qkv_format == "thd":
                            tex.thd_grad_correction(dkv, dkv_, cu_seqlens_k, "add", "none")
1362
1363
1364
1365
1366
                elif i > 0:
                    dkv.add_(dkv_)
                else:
                    dkv.copy_(dkv_)
            else:
1367
1368
1369
1370
1371
                if i == 0:
                    dkv.copy_(dkv_)
                else:
                    dkv.add_(dkv_)

1372
        if causal:
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
            if ctx.qkv_format == "bshd":
                # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
                dq = dq.view(q.shape[0], -1, *q.shape[-2:])
                # [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn]
                dkv = dkv.view(*kv.shape[0:2], -1, *kv.shape[-2:])
            elif ctx.qkv_format == "sbhd":
                # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
                dq = dq.view(-1, *q.shape[-3:])
                # [2, 2, sk//2, b, np, hn] -> [2, sk, b, np, hn]
                dkv = dkv.view(kv.shape[0], -1, *kv.shape[-3:])

        if attn_dbias is not None:
            # [b, np, sq, 2*cp, sk//(2*cp)] -> [b, np, sq, sk]
            attn_dbias = attn_dbias.view(*attn_dbias.shape[:-2], -1)

1388
1389
        return None, dq, dkv[0], dkv[1], None, None, None, None, None, None, None, None, \
                None, None, None, None, None, None, None, None, attn_dbias, None, None
1390
1391
1392


def attn_forward_func_with_cp(
1393
    is_training, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
1394
1395
    seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, dropout_p,
    cp_group, cp_global_ranks, cp_stream, softmax_scale=None, qkv_format="bshd",
1396
1397
    attn_mask_type="causal", attn_bias_type="no_bias", attn_bias=None, deterministic=False,
    use_fused_attention=False
1398
1399
) -> torch.Tensor:
    """Attention implementation with context parallelism"""
1400
    assert(qkv_format in ["bshd", "sbhd", "thd"]
1401
1402
1403
        ), f"QKV format of {qkv_format} is not supported with context parallelism!"
    assert(qkv_format != "sbhd" or use_fused_attention
        ), "FlashAttention does not support sbhd format!"
1404
1405
1406
1407
1408
1409
1410
1411
    assert (qkv_format != 'thd' or \
            not use_fused_attention or \
            attn_mask_type in ["padding", "padding_causal"]
        ), f"""Context parallelism is not supported for {attn_mask_type} mask type and """ \
    f"""{qkv_format} format with {"FusedAttention" if use_fused_attention else "FlashAttention"}!"""
    assert (attn_bias is None or (use_fused_attention and "padding" not in attn_mask_type)
        ), """Attention bias is only supported with FusedAttention and "causal" """ \
           """or "no_mask" mask types!"""
1412
1413
    out = AttnFuncWithCP.apply(
        is_training, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
1414
1415
1416
        seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, dropout_p,
        cp_group, cp_global_ranks, cp_stream, softmax_scale, qkv_format, attn_mask_type,
        attn_bias_type, attn_bias, deterministic, use_fused_attention
1417
1418
1419
1420
    )
    return out


1421
1422
1423
1424
1425
1426
1427
class RotaryPositionEmbedding(torch.nn.Module):
    """
    Implements Rotary Position Embedding from https://arxiv.org/abs/2104.09864.
    """
    def __init__(
        self,
        dim: int,
1428
        rotary_percent: float = 1.0,
1429
1430
1431
1432
1433
1434
1435
1436
        seq_len_interpolation_factor: Optional[int] = None,
        pretrained_max_position_embeddings: Optional[int] = None,
    ):
        """
        Parameters
        ----------
        dim: int
            rotary embedding dimension
1437
1438
        rotary_percent: float
            Percent of rotary dimension to use for rotary position embeddings.
1439
1440
1441
1442
1443
1444
1445
        seq_len_interpolation_factor: int
            if not None, discrete positions will be interpolated by this factor via the trick in
            https://arxiv.org/abs/2306.15595
        pretrained_max_position_embeddings: int
            pre-trained max_position_embeddings before position interpolation
        """
        super().__init__()
1446
1447
        if rotary_percent < 1.0:
            dim = int(dim * rotary_percent)
1448
        self.seq_len_interpolation_factor = seq_len_interpolation_factor
1449
1450
1451
1452
1453
1454
1455
        inv_freq = 1.0 / (
            10000
            ** (
                torch.arange(0, dim, 2, dtype=torch.float32, device=torch.cuda.current_device())
                / dim
            )
        )
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
        self.register_buffer('inv_freq', inv_freq)
        self.pretrained_max_position_embeddings = pretrained_max_position_embeddings

    def forward(self, max_seq_len: int, offset: int = 0):
        """
        Create rotary position embedding frequencies

        Parameters
        ----------
        max_seq_len: int
            sequence length of a sample
        offset: int, default = 0
            fixed offset for freqencies
        """
1470
1471
1472
1473
        seq = (
            torch.arange(max_seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
            + offset
        )
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491

        if (self.pretrained_max_position_embeddings is not None
            and self.seq_len_interpolation_factor is not None):
            if (max_seq_len >
                self.pretrained_max_position_embeddings * self.seq_len_interpolation_factor):
                # dynamic linear scaling (length > position we have learned)
                seq *= 1 / (max_seq_len / self.pretrained_max_position_embeddings)
            else:
                # fixed linear scaling
                seq *= 1 / self.seq_len_interpolation_factor

        freqs = torch.einsum('i , j -> i j', seq, self.inv_freq)
        # first part even vector components, second part odd vector components,
        #  2 * dim in dimension size
        emb = torch.cat((freqs, freqs), dim=-1)
        # emb [seq_length, .., dim]
        return emb.reshape(emb.size(0), 1, 1, emb.size(1))

1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509

class FusedRoPEFunc(torch.autograd.Function):
    """
    Function for FusedRoPE

    This implementation assumes the input tensor to be in `sbhd`, `bshd` or `thd` format and
    the RoPE tensor to be of shape (s, 1, 1, d). It accepts arbitrary memory layouts to avoid
    the expensive `.contiguous()` calls, thus it may not achieve the best memory access pattern.
    """

    @staticmethod
    def forward(
        ctx,
        t: torch.Tensor,
        freqs: torch.Tensor,
        tensor_format: str = "sbhd",
        cu_seqlens: Union[torch.Tensor, None] = None,
    ) -> torch.Tensor:
1510
1511
        if freqs.dtype != torch.float32:
            freqs = freqs.float()
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
        if tensor_format == "sbhd":
            output = tex.fused_rope_forward(t, freqs, False)
        elif tensor_format == "bshd":
            output = tex.fused_rope_forward(
                t.transpose(0, 1), freqs, True
            ).transpose(0, 1)
        elif tensor_format == "thd":
            output = tex.fused_rope_thd_forward(t, cu_seqlens, freqs)
        else:
            raise ValueError(f"Unsupported tensor_format: {tensor_format}.")
        ctx.save_for_backward(freqs, cu_seqlens)
        ctx.tensor_format = tensor_format

        return output

    @staticmethod
    def backward(
        ctx, grad_output: torch.Tensor
    ) -> Tuple[Union[torch.Tensor, None], ...]:
        freqs, cu_seqlens = ctx.saved_tensors
        if ctx.tensor_format == "sbhd":
            grad_input = tex.fused_rope_backward(grad_output, freqs, False)
        elif ctx.tensor_format == "bshd":
            grad_input = tex.fused_rope_backward(
                grad_output.transpose(0, 1), freqs, True
            ).transpose(0, 1)
        elif ctx.tensor_format == "thd":
            grad_input = tex.fused_rope_thd_backward(grad_output, cu_seqlens, freqs)
        else:
            raise ValueError(f"Unsupported tensor_format: {ctx.tensor_format}.")

        return grad_input, None, None, None, None


1546
1547
1548
1549
1550
1551
1552
1553
1554
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
    """
    change sign so the last dimension becomes [-odd, +even]
    """
    x = x.view(x.shape[:-1] + torch.Size((2, x.shape[-1] // 2)))
    x1, x2 = x.unbind(dim=-2)
    return torch.cat((-x2, x1), dim=-1)


1555
def apply_rotary_pos_emb(
1556
1557
1558
1559
1560
1561
    t: torch.Tensor,
    freqs: torch.Tensor,
    tensor_format: str = "sbhd",
    fused: bool = False,
    cu_seqlens: Union[torch.Tensor, None] = None,
) -> torch.Tensor:
1562
    """
1563
    Apply rotary positional embedding tensor to the input tensor.
1564

1565
1566
1567
    Parameters
    ----------
    t: torch.Tensor
1568
        Input tensor of shape `[s, b, h, d]`, `[b, s, h, d]` or `[t, h, d]`, on which
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
        rotary positional embedding will be applied.
    freqs: torch.Tensor
        Rotary positional embedding tensor of shape `[s2, 1, 1, d2]` and dtype 'float',
        with `s2 >= s` and `d2 <= d`.
    fused: bool, default = False
        Whether to use a fused applying RoPE implementation.
    tensor_format: {'sbhd', 'bshd', 'thd'}, default = 'sbhd'
        is `bshd` if `t` is of shape `[bs, seq, ...]`, or `sbhd` if `t` is
        of shape `[seq, bs, ...]`. 'thd' is only supported when `fused` is True.
    cu_seqlens: torch.Tensor, default = None.
        Cumulative sum of sequence lengths in a batch for `t`, with shape [b + 1] and
        dtype torch.int32. Only valid when `tensor_format` is 'thd'.
1581
    """
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
    if fused:
        assert (
            tensor_format != "thd" or cu_seqlens is not None
        ), "cu_seqlens must not be None when tensor_format is 'thd'."
        return FusedRoPEFunc.apply(t, freqs, tensor_format, cu_seqlens)

    assert tensor_format in ("sbhd", "bshd"), (
        "Only formats `sbhd` or `bshd` are supported for input tensor `t` "
        f"when fused is False, got {tensor_format}."
    )

1593
1594
1595
1596
1597
    max_seq_len = freqs.shape[0]
    cur_seq_len = t.shape[1] if tensor_format == "bshd" else t.shape[0]

    # Only apply the rotary embeddings up to the sequence length of the running
    # input.
1598
1599
1600
1601
    assert cur_seq_len <= max_seq_len, (
        f"Rotary Embeddings only supported up to {max_seq_len} sequence length!"
    )
    freqs = freqs[:cur_seq_len]
1602
    if tensor_format == "bshd":
1603
1604
1605
1606
        freqs = freqs.transpose(0, 1)  # [seq, 1, 1, dim] -> [1, seq, 1, dim]
    # cos/sin first then dtype conversion for better precision
    cos_ = torch.cos(freqs).to(t.dtype)
    sin_ = torch.sin(freqs).to(t.dtype)
1607

1608
1609
1610
1611
1612
1613
    rot_dim = freqs.shape[-1]
    # ideally t_pass is empty so rotary pos embedding is applied to all tensor t
    t, t_pass = t[..., :rot_dim], t[..., rot_dim:]

    # first part is cosine component
    # second part is sine component, need to change signs with _rotate_half method
1614
    t = (t * cos_) + (_rotate_half(t) * sin_)
1615
1616
1617
    return torch.cat((t, t_pass), dim=-1)


cyanguwa's avatar
cyanguwa committed
1618
class _SplitAlongDim(torch.autograd.Function):
1619
1620
1621
1622
1623
    """"""

    @staticmethod
    def forward(ctx,
                mixed_x_layer: torch.Tensor,
cyanguwa's avatar
cyanguwa committed
1624
1625
                split_dim: int,
                split_size_or_sections: Union[int, List[int], Tuple[int]],
1626
    ) -> Tuple[torch.Tensor, ...]:
cyanguwa's avatar
cyanguwa committed
1627
1628
        ctx.split_dim = split_dim
        ctx.split_size_or_sections = split_size_or_sections
1629
1630
1631
1632
1633
1634
1635
1636
        if isinstance(mixed_x_layer, Float8Tensor):
            return tuple(Float8Tensor.make_like(
                mixed_x_layer,
                data=x,
                ) for x in torch.split(
                    mixed_x_layer._data,
                    split_size_or_sections=split_size_or_sections,
                    dim=split_dim))
cyanguwa's avatar
cyanguwa committed
1637
        return torch.split(mixed_x_layer, split_size_or_sections, dim = split_dim)
1638
1639
1640
1641
1642
1643

    @staticmethod
    def backward(ctx,
                 *grad_outputs):
        assert len(grad_outputs) > 0, "No gradients received for backprop!"

cyanguwa's avatar
cyanguwa committed
1644
1645
1646
1647
1648
1649
1650
1651
1652
        if isinstance(ctx.split_size_or_sections, (list, tuple)):
            split_sizes = ctx.split_size_or_sections
            assert (len(grad_outputs) == len(split_sizes)
                ), "Unequal number of gradients vs split sections for backprop!"
        if isinstance(ctx.split_size_or_sections, int):
            split_sizes = [ctx.split_size_or_sections] * len(grad_outputs)
        dims = len(grad_outputs[0].shape)
        split_dim = (ctx.split_dim + dims) % dims

1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
        if isinstance(grad_outputs[0], Float8Tensor):
            noop_ok = True
            strides = grad_outputs[0].stride()
            data_ptr = grad_outputs[0]._data.untyped_storage().data_ptr()
            shape = list(grad_outputs[0].shape)
            for i, tensor in enumerate(grad_outputs):
                shape_i = shape
                shape_i[split_dim] = split_sizes[i]
                offset_size = sum(split_sizes[:i]) * np.prod(shape[split_dim+1:])
                if (tensor.stride() != strides or
                    list(tensor.shape) != shape_i or
                    tensor._data.untyped_storage().data_ptr() != data_ptr or
                    tensor.storage_offset() != offset_size):
                    noop_ok = False
                    break
            if noop_ok:
                ret = torch.Tensor().to(device=grad_outputs[0].device,
                                        dtype=grad_outputs[0]._data.dtype)
                new_shape = list(shape)
                new_shape[split_dim] = sum(split_sizes)
                ret.set_(grad_outputs[0]._data.untyped_storage(),
                         grad_outputs[0]._data.storage_offset(),
                         new_shape,
                         strides
                )
                return Float8Tensor.make_like(grad_outputs[0], data=ret), None, None

            grad_outputs_data = [x._data for x in grad_outputs]
            return Float8Tensor.make_like(
                grad_outputs[0],
                data=torch.cat(grad_outputs_data, dim = split_dim)), None, None
1684
1685
        noop_ok = True
        strides = grad_outputs[0].stride()
1686
        data_ptr = grad_outputs[0].untyped_storage().data_ptr()
cyanguwa's avatar
cyanguwa committed
1687
        shape = list(grad_outputs[0].shape)
1688
        for i, tensor in enumerate(grad_outputs):
cyanguwa's avatar
cyanguwa committed
1689
1690
1691
            shape_i = shape
            shape_i[split_dim] = split_sizes[i]
            offset_size = sum(split_sizes[:i]) * np.prod(shape[split_dim+1:])
1692
            if (tensor.stride() != strides or
cyanguwa's avatar
cyanguwa committed
1693
                list(tensor.shape) != shape_i or
1694
                tensor.untyped_storage().data_ptr() != data_ptr or
cyanguwa's avatar
cyanguwa committed
1695
                tensor.storage_offset() != offset_size):
1696
1697
1698
1699
1700
1701
                noop_ok = False
                break
        if noop_ok:
            ret = torch.Tensor().to(device=grad_outputs[0].device,
                                    dtype=grad_outputs[0].dtype)
            new_shape = list(shape)
cyanguwa's avatar
cyanguwa committed
1702
1703
            new_shape[split_dim] = sum(split_sizes)
            ret.set_(grad_outputs[0].untyped_storage(),
1704
1705
                     grad_outputs[0].storage_offset(),
                     new_shape,
cyanguwa's avatar
cyanguwa committed
1706
                     strides
1707
            )
cyanguwa's avatar
cyanguwa committed
1708
            return ret, None, None
1709

cyanguwa's avatar
cyanguwa committed
1710
        return torch.cat(grad_outputs, dim = split_dim), None, None
1711
1712
1713
1714
1715
1716
1717
1718
1719


class UnfusedDotProductAttention(torch.nn.Module):
    """Parallel attention w/o QKV and Proj Gemms
    BMM1 -> softmax + dropout -> BMM2
    """

    def __init__(
        self,
1720
        softmax_scale: float,
1721
1722
1723
1724
1725
1726
        attention_dropout: float = 0.0,
        attention_dropout_ctx: Optional[Callable] = nullcontext,
        layer_number: Optional[int] = None,
    ) -> None:
        super().__init__()

1727
        self.softmax_scale = softmax_scale
1728
1729
1730
        self.attention_dropout_ctx = attention_dropout_ctx
        self.layer_number = layer_number

1731
        self.scale_mask_softmax = FusedScaleMaskSoftmax(attention_mask_func)
1732
1733
1734
1735
1736
1737

        # Dropout. Note that for a single iteration, this layer will generate
        # different outputs on different number of parallel partitions but
        # on average it should not be partition dependent.
        self.attention_dropout = torch.nn.Dropout(attention_dropout)

1738
1739
1740
1741
        # An FP16 training trick required for certain GPT-like models.
        self.apply_qk_layer_scaling = (
            bool(int(os.getenv("NVTE_APPLY_QK_LAYER_SCALING", "0"))) and layer_number is not None)

1742
1743
1744
1745
1746
    def forward(
        self,
        query_layer: torch.Tensor,
        key_layer: torch.Tensor,
        value_layer: torch.Tensor,
1747
1748
1749
        qkv_layout: str = "sbh3d",
        cu_seqlens_q: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
        cu_seqlens_kv: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
1750
        attn_mask_type: str = "causal",
1751
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
1752
1753
        core_attention_bias_type: str = "no_bias",
        core_attention_bias: Optional[torch.Tensor] = None,
1754
        alibi_slopes: Optional[torch.Tensor] = None,
1755
    ) -> torch.Tensor:
1756
        """Unfused attention fprop"""
1757

1758
1759
1760
1761
1762
1763
1764
        assert (qkv_layout in QKVLayouts
            ), f"UnfusedDotProductAttention does not support qkv_layout = {qkv_layout}!"
        qkv_format = ''.join([i for i in qkv_layout.split('_')[0] if i.isalpha()])
        if qkv_format == 'bshd':
            # convert to sbhd and use sbhd implementation for now
            query_layer, key_layer, value_layer = [x.transpose(0, 1)
                for x in [query_layer, key_layer, value_layer]]
1765

1766
        batch_size, seqlen = query_layer.shape[1], query_layer.shape[0]
1767
        apply_qk_layer_scaling = self.apply_qk_layer_scaling and key_layer.dtype == torch.float16
1768
1769
1770
1771
1772
1773
1774
1775
1776

        # [b, np, sq, sk]
        output_size = (
            query_layer.size(1),
            query_layer.size(2),
            query_layer.size(0),
            key_layer.size(0),
        )

1777
1778
1779
1780
1781
1782
1783
1784
        if key_layer.shape[2] != query_layer.shape[2]:
            assert (query_layer.shape[2]%key_layer.shape[2]==0
                ),"The number of attention heads must be divisible by the number of GQA groups!"
            key_layer = key_layer.repeat_interleave(
                    int(query_layer.shape[2]/key_layer.shape[2]), dim = 2)
            value_layer = value_layer.repeat_interleave(
                    int(query_layer.shape[2]/value_layer.shape[2]), dim = 2)

1785
1786
1787
1788
1789
1790
1791
1792
        # [sq, b, np, hn] -> [sq, b * np, hn]
        query_layer = query_layer.reshape(
            output_size[2], output_size[0] * output_size[1], -1
        )
        # [sk, b, np, hn] -> [sk, b * np, hn]
        key_layer = key_layer.reshape(output_size[3], output_size[0] * output_size[1], -1)

        # preallocting result tensor: [b * np, sq, sk]
1793
1794
        # WAR to set dtype to FP32 as ONNX lacks BF16 support for ConstantOfShape operator
        is_bf16 = query_layer.dtype == torch.bfloat16
1795
1796
1797
1798
        matmul_result = torch.empty(
            output_size[0] * output_size[1],
            output_size[2],
            output_size[3],
1799
            dtype=torch.float32 if is_in_onnx_export_mode() and is_bf16 else query_layer.dtype,
1800
1801
1802
            device=torch.cuda.current_device(),
        )

1803
1804
1805
        if is_in_onnx_export_mode() and is_bf16:
            matmul_result = matmul_result.bfloat16()

1806
        scale = self.softmax_scale
1807
        if apply_qk_layer_scaling:
1808
            scale /= self.layer_number
1809
1810

        # Raw attention scores. [b * np, sq, sk]
1811
1812
1813
1814
1815
1816
        if core_attention_bias_type == "no_bias":
            matmul_result = torch.baddbmm(
                matmul_result,
                query_layer.transpose(0, 1),  # [b * np, sq, hn]
                key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
                beta=0.0,
1817
                alpha=scale,
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
            )

        elif core_attention_bias_type == "pre_scale_bias":
            assert core_attention_bias is not None, "core_attention_bias should not be None!"
            matmul_result = torch.bmm(
                query_layer.transpose(0, 1),  # [b * np, sq, hn]
                key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
            )
            matmul_result = (matmul_result.view(
                output_size[0], output_size[1], output_size[2], output_size[3])
                + core_attention_bias).view(-1, output_size[2], output_size[3])
1829
            matmul_result *= scale
1830

1831
1832
1833
1834
        elif core_attention_bias_type in ["post_scale_bias", "alibi"]:
            if core_attention_bias_type == "post_scale_bias":
                assert core_attention_bias is not None, "core_attention_bias should not be None!"
            if core_attention_bias_type == "alibi":
1835
1836
                _, core_attention_bias = get_alibi(
                    output_size[1], output_size[2], output_size[3], alibi_slopes=alibi_slopes)
1837
1838
1839
1840
1841
            matmul_result = torch.baddbmm(
                matmul_result,
                query_layer.transpose(0, 1),  # [b * np, sq, hn]
                key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
                beta=0.0,
1842
                alpha=scale,
1843
1844
1845
            )
            matmul_result = (matmul_result.view(
                output_size[0], output_size[1], output_size[2], output_size[3])
1846
1847
                + core_attention_bias).view(-1, output_size[2], output_size[3]).to(
                dtype=query_layer.dtype)
1848
1849
1850
1851
1852
1853

        # change view to [b, np, sq, sk]
        attention_scores = matmul_result.view(*output_size)

        # attention scores and attention mask [b, np, sq, sk]
        softmax_scale = self.layer_number if apply_qk_layer_scaling else None
1854
1855
        attention_probs = self.scale_mask_softmax(
            attention_scores, attention_mask, attn_mask_type, softmax_scale)
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        with self.attention_dropout_ctx():
            attention_probs = self.attention_dropout(attention_probs)

        # value_layer -> context layer.
        # [sk, b, np, hn] --> [b, np, sq, hn]
        output_size = (
            value_layer.size(1),
            value_layer.size(2),
            query_layer.size(0),
            value_layer.size(3),
        )

        # change view [sk, b * np, hn]
        value_layer = value_layer.reshape(
            value_layer.size(0), output_size[0] * output_size[1], -1
        )

        # change view [b * np, sq, sk]
        attention_probs = attention_probs.view(
            output_size[0] * output_size[1], output_size[2], -1
        )

        # matmul: [b * np, sq, hn]
        context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))

        # change view [b, np, sq, hn]
        context_layer = context_layer.view(*output_size)

1887
1888
1889
        if qkv_format == 'sbhd':
            # [b, np, sq, hn] --> [sq, b, np, hn]
            context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
1890

1891
1892
1893
1894
1895
1896
1897
1898
1899
            # [sq, b, np, hn] --> [sq, b, hp]
            context_layer = context_layer.view(seqlen, batch_size, -1)

        if qkv_format == 'bshd':
            # [b, np, sq, hn] --> [b, sq, np, hn]
            context_layer = context_layer.permute(0, 2, 1, 3).contiguous()

            # [b, sq, np, hn] --> [b, sq, hp]
            context_layer = context_layer.view(batch_size, seqlen, -1)
1900
1901
1902
1903
1904
1905
1906
1907
1908

        return context_layer


class _PrepareQKVForFA(torch.autograd.Function):
    """This class converts QKV from interleaved (s, b, ...) layout
       to separate contiguous q, k, v tensors in (b, s, ...) layout."""

    @staticmethod
1909
1910
1911
1912
1913
1914
    def forward(
        _ctx: torch.autograd.function.FunctionCtx,  # unused
        query_layer: torch.Tensor,
        key_layer: torch.Tensor,
        value_layer: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
        # All inputs received are non-contiguous tensors.
        # The `query_layer` tensor is used to access the
        # full memory region of the QKV tensor.
        qkv = tex.fa_prepare_fwd(query_layer)
        q, k, v = split_tensor_along_dim(qkv, 0, 3)
        query_layer = torch.squeeze(q, 0)
        key_layer = torch.squeeze(k, 0)
        value_layer = torch.squeeze(v, 0)
        return query_layer, key_layer, value_layer

    @staticmethod
1926
1927
1928
1929
1930
    def backward(
        _ctx: torch.autograd.function.FunctionCtx,  # unused
        dq: torch.Tensor,
        dk: torch.Tensor,
        dv: torch.Tensor
1931
1932
1933
1934
1935
    ) -> Tuple[Union[torch.Tensor, None], ...]:
        dqkv = tex.fa_prepare_bwd(dq, dk, dv)
        dq, dk, dv = split_tensor_along_dim(dqkv, -1, 3)
        return dq, dk, dv

1936

1937
1938
1939
1940
1941
1942
1943
def _get_qkv_layout(
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
        qkv_format: str = 'sbhd',
    ) -> str:
    """Get qkv layout.
1944

1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
    Parameters
    ----------
    q: torch.Tensor
        Query tensor.
    k: torch.Tensor
        Key tensor.
    v: torch.Tensor
        Value tensor.
    qkv_format: str, default = `sbhd`
        Dimension format for `q`, `k` and `v`, {`sbhd`, `bshd`, `thd`}. `s` stands for
        the sequence length dimension, `b` batch size, `h` the number of attention heads,
        `d` head size, and `t` the total number of sequences in a batch, i.e.
        `t = sum(s_i) for i = 0...b-1`.

    Returns
    ----------
    qkv_layout: str
       Memory layout of `q`, `k` and `v`. Each `qkv_format` can be mapped to one of five
       memory layouts. For example, `sb3hd` means `q`, `k`, `v` are created as one chunk
       of memory and that they are interleaved in the `2`nd dimension. `sbhd_sbh2d` means
       `q` and `kv` are created in two chunks and that `q` itself is contiguous and `k`, `v`
       are interleaved with each other in the `3`rd dimension, `k = kv[:,:,:,0,:]` and
       `v = kv[:,:,:,1,:]`.
       Mapping:
       `sbhd`: {`sb3hd`, `sbh3d`, `sbhd_sb2hd`, `sbhd_sbh2d`, `sbhd_sbhd_sbhd`}
       `bshd`: {`bs3hd`, `bsh3d`, `bshd_bs2hd`, `bshd_bsh2d`, `bshd_bshd_bshd`}
       `thd` : {`t3hd`, `th3d`, `thd_t2hd`, `thd_th2d`, `thd_thd_thd`}
    """
1973

1974
1975
    check_last_dim_contiguous = all(x.stride(-1) == 1 for x in [q, k, v])
    assert check_last_dim_contiguous, "q, k and v must have stride 1 in their last dimension!"
1976

1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025
2026
2027
2028
2029
2030
2031
2032
2033
2034
2035
2036
2037
2038
    def run_iteratively(q, k, v):
        data_ptr = q.untyped_storage().data_ptr()
        check_ptrs_qkv = all(x.untyped_storage().data_ptr() == data_ptr for x in [q, k, v])
        data_ptr = k.untyped_storage().data_ptr()
        check_ptrs_kv = all(x.untyped_storage().data_ptr() == data_ptr for x in [k, v])

        stride = q.stride()
        check_strides_qkv = all(stride == x.stride() for x in [q, k, v])
        stride = k.stride()
        check_strides_kv = all(stride == x.stride() for x in [k, v])

        shape = q.shape
        check_shapes_qkv = all(shape == x.shape for x in [q, k, v])
        shape = k.shape
        check_shapes_kv = all(shape == x.shape for x in [k, v])

        last_dim_size = q.shape[-1]
        check_last_dim_offsets_qkv = all(i * last_dim_size == x.storage_offset()
                            for i, x in enumerate([q, k, v]))
        last_dim_size = k.shape[-1]
        check_last_dim_offsets_kv = all(i * last_dim_size == x.storage_offset()
                            for i, x in enumerate([k, v]))

        last_two_dims_size = q.shape[-1] * q.shape[-2]
        check_last_two_dims_offsets_qkv = all(i * last_two_dims_size == x.storage_offset()
                            for i, x in enumerate([q, k, v]))
        last_two_dims_size = k.shape[-1] * k.shape[-2]
        check_last_two_dims_offsets_kv = all(i * last_two_dims_size == x.storage_offset()
                            for i, x in enumerate([k, v]))

        if (check_ptrs_qkv and check_strides_qkv and check_shapes_qkv
            and check_last_two_dims_offsets_qkv
            and not check_last_dim_offsets_qkv):
            # sb3hd, bs3hd, t3hd
            qkv_layout = qkv_format[:-2] + '3' + qkv_format[-2:]
        elif (check_ptrs_qkv and check_strides_qkv and check_shapes_qkv
            and check_last_dim_offsets_qkv):
            # sbh3d, bsh3d, th3d
            qkv_layout = qkv_format[:-1] + '3' + qkv_format[-1:]
        elif (check_ptrs_kv and check_strides_kv and check_shapes_kv
            and check_last_two_dims_offsets_kv
            and not check_last_dim_offsets_kv):
            # sbhd_sb2hd, bshd_bs2hd, thd_t2hd
            qkv_layout = qkv_format + '_' + qkv_format[:-2] + '2' + qkv_format[-2:]
        elif (check_ptrs_kv and check_strides_kv and check_shapes_kv
            and check_last_dim_offsets_kv):
            # sbhd_sbh2d, bshd_bsh2d, thd_th2d
            qkv_layout = qkv_format + '_' + qkv_format[:-1] + '2' + qkv_format[-1:]
        elif check_strides_kv and check_shapes_kv:
            # sbhd_sbhd_sbhd, bshd_bshd_bshd, thd_thd_thd
            qkv_layout = '_'.join(list([qkv_format])*3)
        else:
            qkv_layout = 'not_supported'

        return qkv_layout

    qkv_layout = run_iteratively(q, k, v)
    if qkv_layout == 'not_supported':
        # force q,k,v to be contiguous and run get_layout again
        q, k, v = [x.contiguous() for x in [q, k, v]]
        qkv_layout = run_iteratively(q, k, v)
    if qkv_layout == 'not_supported':
2039
2040
        raise Exception("The provided qkv memory layout is not supported!")

2041
    return qkv_layout, q, k, v
2042

2043

2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
2055
2056
2057
2058
2059
2060
2061
def check_set_window_size(
        attn_mask_type: str,
        window_size: Tuple[int, int] = None,
    ):
    """Check if sliding window size is compliant with mask type and if not,
    assert or set it to the appropriate size
    """
    if "causal" in attn_mask_type:
        if window_size is None:
            window_size = (-1, 0)
        else:
            assert (
                window_size[1] == 0
            ), "window_size[1] should be 0 when self_attn_mask_type includes 'causal'!"
    else:
        if window_size is None:
            window_size = (-1, -1)
    return window_size
2062

2063

2064
class FlashAttention(torch.nn.Module):
2065
    """Dot product attention, using HazyResearch flash-attn package:
2066
    https://github.com/Dao-AILab/flash-attention
2067
2068
2069
2070
    """

    def __init__(
        self,
2071
        softmax_scale: float,
2072
2073
        attention_dropout: float = 0.0,
        attention_dropout_ctx: Optional[Callable] = nullcontext,
2074
2075
        attention_type: str = "self",
        layer_number: Optional[int] = None,
2076
        deterministic: bool = False,
2077
2078
2079
2080
2081
2082
    ) -> None:
        super().__init__()

        assert (
            _flash_attn_version >= _flash_attn_version_required
        ), f"FlashAttention minimum version {_flash_attn_version_required} is required."
2083
2084
2085
        assert (
            _flash_attn_version <= _flash_attn_max_version
        ), f"FlashAttention maximum version {_flash_attn_max_version} is supported."
2086

2087
        self.softmax_scale = softmax_scale
2088
2089
        self.attention_dropout_ctx = attention_dropout_ctx
        self.attention_dropout = attention_dropout
2090
2091
        self.attention_type = attention_type
        self.layer_number = 1 if layer_number is None else layer_number
2092
        self.deterministic = deterministic
2093
2094
2095
2096
2097
2098

    def forward(
        self,
        query_layer: torch.Tensor,
        key_layer: torch.Tensor,
        value_layer: torch.Tensor,
2099
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
2100
2101
2102
        qkv_layout: str = "sbh3d",
        cu_seqlens_q: Optional[torch.Tensor] = None,
        cu_seqlens_kv: Optional[torch.Tensor] = None,
2103
2104
        max_seqlen_q: Optional[int] = None,
        max_seqlen_kv: Optional[int] = None,
2105
        attn_mask_type: str = "causal",
2106
        window_size: Optional[Tuple[int, int]] = None,
2107
        alibi_slopes: Optional[torch.Tensor] = None,
2108
        cp_group: Optional[dist_group_type] = None,
2109
        cp_global_ranks: List[int] = None,
2110
        cp_stream: torch.cuda.Stream = None,
2111
2112
2113
    ) -> torch.Tensor:
        """flash-attn fprop"""

2114
2115
        window_size = check_set_window_size(attn_mask_type, window_size)

2116
        assert (
2117
2118
2119
            query_layer.dtype in [torch.float16, torch.bfloat16]
            and key_layer.dtype in [torch.float16, torch.bfloat16]
            and value_layer.dtype in [torch.float16, torch.bfloat16]
2120
            ), "FlashAttention currently only supports FP16 and BF16."
2121
2122
        assert (
            query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda
2123
2124
2125
2126
2127
            ), "FlashAttention currently only supports CUDA tensors."
        assert (
            qkv_layout in QKVLayouts
            ), f"FlashAttention does not support qkv_layout = {qkv_layout}!"

2128
2129
        context_parallel = (cp_group is not None) and (get_distributed_world_size(cp_group) != 1)

2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
        qkv_format = ''.join([i for i in qkv_layout.split('_')[0] if i.isalpha()])

        if qkv_format == 'sbhd':
            # For now just 128, will make it more general in the future
            if (query_layer.shape[-1] == 128 and
                query_layer.shape[0] * query_layer.shape[1] >= 512 and
                qkv_layout == "sbh3d"):
                query_layer, key_layer, value_layer = _PrepareQKVForFA.apply(query_layer,
                                                                             key_layer,
                                                                             value_layer)
            else:
                query_layer, key_layer, value_layer = [x.transpose(0,1).contiguous()
                    for x in (query_layer, key_layer, value_layer)]
2143
        elif qkv_format in ['bshd', 'thd']:
2144
2145
2146
            query_layer, key_layer, value_layer = [x.contiguous()
                for x in (query_layer, key_layer, value_layer)]

2147
        batch_size = query_layer.shape[0]
2148

2149
        if qkv_format in ['sbhd', 'bshd']:
2150
            max_seqlen_q, max_seqlen_kv = query_layer.shape[1], key_layer.shape[1]
2151
2152
2153
2154
2155
2156
2157
            if not context_parallel:
                # [b * s, h, d]
                query_layer, key_layer, value_layer = [
                    x.view(x.shape[0] * x.shape[1], *x.shape[2:])
                    for x in [query_layer, key_layer, value_layer]
                ]

2158
            if 'padding' in attn_mask_type:
2159
                assert not context_parallel, "Padding mask not supported with context parallelism!"
2160
2161
2162
2163
2164

                if self.attention_type == "self":
                    assert (
                        max_seqlen_q == max_seqlen_kv
                    ), "Maximum sequence length for Q and KV should be the same."
2165
2166
                    if cu_seqlens_q is None:
                        assert (attention_mask is not None
2167
                                ), "Please provide attention_mask for padding!"
2168
2169
2170
2171
2172
2173
                        cu_seqlens_q, indices_q = get_cu_seqlens_and_indices(attention_mask)
                    else:
                        indices_q = get_indices(max_seqlen_q, cu_seqlens_q)
                    cu_seqlens_kv = cu_seqlens_q
                    query_layer, key_layer, value_layer = PackTensors.apply(
                        indices_q, query_layer, key_layer, value_layer
2174
2175
                    )
                else:
2176
2177
2178
2179
2180
2181
2182
2183
2184
2185
2186
2187
2188
                    if cu_seqlens_q is None or cu_seqlens_kv is None:
                        assert (attention_mask is not None
                            ), "Please provide attention_mask for padding!"
                        cu_seqlens_q, indices_q = get_cu_seqlens_and_indices(
                            attention_mask[0])
                        cu_seqlens_kv, indices_kv = get_cu_seqlens_and_indices(
                            attention_mask[1])
                    else:
                        indices_q = get_indices(max_seqlen_q, cu_seqlens_q)
                        indices_kv = get_indices(max_seqlen_kv, cu_seqlens_kv)
                    query_layer = PackTensors.apply(indices_q, query_layer)
                    key_layer, value_layer = PackTensors.apply(
                        indices_kv, key_layer, value_layer
2189
2190
                    )
            else:
2191
2192
2193
2194
2195
2196
2197
2198
2199
2200
2201
2202
2203
                # Cumulative sequence lengths for unpadded data
                if cu_seqlens_q is None:
                    cu_seqlens_q = _get_full_cu_seqlens(
                        batch_size,
                        max_seqlen_q,
                        query_layer.device,
                    )
                if cu_seqlens_kv is None:
                    cu_seqlens_kv = _get_full_cu_seqlens(
                        batch_size,
                        max_seqlen_kv,
                        key_layer.device,
                    )
2204
        elif qkv_format == 'thd':
2205
2206
            assert (cu_seqlens_q is not None and cu_seqlens_kv is not None
                ), "cu_seqlens_q and cu_seqlens_kv can not be None when qkv_format = thd!"
2207
2208
2209
2210
2211
2212
            if max_seqlen_q is None:
                seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
                max_seqlen_q = seqlens_q.max().item()
            if max_seqlen_kv is None:
                seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
                max_seqlen_kv = seqlens_kv.max().item()
2213

2214
        if context_parallel:
2215
2216
2217
            assert (
                window_size in ((-1, -1), (-1, 0))
                ), "Sliding window attention is not supported with context parallelism."
2218
2219
2220
            assert (
                alibi_slopes is None
            ), "Alibi slope bias addition is not supported with context parallelism."
2221
            with self.attention_dropout_ctx():
2222
2223
                output = attn_forward_func_with_cp(
                    self.training, query_layer, key_layer, value_layer,
2224
                    cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv,
2225
                    None, None, None, None,
2226
                    self.attention_dropout if self.training else 0.0,
2227
                    cp_group, cp_global_ranks, cp_stream,
2228
                    softmax_scale=self.softmax_scale,
2229
                    qkv_format="bshd" if qkv_format=="sbhd" else qkv_format,
2230
                    attn_mask_type=attn_mask_type,
2231
                    deterministic=self.deterministic
2232
2233
                )
        else:
2234
2235
2236
2237
2238
2239
2240
2241

            from .cpu_offload import CPUOffloadEnabled
            if CPUOffloadEnabled:
                tensor_list = [query_layer, key_layer, value_layer, cu_seqlens_q, cu_seqlens_kv]
                for tensor in tensor_list:
                    if tensor is not None:
                        tensor.activation_offloading = True

2242
            with self.attention_dropout_ctx():
2243
                fa_optional_forward_kwargs = {}
2244
2245
                if _flash_attn_2_3_plus:
                    fa_optional_forward_kwargs["window_size"] = window_size
2246
2247
2248
2249
                if _flash_attn_2_4_plus:
                    fa_optional_forward_kwargs["alibi_slopes"] = alibi_slopes
                if _flash_attn_2_4_1_plus:
                    fa_optional_forward_kwargs["deterministic"] = self.deterministic
2250
                output = flash_attn_forward_func(
2251
                    query_layer, key_layer, value_layer,
2252
                    cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv,
2253
                    self.attention_dropout if self.training else 0.0,
2254
                    softmax_scale=self.softmax_scale, causal="causal" in attn_mask_type,
2255
                    **fa_optional_forward_kwargs,
2256
                )
2257

2258
        if qkv_format in ['sbhd', 'bshd'] and 'padding' in attn_mask_type:
2259
            output = UnpackTensor.apply(indices_q, batch_size * max_seqlen_q, output)
2260

2261
2262
2263
        if qkv_format == 'sbhd':
            # (bs)hd -> bs(hd) -> sb(hd)
            output = output.view(batch_size, max_seqlen_q, -1).transpose(0, 1).contiguous()
2264
        elif qkv_format == 'bshd':
2265
2266
            # (bs)hd -> bs(hd)
            output = output.view(batch_size, max_seqlen_q, -1).contiguous()
2267
2268
2269
        elif qkv_format == 'thd':
            # thd -> t(hd)
            output = output.view(output.shape[0], -1).contiguous()
2270
2271

        return output
2272

2273
2274
2275
2276
2277
2278
2279
2280
2281
2282
2283
2284
2285
2286
2287
2288
2289
2290
2291
2292
2293
2294
2295
2296
2297
2298
2299
2300
2301
def _combine_tensors(
        tensors: List[torch.Tensor],
        dim: int,
    ) -> torch.Tensor:
    """Combine tensors along a particular dimension"""

    num_tensors = len(tensors)
    new_shape = list(tensors[0].shape)
    new_shape.insert(dim, num_tensors)
    new_stride = list(tensors[0].stride())
    new_stride.insert(dim, int(new_stride[dim-1]/num_tensors))
    if isinstance(tensors[0], Float8Tensor):
        combined_tensor = torch.Tensor().to(
            device=tensors[0].device, dtype=tensors[0]._data.dtype)
        combined_tensor.set_(
            tensors[0]._data.untyped_storage(),
            tensors[0]._data.storage_offset(),
            new_shape, new_stride)
        combined_tensor = Float8Tensor.make_like(
            tensors[0], data=combined_tensor)
    else:
        combined_tensor = torch.Tensor().to(
            device=tensors[0].device, dtype=tensors[0].dtype)
        combined_tensor.set_(
            tensors[0].untyped_storage(),
            tensors[0].storage_offset(),
            new_shape, new_stride)

    return combined_tensor
2302

2303
2304
2305
2306
class FusedAttnFunc_qkvpacked(torch.autograd.Function):
    """Function for FusedAttention with packed QKV input"""

    @staticmethod
2307
2308
2309
    def forward(ctx, is_training, max_seqlen, cu_seqlens,
                seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
                qkv, qkv_dtype, attn_bias, attn_scale,
2310
                dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type,
2311
                rng_gen, fused_attention_backend, use_FAv2_bwd,
2312
                fp8, fp8_meta):
2313
        logger = logging.getLogger("FusedAttnFunc_qkvpacked")
2314
        if fp8:
2315
            logger.debug("Running forward in FP8")
2316
2317
2318
2319
2320
2321
2322
2323
2324
2325
2326
2327
2328
2329
2330
2331
2332
2333
2334
2335
            if fp8_meta["recipe"].fp8_mha:
                assert (isinstance(qkv, Float8Tensor)), "qkv must be Float8Tensors for FP8 MHA."
                fp8_meta["scaling_fwd"].scale_inv[META_QKV] = qkv._scale_inv
            fused_attention_backend = FusedAttnBackend["FP8"]
            fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
            # 1: qkv packed, 2: kv packed, 3: qkv separate
            qkv_group = len(qkv_layout.split('_'))
            assert (qkv_group == 1
                ), f"qkv layout should conform to 3hd or h3d, e.g. sb3hd, \
                but found {qkv_layout}."
            if fp8_meta["recipe"].fp8_mha:
                qkv_fp8 = qkv._data
            else:
                qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1])
                qkv_fp8 = cast_to_fp8(qkv_c,
                    fp8_meta["scaling_fwd"],
                    META_QKV, fp8_dtype_forward).view(qkv.shape)
            out_fp8, aux_ctx_tensors = fused_attn_fwd_qkvpacked(
                is_training, max_seqlen, cu_seqlens,
                qkv_fp8, fp8_dtype_forward, fused_attention_backend, attn_bias,
2336
                seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
2337
2338
2339
2340
2341
2342
2343
2344
2345
2346
2347
2348
2349
2350
2351
2352
2353
2354
2355
2356
2357
2358
2359
2360
2361
2362
2363
2364
2365
2366
2367
2368
2369
2370
2371
                fp8_meta["scaling_fwd"].scale_inv[META_QKV],
                fp8_meta["scaling_fwd"].scale_inv[META_S],
                fp8_meta["scaling_fwd"].scale[META_S],
                fp8_meta["scaling_fwd"].scale[META_O],
                fp8_meta["scaling_fwd"].amax_history[0][META_S],
                fp8_meta["scaling_fwd"].amax_history[0][META_O],
                attn_scale, dropout_p, fast_zero_fill, qkv_layout,
                attn_bias_type, attn_mask_type, rng_gen)
            if fp8_meta["recipe"].fp8_mha:
                out_ret = Float8Tensor(data=out_fp8,
                    fp8_meta=fp8_meta,
                    fp8_meta_forward=True,
                    fp8_meta_index=META_O,
                    fp8_dtype=fp8_dtype_forward,
                    dtype=qkv.dtype,
                )
            else:
                out_ret = cast_from_fp8(
                    out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]),
                    fp8_meta["scaling_fwd"], META_O,
                    fp8_dtype_forward, qkv_dtype).view(out_fp8.shape)
            out_save = out_ret
            if fp8_meta["recipe"].fp8_mha and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
                qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1])
                qkv = cast_from_fp8(qkv_c._data,
                    fp8_meta["scaling_fwd"],
                    META_QKV, fp8_dtype_forward, TE_DType[qkv.dtype]).view(qkv.shape)
                out_save = cast_from_fp8(
                    out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]),
                    fp8_meta["scaling_fwd"], META_O,
                    fp8_dtype_forward, qkv_dtype).view(out_fp8.shape)
            fp8_tensors = (qkv_fp8, out_fp8,
                fp8_meta["scaling_fwd"].scale.clone(),
                fp8_meta["scaling_fwd"].scale_inv.clone())
        else:
2372
            logger.debug("Running forward in %s",qkv.dtype)
2373
2374
2375
            out_ret, aux_ctx_tensors = fused_attn_fwd_qkvpacked(
                is_training, max_seqlen, cu_seqlens, qkv, qkv_dtype,
                fused_attention_backend, attn_bias,
2376
                seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
2377
2378
2379
2380
2381
2382
2383
2384
                None, None, None, None, None, None,
                attn_scale, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type,
                rng_gen)
            fp8_tensors = (None, None, None, None)
            out_save = out_ret

        ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
        qkvo_tensors = (qkv, out_save) if not ctx.fp8 else (None, None)
2385
2386
2387
        ctx.save_for_backward(*qkvo_tensors, cu_seqlens,
            seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
            *fp8_tensors, *aux_ctx_tensors)
2388
        ctx.fp8_meta = fp8_meta
2389
2390
2391
2392
2393
2394
2395
2396
        ctx.max_seqlen = max_seqlen
        ctx.qkv_dtype = qkv_dtype
        ctx.attn_scale = attn_scale
        ctx.dropout_p = dropout_p
        ctx.fast_zero_fill = fast_zero_fill
        ctx.qkv_layout = qkv_layout
        ctx.attn_bias_type = attn_bias_type
        ctx.attn_mask_type = attn_mask_type
2397
2398
        ctx.fused_attention_backend = \
            fused_attention_backend if ctx.fp8 else FusedAttnBackend["F16_arbitrary_seqlen"]
2399
        ctx.use_FAv2_bwd = use_FAv2_bwd
2400

2401
        return out_ret
2402
2403
2404

    @staticmethod
    def backward(ctx, d_out):
2405
        logger = logging.getLogger("FusedAttnFunc_qkvpacked")
2406
2407
2408
2409
2410
2411
        if ctx.fp8_meta["recipe"].fp8_mha:
            assert (isinstance(d_out, Float8Tensor)
                ), "Gradient of the DPA output must be in Float8Tensor type for FP8 MHA."
            d_out_f8tensor = d_out
            d_out = d_out._data

2412
        d_out = d_out.contiguous()
2413
2414
2415
        (qkv, out, cu_seqlens,
            seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
            qkv_fp8, out_fp8,
2416
2417
2418
            fwd_scales, fwd_scale_invs, *aux_ctx_tensors) = ctx.saved_tensors
        if not aux_ctx_tensors[0].is_contiguous():
            aux_ctx_tensors[0] = aux_ctx_tensors[0].contiguous()
2419
        if ctx.use_FAv2_bwd:
2420
            softmax_lse, rng_state = aux_ctx_tensors
2421
2422
2423
2424
2425
2426
2427
2428
            dqkv = torch.empty_like(qkv)
            maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
            d_out, q, k, v, out = [maybe_contiguous(x)
                for x in (d_out, qkv[:,0], qkv[:,1], qkv[:,2], out)]
            flash_attn_cuda_bwd(
                d_out, q, k, v, out, softmax_lse, dqkv[:,0], dqkv[:,1], dqkv[:,2],
                cu_seqlens, cu_seqlens, ctx.max_seqlen, ctx.max_seqlen,
                ctx.dropout_p, ctx.attn_scale, False,
2429
                "causal" in ctx.attn_mask_type, None, rng_state
2430
2431
2432
            )
            dqkv = dqkv[..., :d_out.shape[-1]]
        else:
2433
2434
            with torch.cuda.nvtx.range("_FusedAttn_qkvpacked"):
                if ctx.fp8:
2435
                    logger.debug("Running backward in FP8")
2436
2437
2438
2439
2440
2441
2442
2443
2444
2445
2446
2447
2448
2449
2450
                    fp8_dtype_forward = get_fp8_te_dtype(
                        ctx.fp8_meta["recipe"], fprop_tensor=True)
                    fp8_dtype_backward = get_fp8_te_dtype(
                        ctx.fp8_meta["recipe"], fprop_tensor=False)
                    if ctx.fp8_meta["recipe"].fp8_mha:
                        d_out_fp8 = d_out
                        ctx.fp8_meta['scaling_bwd'].scale_inv[META_DO] = d_out_f8tensor._scale_inv
                    else:
                        d_out_fp8 = cast_to_fp8(
                            d_out.view(-1, d_out.shape[-2] * d_out.shape[-1]),
                            ctx.fp8_meta["scaling_bwd"], META_DO, fp8_dtype_backward
                            ).view(d_out.shape)
                    dqkv_fp8, *rest = fused_attn_bwd_qkvpacked(
                        ctx.max_seqlen, cu_seqlens,
                        qkv_fp8, out_fp8, d_out_fp8,
2451
                        fp8_dtype_forward, fp8_dtype_backward, aux_ctx_tensors,
2452
                        ctx.fused_attention_backend,
2453
                        seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
2454
2455
2456
2457
2458
2459
2460
2461
2462
2463
2464
2465
2466
2467
2468
2469
2470
2471
2472
2473
2474
2475
2476
2477
2478
2479
2480
                        fwd_scale_invs[META_QKV], # d_scale_qkv,
                        fwd_scale_invs[META_S], # d_scale_s,
                        fwd_scale_invs[META_O], # d_scale_o,
                        ctx.fp8_meta['scaling_bwd'].scale_inv[META_DO], # d_scale_do
                        ctx.fp8_meta['scaling_bwd'].scale_inv[META_DP], # d_scale_dp
                        fwd_scales[META_S], # q_scale_s
                        ctx.fp8_meta['scaling_bwd'].scale[META_DP], # q_scale_dp
                        ctx.fp8_meta['scaling_bwd'].scale[META_DQKV], # q_scale_dqkv
                        ctx.fp8_meta['scaling_bwd'].amax_history[0][META_DP], # amax_dp
                        ctx.fp8_meta['scaling_bwd'].amax_history[0][META_DQKV], # amax_dqkv
                        ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill,
                        ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type)
                    if ctx.fp8_meta["recipe"].fp8_mha:
                        dqkv = Float8Tensor(data=dqkv_fp8,
                            fp8_meta=ctx.fp8_meta,
                            fp8_meta_forward=False,
                            fp8_meta_index=META_DQKV,
                            fp8_dtype=fp8_dtype_backward,
                            dtype=d_out_f8tensor.dtype,
                            )
                    else:
                        dqkv_c_fp8 = dqkv_fp8.view(-1,
                            dqkv_fp8.shape[-3] * dqkv_fp8.shape[-2] * dqkv_fp8.shape[-1])
                        dqkv = cast_from_fp8(dqkv_c_fp8,
                            ctx.fp8_meta["scaling_bwd"], META_DQKV,
                            fp8_dtype_backward, ctx.qkv_dtype).view(dqkv_fp8.shape)
                else:
2481
                    logger.debug("Running backward in %s",qkv.dtype)
2482
2483
2484
2485
                    if d_out.dtype == torch.uint8:
                        d_out = d_out_f8tensor.from_float8(qkv.dtype)
                    dqkv, *rest = fused_attn_bwd_qkvpacked(
                        ctx.max_seqlen, cu_seqlens, qkv, out, d_out,
2486
                        ctx.qkv_dtype, ctx.qkv_dtype, aux_ctx_tensors,
2487
                        ctx.fused_attention_backend,
2488
                        seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
2489
2490
2491
                        None, None, None, None, None, None, None, None, None, None,
                        ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill,
                        ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type)
2492

2493
2494
        # if no_bias or alibi, return dqkv
        if ctx.attn_bias_type in ["no_bias", "alibi"]:
2495
            return (None, None, None, None, None, None, None, dqkv, None, None, None,
2496
2497
2498
                    None, None, None, None, None, None,
                    None, None, None, None, None, None)
        # else, return (dqkv, dbias)
2499
        return (None, None, None, None, None, None, None, dqkv, None, rest[0], None,
2500
2501
2502
                None, None, None, None, None, None,
                None, None, None, None, None, None)

2503

2504
2505
2506
2507
2508
class FusedAttnFunc_kvpacked(torch.autograd.Function):
    """Function for FusedAttention with packed KV input"""

    @staticmethod
    def forward(ctx, is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
2509
                seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
2510
                q, kv, qkv_dtype, attn_bias, attn_scale, dropout_p, fast_zero_fill,
2511
                qkv_layout, attn_bias_type, attn_mask_type, rng_gen, fused_attention_backend,
2512
                use_FAv2_bwd, fp8, fp8_meta):
2513
        logger = logging.getLogger("FusedAttnFunc_kvpacked")
2514
        if fp8:
2515
            logger.debug("Running forward in FP8")
2516
2517
2518
2519
2520
2521
2522
2523
2524
2525
2526
2527
2528
2529
2530
2531
2532
2533
2534
2535
2536
2537
2538
2539
            if fp8_meta["recipe"].fp8_mha:
                assert (isinstance(q, Float8Tensor)
                    and isinstance(kv, Float8Tensor)), "q/kv must be Float8Tensors for FP8 MHA."
                fp8_meta["scaling_fwd"].scale_inv[META_QKV] = q._scale_inv
            fused_attention_backend = FusedAttnBackend["FP8"]
            fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
            if fp8_meta["recipe"].fp8_mha:
                q_fp8, kv_fp8 = q._data, kv._data
            else:
                # 1: qkv packed, 2: kv packed, 3: qkv separate
                qkv_group = len(qkv_layout.split('_'))
                assert (qkv_group == 2
                    ), f"qkv layout should conform to hd_2hd or hd_h2d, e.g. sbhd_sb2hd, \
                    but found {qkv_layout}."
                q_fp8 = cast_to_fp8(q,
                    fp8_meta["scaling_fwd"],
                    META_QKV, fp8_dtype_forward).view(q.shape)
                kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1])
                kv_fp8 = cast_to_fp8(kv_c,
                    fp8_meta["scaling_fwd"],
                    META_QKV, fp8_dtype_forward).view(kv.shape)
            out_fp8, aux_ctx_tensors = fused_attn_fwd_kvpacked(
                is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
                q_fp8, kv_fp8, fp8_dtype_forward, fused_attention_backend, attn_bias,
2540
                seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
2541
2542
2543
2544
2545
2546
2547
2548
2549
2550
2551
2552
2553
2554
2555
2556
2557
2558
2559
2560
2561
2562
2563
2564
2565
2566
2567
2568
2569
2570
2571
2572
2573
2574
2575
2576
2577
2578
                fp8_meta["scaling_fwd"].scale_inv[META_QKV],
                fp8_meta["scaling_fwd"].scale_inv[META_S],
                fp8_meta["scaling_fwd"].scale[META_S],
                fp8_meta["scaling_fwd"].scale[META_O],
                fp8_meta["scaling_fwd"].amax_history[0][META_S],
                fp8_meta["scaling_fwd"].amax_history[0][META_O],
                attn_scale, dropout_p, fast_zero_fill, qkv_layout,
                attn_bias_type, attn_mask_type, rng_gen)
            if fp8_meta["recipe"].fp8_mha:
                out_ret = Float8Tensor(data=out_fp8,
                    fp8_meta=fp8_meta,
                    fp8_meta_forward=True,
                    fp8_meta_index=META_O,
                    fp8_dtype=fp8_dtype_forward,
                    dtype=q.dtype,
                )
            else:
                out_ret = cast_from_fp8(
                    out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]),
                    fp8_meta["scaling_fwd"], META_O,
                    fp8_dtype_forward, qkv_dtype).view(out_fp8.shape)
            out_save = out_ret
            if fp8_meta["recipe"].fp8_mha and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
                q = cast_from_fp8(q._data,
                    fp8_meta["scaling_fwd"],
                    META_QKV, fp8_dtype_forward, TE_DType[q.dtype]).view(q.shape)
                kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1])
                kv = cast_from_fp8(kv_c._data,
                    fp8_meta["scaling_fwd"],
                    META_QKV, fp8_dtype_forward, TE_DType[kv.dtype]).view(kv.shape)
                out_save = cast_from_fp8(
                    out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]),
                    fp8_meta["scaling_fwd"], META_O,
                    fp8_dtype_forward, qkv_dtype).view(out_fp8.shape)
            fp8_tensors = (q_fp8, kv_fp8, out_fp8,
                fp8_meta["scaling_fwd"].scale.clone(),
                fp8_meta["scaling_fwd"].scale_inv.clone())
        else:
2579
            logger.debug("Running forward in %s",q.dtype)
2580
2581
2582
            out_ret, aux_ctx_tensors = fused_attn_fwd_kvpacked(
                is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
                q, kv, qkv_dtype, fused_attention_backend, attn_bias,
2583
                seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
2584
2585
2586
2587
2588
2589
2590
2591
                None, None, None, None, None, None,
                attn_scale, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type,
                rng_gen)
            out_save = out_ret
            fp8_tensors = (None, None, None, None, None)

        ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
        qkvo_tensors = (q, kv, out_save) if not ctx.fp8 else (None, None, None)
2592
        ctx.save_for_backward(*qkvo_tensors, cu_seqlens_q, cu_seqlens_kv,
2593
            seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
2594
            *fp8_tensors, *aux_ctx_tensors)
2595
        ctx.fp8_meta = fp8_meta
2596
2597
2598
2599
2600
2601
2602
2603
2604
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_kv = max_seqlen_kv
        ctx.qkv_dtype = qkv_dtype
        ctx.attn_scale = attn_scale
        ctx.dropout_p = dropout_p
        ctx.fast_zero_fill = fast_zero_fill
        ctx.qkv_layout = qkv_layout
        ctx.attn_bias_type = attn_bias_type
        ctx.attn_mask_type = attn_mask_type
2605
2606
        ctx.fused_attention_backend = \
            fused_attention_backend if ctx.fp8 else FusedAttnBackend["F16_arbitrary_seqlen"]
2607
        ctx.use_FAv2_bwd = use_FAv2_bwd
2608

2609
        return out_ret
2610
2611
2612

    @staticmethod
    def backward(ctx, d_out):
2613
        logger = logging.getLogger("FusedAttnFunc_kvpacked")
2614
2615
2616
2617
2618
2619
        if ctx.fp8_meta["recipe"].fp8_mha:
            assert (isinstance(d_out, Float8Tensor)
                ), "Gradient of the DPA output must be in Float8Tensor type for FP8 MHA."
            d_out_f8tensor = d_out
            d_out = d_out._data

2620
        d_out = d_out.contiguous()
2621
2622
2623
        (q, kv, out, cu_seqlens_q, cu_seqlens_kv,
            seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
            q_fp8, kv_fp8, out_fp8,
2624
2625
2626
            fwd_scales, fwd_scale_invs, *aux_ctx_tensors) = ctx.saved_tensors
        if not aux_ctx_tensors[0].is_contiguous():
            aux_ctx_tensors[0] = aux_ctx_tensors[0].contiguous()
2627
        if ctx.use_FAv2_bwd:
2628
            softmax_lse, rng_state = aux_ctx_tensors
2629
2630
2631
2632
2633
2634
2635
2636
2637
            dq = torch.empty_like(q)
            dkv = torch.empty_like(kv)
            maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
            d_out, q, k, v, out = [maybe_contiguous(x)
                for x in (d_out, q, kv[:,0], kv[:,1], out)]
            flash_attn_cuda_bwd(
                d_out, q, k, v, out, softmax_lse, dq, dkv[:,0], dkv[:,1],
                cu_seqlens_q, cu_seqlens_kv, ctx.max_seqlen_q, ctx.max_seqlen_kv,
                ctx.dropout_p, ctx.attn_scale, False,
2638
                "causal" in ctx.attn_mask_type, None, rng_state
2639
2640
2641
2642
            )
            dq = dq[..., :d_out.shape[-1]]
            dkv = dkv[..., :d_out.shape[-1]]
        else:
2643
2644
            with torch.cuda.nvtx.range("_FusedAttn_kvpacked"):
                if ctx.fp8:
2645
                    logger.debug("Running backward in FP8")
2646
2647
2648
2649
2650
2651
2652
2653
2654
2655
2656
2657
2658
2659
2660
                    fp8_dtype_forward = get_fp8_te_dtype(
                        ctx.fp8_meta["recipe"], fprop_tensor=True)
                    fp8_dtype_backward = get_fp8_te_dtype(
                        ctx.fp8_meta["recipe"], fprop_tensor=False)
                    if ctx.fp8_meta["recipe"].fp8_mha:
                        d_out_fp8 = d_out
                        ctx.fp8_meta['scaling_bwd'].scale_inv[META_DO] = d_out_f8tensor._scale_inv
                    else:
                        d_out_fp8 = cast_to_fp8(
                            d_out.view(-1, d_out.shape[-2] * d_out.shape[-1]),
                            ctx.fp8_meta["scaling_bwd"], META_DO, fp8_dtype_backward
                            ).view(d_out.shape)
                    dq_fp8, dkv_fp8, *rest = fused_attn_bwd_kvpacked(
                        ctx.max_seqlen_q, ctx.max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
                        q_fp8, kv_fp8, out_fp8, d_out_fp8,
2661
                        fp8_dtype_forward, fp8_dtype_backward, aux_ctx_tensors,
2662
                        ctx.fused_attention_backend,
2663
                        seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
2664
2665
2666
2667
2668
2669
2670
2671
2672
2673
2674
2675
2676
2677
2678
2679
2680
2681
2682
2683
2684
2685
2686
2687
2688
2689
2690
2691
2692
2693
2694
2695
2696
2697
2698
2699
2700
2701
                        fwd_scale_invs[META_QKV], # d_scale_qkv,
                        fwd_scale_invs[META_S], # d_scale_s,
                        fwd_scale_invs[META_O], # d_scale_o,
                        ctx.fp8_meta['scaling_bwd'].scale_inv[META_DO], # d_scale_do
                        ctx.fp8_meta['scaling_bwd'].scale_inv[META_DP], # d_scale_dp
                        fwd_scales[META_S], # q_scale_s
                        ctx.fp8_meta['scaling_bwd'].scale[META_DP], # q_scale_dp
                        ctx.fp8_meta['scaling_bwd'].scale[META_DQKV], # q_scale_dqkv
                        ctx.fp8_meta['scaling_bwd'].amax_history[0][META_DP], # amax_dp
                        ctx.fp8_meta['scaling_bwd'].amax_history[0][META_DQKV], # amax_dqkv
                        ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill,
                        ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type)
                    if ctx.fp8_meta["recipe"].fp8_mha:
                        dq = Float8Tensor(data=dq_fp8,
                            fp8_meta=ctx.fp8_meta,
                            fp8_meta_forward=False,
                            fp8_meta_index=META_DQKV,
                            fp8_dtype=fp8_dtype_backward,
                            dtype=d_out_f8tensor.dtype,
                            )
                        dkv = Float8Tensor(data=dkv_fp8,
                            fp8_meta=ctx.fp8_meta,
                            fp8_meta_forward=False,
                            fp8_meta_index=META_DQKV,
                            fp8_dtype=fp8_dtype_backward,
                            dtype=d_out_f8tensor.dtype,
                            )
                    else:
                        dq = cast_from_fp8(
                            dq_fp8.view(-1, dq_fp8.shape[-2] * dq_fp8.shape[-1]),
                            ctx.fp8_meta["scaling_bwd"], META_DQKV,
                            fp8_dtype_backward, ctx.qkv_dtype).view(dq_fp8.shape)
                        dkv_c_fp8 = dkv_fp8.view(-1,
                            dkv_fp8.shape[-3] * dkv_fp8.shape[-2] * dkv_fp8.shape[-1])
                        dkv = cast_from_fp8(dkv_c_fp8,
                            ctx.fp8_meta["scaling_bwd"], META_DQKV,
                            fp8_dtype_backward, ctx.qkv_dtype).view(dkv_fp8.shape)
                else:
2702
                    logger.debug("Running backward in %s",q.dtype)
2703
2704
2705
2706
2707
                    if d_out.dtype == torch.uint8:
                        d_out = d_out_f8tensor.from_float8(q.dtype)
                    dq, dkv, *rest = fused_attn_bwd_kvpacked(
                        ctx.max_seqlen_q, ctx.max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
                        q, kv, out, d_out,
2708
                        ctx.qkv_dtype, ctx.qkv_dtype, aux_ctx_tensors,
2709
                        ctx.fused_attention_backend,
2710
                        seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
2711
2712
2713
                        None, None, None, None, None, None, None, None, None, None,
                        ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill,
                        ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type)
2714

2715
2716
        # if no_bias or alibi, return dqkv
        if ctx.attn_bias_type in ["no_bias", "alibi"]:
2717
            return (None, None, None, None, None, None, None, None, None, dq, dkv, None, None, None,
2718
2719
2720
                    None, None, None, None, None, None,
                    None, None, None, None, None, None)
        # else, return (dqkv, dbias)
2721
        return (None, None, None, None, None, None, None, None, None, dq, dkv, None, rest[0], None,
2722
2723
2724
                None, None, None, None, None, None,
                None, None, None, None, None, None)

2725
2726
2727
2728
2729
class FusedAttnFunc(torch.autograd.Function):
    """Function for FusedAttention with separate Q, K, V tensors"""

    @staticmethod
    def forward(ctx, is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
2730
                seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
2731
                q, k, v, qkv_dtype, attn_bias, attn_scale, dropout_p, fast_zero_fill,
2732
                qkv_layout, attn_bias_type, attn_mask_type, rng_gen, fused_attention_backend,
2733
                use_FAv2_bwd, fp8, fp8_meta):
2734
        logger = logging.getLogger("FusedAttnFunc")
2735
        if fp8:
2736
            logger.debug("Running forward in FP8")
2737
2738
2739
2740
2741
2742
2743
2744
2745
2746
2747
2748
2749
2750
2751
2752
2753
2754
2755
2756
2757
2758
2759
2760
2761
2762
2763
2764
2765
2766
2767
2768
2769
2770
2771
2772
2773
2774
2775
2776
2777
2778
2779
2780
2781
            fused_attention_backend = FusedAttnBackend["FP8"]
            fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
            if fp8_meta["recipe"].fp8_mha:
                assert (isinstance(q, Float8Tensor)
                    and isinstance(k, Float8Tensor)
                    and isinstance(v, Float8Tensor)), "q/k/v must be Float8Tensors for FP8 MHA."
                fp8_meta["scaling_fwd"].scale_inv[META_QKV] = q._scale_inv
                q_fp8, k_fp8, v_fp8 = q._data, k._data, v._data
            else:
                # 1: qkv packed, 2: kv packed, 3: qkv separate
                qkv_group = len(qkv_layout.split('_'))
                if qkv_group == 1:
                    dim = qkv_layout.find('3')
                    qkv = _combine_tensors([q,k,v], dim)
                    qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1])
                    qkv_fp8 = cast_to_fp8(qkv_c,
                        fp8_meta["scaling_fwd"],
                        META_QKV, fp8_dtype_forward).view(qkv.shape)
                    q_fp8, k_fp8, v_fp8 = _SplitAlongDim.apply(qkv_fp8, dim, [1,1,1])
                    q_fp8, k_fp8, v_fp8 = [x.squeeze(dim) for x in [q_fp8, k_fp8, v_fp8]]
                if qkv_group == 2:
                    q_fp8 = cast_to_fp8(q,
                        fp8_meta["scaling_fwd"],
                        META_QKV, fp8_dtype_forward).view(q.shape)
                    dim = qkv_layout.split('_')[1].find('2')
                    kv = _combine_tensors([k,v], dim)
                    kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1])
                    kv_fp8 = cast_to_fp8(kv_c,
                        fp8_meta["scaling_fwd"],
                        META_QKV, fp8_dtype_forward).view(kv.shape)
                    k_fp8, v_fp8 = _SplitAlongDim.apply(kv_fp8, dim, [1,1])
                    k_fp8, v_fp8 = [x.squeeze(dim) for x in [k_fp8, v_fp8]]
                if qkv_group == 3:
                    q_fp8 = cast_to_fp8(q,
                        fp8_meta["scaling_fwd"],
                        META_QKV, fp8_dtype_forward).view(q.shape)
                    k_fp8 = cast_to_fp8(k,
                        fp8_meta["scaling_fwd"],
                        META_QKV, fp8_dtype_forward).view(k.shape)
                    v_fp8 = cast_to_fp8(v,
                        fp8_meta["scaling_fwd"],
                        META_QKV, fp8_dtype_forward).view(v.shape)
            out_fp8, aux_ctx_tensors = fused_attn_fwd(
                is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
                q_fp8, k_fp8, v_fp8, fp8_dtype_forward, fused_attention_backend, attn_bias,
2782
                seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
2783
2784
2785
2786
2787
2788
2789
2790
2791
2792
2793
2794
2795
2796
2797
2798
2799
2800
2801
2802
2803
2804
2805
2806
2807
2808
2809
2810
2811
2812
2813
2814
2815
2816
2817
2818
2819
2820
2821
2822
2823
2824
2825
2826
2827
2828
2829
2830
2831
2832
2833
2834
2835
2836
2837
2838
2839
2840
2841
2842
2843
2844
2845
2846
2847
2848
                fp8_meta["scaling_fwd"].scale_inv[META_QKV],
                fp8_meta["scaling_fwd"].scale_inv[META_S],
                fp8_meta["scaling_fwd"].scale[META_S],
                fp8_meta["scaling_fwd"].scale[META_O],
                fp8_meta["scaling_fwd"].amax_history[0][META_S],
                fp8_meta["scaling_fwd"].amax_history[0][META_O],
                attn_scale, dropout_p, fast_zero_fill, qkv_layout,
                attn_bias_type, attn_mask_type, rng_gen)
            if fp8_meta["recipe"].fp8_mha:
                out_ret = Float8Tensor(data=out_fp8,
                    fp8_meta=fp8_meta,
                    fp8_meta_forward=True,
                    fp8_meta_index=META_O,
                    fp8_dtype=fp8_dtype_forward,
                    dtype=q.dtype,
                )
            else:
                out_ret = cast_from_fp8(
                    out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]),
                    fp8_meta["scaling_fwd"], META_O,
                    fp8_dtype_forward, qkv_dtype).view(out_fp8.shape)
            out_save = out_ret

            if fp8_meta["recipe"].fp8_mha and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
                # 1: qkv packed, 2: kv packed, 3: qkv separate
                qkv_group = len(qkv_layout.split('_'))
                if qkv_group == 1:
                    dim = qkv_layout.find('3')
                    qkv = _combine_tensors([q,k,v], dim)
                    qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1])
                    qkv_no_fp8 = cast_from_fp8(qkv_c._data,
                        fp8_meta["scaling_fwd"],
                        META_QKV, fp8_dtype_forward, TE_DType[qkv.dtype]).view(qkv.shape)
                    q, k, v = _SplitAlongDim.apply(qkv_no_fp8, dim, [1,1,1])
                    q, k, v = [x.squeeze(dim) for x in [q, k, v]]
                if qkv_group == 2:
                    q = cast_from_fp8(q._data,
                        fp8_meta["scaling_fwd"],
                        META_QKV, fp8_dtype_forward, TE_DType[q.dtype]).view(q.shape)
                    dim = qkv_layout.split('_')[1].find('2')
                    kv = _combine_tensors([k,v], dim)
                    kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1])
                    kv_no_fp8 = cast_from_fp8(kv_c._data,
                        fp8_meta["scaling_fwd"],
                        META_QKV, fp8_dtype_forward, TE_DType[kv.dtype]).view(kv.shape)
                    k, v = _SplitAlongDim.apply(kv_no_fp8, dim, [1,1])
                    k, v = [x.squeeze(dim) for x in [k, v]]
                if qkv_group == 3:
                    q = cast_from_fp8(q._data,
                        fp8_meta["scaling_fwd"],
                        META_QKV, fp8_dtype_forward, TE_DType[q.dtype]).view(q.shape)
                    k = cast_from_fp8(k._data,
                        fp8_meta["scaling_fwd"],
                        META_QKV, fp8_dtype_forward, TE_DType[k.dtype]).view(k.shape)
                    v = cast_from_fp8(v._data,
                        fp8_meta["scaling_fwd"],
                        META_QKV, fp8_dtype_forward, TE_DType[v.dtype]).view(v.shape)
                out_save = cast_from_fp8(
                    out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]),
                    fp8_meta["scaling_fwd"], META_O,
                    fp8_dtype_forward, qkv_dtype).view(out_fp8.shape)

            fp8_tensors = (q_fp8, k_fp8, v_fp8, out_fp8,
                fp8_meta["scaling_fwd"].scale.clone(),
                fp8_meta["scaling_fwd"].scale_inv.clone())
        else:
2849
            logger.debug("Running forward in %s",q.dtype)
2850
2851
2852
            out_ret, aux_ctx_tensors = fused_attn_fwd(
                is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
                q, k, v, qkv_dtype, fused_attention_backend, attn_bias,
2853
                seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
2854
2855
2856
2857
2858
                None, None, None, None, None, None,
                attn_scale, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type,
                rng_gen)
            out_save = out_ret
            fp8_tensors = (None, None, None, None, None, None)
2859

2860
2861
        from .cpu_offload import CPUOffloadEnabled
        if CPUOffloadEnabled:
2862
            tensor_list = [q, k, v, out_save, cu_seqlens_q, cu_seqlens_kv]
2863
2864
2865
2866
2867
            qkv_layout = 'sbhd_sbhd_sbhd'
            for tensor in tensor_list:
                if tensor is not None:
                    tensor.activation_offloading = True

2868
2869
        ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
        qkvo_tensors = (q, k, v, out_save) if not ctx.fp8 else (None, None, None, None)
2870
        ctx.save_for_backward(*qkvo_tensors, cu_seqlens_q, cu_seqlens_kv,
2871
            seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
2872
            *fp8_tensors, *aux_ctx_tensors)
2873
        ctx.fp8_meta = fp8_meta
2874
2875
2876
2877
2878
2879
2880
2881
2882
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_kv = max_seqlen_kv
        ctx.qkv_dtype = qkv_dtype
        ctx.attn_scale = attn_scale
        ctx.dropout_p = dropout_p
        ctx.fast_zero_fill = fast_zero_fill
        ctx.qkv_layout = qkv_layout
        ctx.attn_bias_type = attn_bias_type
        ctx.attn_mask_type = attn_mask_type
2883
2884
        ctx.fused_attention_backend = \
            fused_attention_backend if ctx.fp8 else FusedAttnBackend["F16_arbitrary_seqlen"]
2885
2886
        ctx.use_FAv2_bwd = use_FAv2_bwd

2887
        return out_ret
2888
2889
2890

    @staticmethod
    def backward(ctx, d_out):
2891
        logger = logging.getLogger("FusedAttnFunc")
2892
2893
2894
2895
2896
2897
        if ctx.fp8_meta["recipe"].fp8_mha:
            assert (isinstance(d_out, Float8Tensor)
                ), "Gradient of the DPA output must be in Float8Tensor type for FP8 MHA."
            d_out_f8tensor = d_out
            d_out = d_out._data

2898
        d_out = d_out.contiguous()
2899
2900
2901
        (q, k, v, out, cu_seqlens_q, cu_seqlens_kv,
            seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
            q_fp8, k_fp8, v_fp8, out_fp8,
2902
2903
2904
            fwd_scales, fwd_scale_invs, *aux_ctx_tensors) = ctx.saved_tensors
        if not aux_ctx_tensors[0].is_contiguous():
            aux_ctx_tensors[0] = aux_ctx_tensors[0].contiguous()
2905
        if ctx.use_FAv2_bwd:
2906
            softmax_lse, rng_state = aux_ctx_tensors
2907
2908
2909
2910
2911
2912
2913
2914
2915
2916
            dq = torch.empty_like(q)
            dk = torch.empty_like(k)
            dv = torch.empty_like(v)
            maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
            d_out, q, k, v, out = [maybe_contiguous(x)
                for x in (d_out, q, k, v, out)]
            flash_attn_cuda_bwd(
                d_out, q, k, v, out, softmax_lse, dq, dk, dv,
                cu_seqlens_q, cu_seqlens_kv, ctx.max_seqlen_q, ctx.max_seqlen_kv,
                ctx.dropout_p, ctx.attn_scale, False,
2917
                "causal" in ctx.attn_mask_type, None, rng_state
2918
2919
2920
2921
2922
            )
            dq = dq[..., :d_out.shape[-1]]
            dk = dk[..., :d_out.shape[-1]]
            dv = dv[..., :d_out.shape[-1]]
        else:
2923
2924
            with torch.cuda.nvtx.range("_FusedAttn"):
                if ctx.fp8:
2925
                    logger.debug("Running backward in FP8")
2926
2927
2928
2929
2930
2931
2932
2933
2934
2935
2936
2937
2938
2939
                    fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True)
                    fp8_dtype_backward = get_fp8_te_dtype(
                        ctx.fp8_meta["recipe"], fprop_tensor=False)
                    if ctx.fp8_meta["recipe"].fp8_mha:
                        d_out_fp8 = d_out
                        ctx.fp8_meta['scaling_bwd'].scale_inv[META_DO] = d_out_f8tensor._scale_inv
                    else:
                        d_out_fp8 = cast_to_fp8(
                            d_out.view(-1, d_out.shape[-2] * d_out.shape[-1]),
                            ctx.fp8_meta["scaling_bwd"], META_DO, fp8_dtype_backward
                            ).view(d_out.shape)
                    dq_fp8, dk_fp8, dv_fp8, *rest = fused_attn_bwd(
                        ctx.max_seqlen_q, ctx.max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
                        q_fp8, k_fp8, v_fp8, out_fp8, d_out_fp8,
2940
                        fp8_dtype_forward, fp8_dtype_backward, aux_ctx_tensors,
2941
                        ctx.fused_attention_backend,
2942
                        seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
2943
2944
2945
2946
2947
2948
2949
2950
2951
2952
2953
2954
2955
2956
2957
2958
2959
2960
2961
2962
2963
2964
2965
2966
2967
2968
2969
2970
2971
2972
2973
2974
2975
2976
2977
2978
2979
2980
2981
2982
2983
2984
2985
2986
2987
2988
2989
2990
2991
2992
2993
2994
2995
2996
2997
2998
2999
3000
3001
3002
3003
3004
3005
3006
3007
3008
3009
3010
3011
3012
3013
3014
3015
3016
                        fwd_scale_invs[META_QKV], # d_scale_qkv,
                        fwd_scale_invs[META_S], # d_scale_s,
                        fwd_scale_invs[META_O], # d_scale_o,
                        ctx.fp8_meta['scaling_bwd'].scale_inv[META_DO], # d_scale_do
                        ctx.fp8_meta['scaling_bwd'].scale_inv[META_DP], # d_scale_dp
                        fwd_scales[META_S], # q_scale_s
                        ctx.fp8_meta['scaling_bwd'].scale[META_DP], # q_scale_dp
                        ctx.fp8_meta['scaling_bwd'].scale[META_DQKV], # q_scale_dqkv
                        ctx.fp8_meta['scaling_bwd'].amax_history[0][META_DP], # amax_dp
                        ctx.fp8_meta['scaling_bwd'].amax_history[0][META_DQKV], # amax_dqkv
                        ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill,
                        ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type)
                    if ctx.fp8_meta["recipe"].fp8_mha:
                        dq = Float8Tensor(data=dq_fp8,
                            fp8_meta=ctx.fp8_meta,
                            fp8_meta_forward=False,
                            fp8_meta_index=META_DQKV,
                            fp8_dtype=fp8_dtype_backward,
                            dtype=d_out_f8tensor.dtype,
                            )
                        dk = Float8Tensor(data=dk_fp8,
                            fp8_meta=ctx.fp8_meta,
                            fp8_meta_forward=False,
                            fp8_meta_index=META_DQKV,
                            fp8_dtype=fp8_dtype_backward,
                            dtype=d_out_f8tensor.dtype,
                            )
                        dv = Float8Tensor(data=dv_fp8,
                            fp8_meta=ctx.fp8_meta,
                            fp8_meta_forward=False,
                            fp8_meta_index=META_DQKV,
                            fp8_dtype=fp8_dtype_backward,
                            dtype=d_out_f8tensor.dtype,
                            )
                    else:
                        qkv_group = len(ctx.qkv_layout.split('_'))
                        if qkv_group == 1:
                            dim = ctx.qkv_layout.find('3')
                            dqkv_fp8 = _combine_tensors([dq_fp8,dk_fp8,dv_fp8], dim)
                            dqkv_c_fp8 = dqkv_fp8.view(-1,
                                dqkv_fp8.shape[-3] * dqkv_fp8.shape[-2] * dqkv_fp8.shape[-1])
                            dqkv = cast_from_fp8(dqkv_c_fp8,
                                ctx.fp8_meta["scaling_bwd"], META_DQKV,
                                fp8_dtype_backward, ctx.qkv_dtype).view(dqkv_fp8.shape)
                            dq, dk, dv = _SplitAlongDim.apply(dqkv, dim, [1,1,1])
                            dq, dk, dv = [x.squeeze(dim) for x in [dq, dk, dv]]
                        if qkv_group == 2:
                            dq = cast_from_fp8(
                                dq_fp8.view(-1, dq_fp8.shape[-2] * dq_fp8.shape[-1]),
                                ctx.fp8_meta["scaling_bwd"], META_DQKV,
                                fp8_dtype_backward, ctx.qkv_dtype).view(dq_fp8.shape)
                            dim = ctx.qkv_layout.split('_')[1].find('2')
                            dkv_fp8 = _combine_tensors([dk_fp8,dv_fp8], dim)
                            dkv_c_fp8 = dkv_fp8.view(-1,
                                dkv_fp8.shape[-3] * dkv_fp8.shape[-2] * dkv_fp8.shape[-1])
                            dkv = cast_from_fp8(dkv_c_fp8,
                                ctx.fp8_meta["scaling_bwd"], META_DQKV,
                                fp8_dtype_backward, ctx.qkv_dtype).view(dkv_fp8.shape)
                            dk, dv = _SplitAlongDim.apply(dkv, dim, [1,1])
                            dk, dv = [x.squeeze(dim) for x in [dk, dv]]
                        if qkv_group == 3:
                            dq = cast_from_fp8(
                                dq_fp8.view(-1, dq_fp8.shape[-2] * dq_fp8.shape[-1]),
                                ctx.fp8_meta["scaling_bwd"], META_DQKV,
                                fp8_dtype_backward, ctx.qkv_dtype).view(dq_fp8.shape)
                            dk = cast_from_fp8(
                                dk_fp8.view(-1, dk_fp8.shape[-2] * dk_fp8.shape[-1]),
                                ctx.fp8_meta["scaling_bwd"], META_DQKV,
                                fp8_dtype_backward, ctx.qkv_dtype).view(dk_fp8.shape)
                            dv = cast_from_fp8(
                                dv_fp8.view(-1, dv_fp8.shape[-2] * dv_fp8.shape[-1]),
                                ctx.fp8_meta["scaling_bwd"], META_DQKV,
                                fp8_dtype_backward, ctx.qkv_dtype).view(dv_fp8.shape)
                else:
3017
                    logger.debug("Running backward in %s",q.dtype)
3018
3019
3020
3021
3022
                    if d_out.dtype == torch.uint8:
                        d_out = d_out_f8tensor.from_float8(q.dtype)
                    dq, dk, dv, *rest = fused_attn_bwd(
                        ctx.max_seqlen_q, ctx.max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
                        q, k, v, out, d_out,
3023
                        ctx.qkv_dtype, ctx.qkv_dtype, aux_ctx_tensors,
3024
                        ctx.fused_attention_backend,
3025
                        seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
3026
3027
3028
                        None, None, None, None, None, None, None, None, None, None,
                        ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill,
                        ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type)
3029

3030
3031
        # if no_bias or alibi, return dqkv
        if ctx.attn_bias_type in ["no_bias", "alibi"]:
3032
3033
            return (None, None, None, None, None, None,
                    None, None, None, dq, dk, dv, None, None, None,
3034
3035
3036
                    None, None, None, None, None, None,
                    None, None, None, None, None, None)
        # else, return (dqkv, dbias)
3037
3038
        return (None, None, None, None, None, None,
                None, None, None, dq, dk, dv, None, rest[0], None,
3039
3040
3041
                None, None, None, None, None, None,
                None, None, None, None, None, None)

3042

3043
class FusedAttention(TransformerEngineBaseModule):
3044
3045
3046
3047
3048
3049
3050
3051
3052
    """Dot product attention, with multiple backends:

    1. FusedAttnBackend["F16_max512_seqlen"]
       cuDNN based fused attention for FP16/BF16 and <=512 sequence length.
    2. FusedAttnBackend["F16_arbitrary_seqlen"]
       cuDNN based fused attention for FP16/BF16 and any sequence length.

    Support matrix:

3053
3054
3055
3056
    | backend       | 1                       | 2                              |
    | flash based   | no                      | yes                            |
    | cuDNN based   | yes                     | yes                            |
    | qkv dtype     | fp16/bf16               | fp16/bf16                      |
3057
    | attn_type     | self/cross              | self/cross                     |
3058
    | qkv_layout    |                         |                                |
3059
    |  - (q,k,v)    | sb3hd, bs3hd            | sb3hd, bs3hd, sbh3d, bsh3d     |
3060
    |               | sbhd_sb2hd, bshd_bs2hd  | sbhd_sb2hd, bshd_bs2hd         |
3061
3062
    |               | bshd_bshd_bshd          | sbhd_sbh2d, bshd_bsh2d         |
    |               |                         | sbhd_sbhd_sbhd, bshd_bshd_bshd |
3063
3064
    | mask_type     | causal/padding/no_mask  | causal/padding/no_mask         |
    | bias_type     | post_scale_bias/no_bias | post_scale_bias/alibi/no_bias  |
3065
    | dropout       | yes                     | yes                            |
3066
3067
    | max_seqlen    | <=512, multiple of 64   | any, multiple of 64            |
    | head_dim      | 64                      | <=128, multiple of 8           |
3068
    | output dtype  | fp16/bf16               | fp16/bf16                      |
3069
3070
3071
3072
    """

    def __init__(
        self,
3073
        softmax_scale: float,
3074
3075
3076
        attention_dropout: float = 0.0,
        attention_dropout_ctx: Optional[Callable] = nullcontext,
        attention_type: str = "self",
3077
3078
        layer_number: Optional[int] = None,
        deterministic: bool = False,
3079
3080
3081
    ) -> None:
        super().__init__()

3082
        self.logger = logging.getLogger("FusedAttention")
3083
        self.softmax_scale = softmax_scale
3084
3085
3086
        self.attention_dropout = attention_dropout
        self.attention_dropout_ctx = attention_dropout_ctx
        self.attention_type = attention_type
3087
        self.use_FAv2_bwd = (os.getenv("NVTE_FUSED_ATTN_USE_FAv2_BWD", "0") == "1"
Tim Moon's avatar
Tim Moon committed
3088
                        and get_device_compute_capability() == (9, 0))
3089
3090
3091
3092
3093
3094
3095
3096
3097
3098
3099
3100
3101
3102
3103
3104
        self.layer_number = 1 if layer_number is None else layer_number
        if deterministic:
            # workspace optimization path is deterministic
            os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "-1"

        # CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT
        # - unset:       enables workspace optimization when required workspace is <= 256MB
        #                or when bias gradient needs to be computed
        # - n:           enables workspace optimization when required workspace is <= n bytes
        # - -1:          enables workspace optimization always
        # - 0:           disables workspace optimization always
        if "NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT" in os.environ:
            if os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] == "0":
                os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "0"
            if os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] == "1":
                os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "-1"
3105

3106
3107
3108
3109
3110
3111
3112
3113
3114
3115
3116
        def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument
            """
            Temporarily remove fused_attention._extra_state as a missing key
            when loading older TransformerEngine checkpoints. Will phase out
            this hook in TransformerEngine 2.0.
            """
            for key in incompatible_keys.missing_keys:
                if 'fused_attention._extra_state' in key:
                    incompatible_keys.missing_keys.remove(key)
        self.register_load_state_dict_post_hook(remove_extra_states_check)

3117
3118
3119
3120
3121
3122
    def get_fp8_weights_scratchpad(
        self,
        is_first_microbatch: Union[bool, None],
    ) -> List[Float8Tensor]:
        """Needs override."""

3123
    @no_torch_dynamo()
3124
3125
3126
3127
3128
    def forward(
        self,
        query_layer: torch.Tensor,
        key_layer: torch.Tensor,
        value_layer: torch.Tensor,
3129
3130
3131
        qkv_layout: str = "sbh3d",
        cu_seqlens_q: Optional[torch.Tensor] = None,
        cu_seqlens_kv: Optional[torch.Tensor] = None,
3132
3133
3134
3135
        seq_offsets_q: Optional[torch.Tensor] = None,
        seq_offsets_k: Optional[torch.Tensor] = None,
        seq_offsets_v: Optional[torch.Tensor] = None,
        seq_offsets_o: Optional[torch.Tensor] = None,
3136
3137
        max_seqlen_q: Optional[int] = None,
        max_seqlen_kv: Optional[int] = None,
3138
        attn_mask_type: str = "causal",
3139
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
3140
3141
        fused_attention_backend:
            tex.NVTE_Fused_Attn_Backend = tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend,
3142
3143
3144
        core_attention_bias_type: str = "no_bias",
        core_attention_bias: Optional[torch.Tensor] = None,
        fast_zero_fill: bool = True,
3145
3146
3147
        cp_group: Optional[dist_group_type] = None,
        cp_global_ranks: List[int] = None,
        cp_stream: torch.cuda.Stream = None,
3148
        is_first_microbatch: Optional[bool] = None,
3149
3150
    ) -> torch.Tensor:
        """fused attention fprop"""
3151
        assert (fused_attention_backend
3152
3153
            != tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend
            ), 'No fused attention backend supports this input combination!'
3154
        assert (
3155
3156
3157
            (query_layer.dtype in [torch.float16, torch.bfloat16, torch.uint8])
            and (key_layer.dtype in [torch.float16, torch.bfloat16, torch.uint8])
            and (value_layer.dtype in [torch.float16, torch.bfloat16, torch.uint8])
3158
3159
3160
3161
            ), 'FusedAttention only supports FP16 and BF16 data types.'
        assert (
            query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda
            ), 'FusedAttention only supports CUDA tensors.'
3162
3163
3164
3165
        assert (
            qkv_layout in QKVLayouts
            ), f"FusedAttention does not support qkv_layout = {qkv_layout}!"

3166
3167
        context_parallel = (cp_group is not None) and (get_distributed_world_size(cp_group) != 1)

3168
        qkv_format = ''.join([i for i in qkv_layout.split('_')[0] if i.isalpha()])
3169

3170
3171
3172
3173
3174
3175
3176
        if qkv_format in ['sbhd', 'bshd']:
            if qkv_format == 'sbhd':
                batch_size, max_seqlen_q, max_seqlen_kv = (
                    query_layer.shape[1], query_layer.shape[0], key_layer.shape[0])
            if qkv_format == 'bshd':
                batch_size, max_seqlen_q, max_seqlen_kv = (
                    query_layer.shape[0], query_layer.shape[1], key_layer.shape[1])
3177
            if 'padding' in attn_mask_type:
3178
3179
                assert not context_parallel, "Padding mask not supported with context parallelism!"

3180
3181
3182
3183
3184
                if cu_seqlens_q is None or cu_seqlens_kv is None:
                    if attention_mask is None:
                        raise RuntimeError(
                            "Please provide attention_mask or cu_seqlens for padding!"
                        )
3185
                    if self.attention_type == "self":
3186
3187
                        cu_seqlens_q = get_cu_seqlens(attention_mask)
                        cu_seqlens_kv = cu_seqlens_q
3188
                    else:
3189
3190
                        cu_seqlens_q = get_cu_seqlens(attention_mask[0])
                        cu_seqlens_kv = get_cu_seqlens(attention_mask[1])
3191
            else:
3192
3193
3194
3195
3196
3197
3198
3199
3200
3201
3202
3203
                if cu_seqlens_q is None:
                    cu_seqlens_q = _get_full_cu_seqlens(
                        batch_size,
                        max_seqlen_q,
                        query_layer.device,
                    )
                if cu_seqlens_kv is None:
                    cu_seqlens_kv = _get_full_cu_seqlens(
                        batch_size,
                        max_seqlen_kv,
                        key_layer.device,
                    )
3204
3205
3206
3207
3208
3209
3210
3211
3212
        if qkv_format == 'thd':
            assert (max_seqlen_q is not None
                and max_seqlen_kv is not None
                and cu_seqlens_q is not None
                and cu_seqlens_kv is not None
                ), "max_seqlen_q/kv and cu_seqlens_q/kv can not be None when qkv_format is thd!"
            if (seq_offsets_q is None
                or seq_offsets_k is None
                or seq_offsets_v is None
3213
3214
                or seq_offsets_o is None
                or context_parallel):
3215
                qkv_group = ''.join([x for x in qkv_layout if x not in 'bst'])
3216
                qkv_group = 'hd_hd_hd' if context_parallel else qkv_group
3217
3218
3219
3220
3221
3222
3223
3224
3225
3226
3227
3228
3229
3230
3231
3232
                num_heads = query_layer.shape[-2]
                num_gqa_groups = key_layer.shape[-2]
                head_dim = query_layer.shape[-1]
                seq_offsets_o = num_heads * head_dim * cu_seqlens_q
                if qkv_group == 'hd_hd_hd':
                    seq_offsets_q = num_heads * head_dim * cu_seqlens_q
                    seq_offsets_k = num_gqa_groups * head_dim * cu_seqlens_kv
                    seq_offsets_v = num_gqa_groups * head_dim * cu_seqlens_kv
                if qkv_group in ['3hd', 'h3d']:
                    seq_offsets_q = num_heads * head_dim * 3 * cu_seqlens_q
                    seq_offsets_k = num_heads * head_dim * 3 * cu_seqlens_q
                    seq_offsets_v = num_heads * head_dim * 3 * cu_seqlens_q
                if qkv_group in ['hd_2hd', 'hd_h2d']:
                    seq_offsets_q = num_heads * head_dim * cu_seqlens_q
                    seq_offsets_k = num_gqa_groups * head_dim * 2 * cu_seqlens_kv
                    seq_offsets_v = num_gqa_groups * head_dim * 2 * cu_seqlens_kv
3233
3234
3235

        qkv_dtype = TE_DType[query_layer.dtype]

3236
        use_FAv2_bwd = (self.use_FAv2_bwd
3237
                and (core_attention_bias_type == "no_bias")
3238
3239
                and (fused_attention_backend
                    == tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen))
3240
3241
3242
3243
3244

        if context_parallel:
            assert (fused_attention_backend
                == tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen
                ), f"{fused_attention_backend} does not work with context parallelism!"
3245
3246
3247
3248
3249
            assert (
                core_attention_bias_type not in ["alibi"]
            ), f"{core_attention_bias_type} is not supported with context parallelism!"
            query_layer, key_layer, value_layer = [x.contiguous()
                for x in (query_layer, key_layer, value_layer)]
3250
3251
3252
3253
3254
3255
            with self.attention_dropout_ctx():
                output = attn_forward_func_with_cp(
                    self.training,
                    query_layer, key_layer, value_layer,
                    cu_seqlens_q, cu_seqlens_kv,
                    max_seqlen_q, max_seqlen_kv,
3256
                    seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
3257
3258
                    self.attention_dropout if self.training else 0.0,
                    cp_group, cp_global_ranks, cp_stream,
3259
                    softmax_scale=self.softmax_scale,
3260
                    qkv_format=qkv_format,
3261
                    attn_mask_type=attn_mask_type,
3262
3263
                    attn_bias_type=core_attention_bias_type,
                    attn_bias=core_attention_bias,
3264
3265
3266
                    use_fused_attention=True,
                )
        else:
3267
3268
3269
3270
3271
3272
3273
3274
3275
3276
            with self.prepare_forward(query_layer,
                is_first_microbatch,
                num_gemms=3,
                allow_non_contiguous=True) as query_layer:
                with self.attention_dropout_ctx():
                    forced_fp8_dpa = ""
                    if self.fp8_meta["recipe"].fp8_mha:
                        if not self.fp8_meta["recipe"].fp8_dpa:
                            self.fp8_meta["recipe"].fp8_dpa = True
                            forced_fp8_dpa = " (forced)"
3277
3278
3279
3280
3281
3282
3283
3284
                    if fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_FP8:
                        self.logger.debug(
                            "Running with fp8_recipe.fp8_mha=%s, "
                            "fp8_recipe.fp8_dpa=%s%s, and NVTE_FP8_DPA_BWD=%s",
                            self.fp8_meta["recipe"].fp8_mha,
                            self.fp8_meta["recipe"].fp8_dpa,
                            forced_fp8_dpa,
                            int(os.getenv("NVTE_FP8_DPA_BWD", "1")))
3285
3286
3287
3288
                    output = FusedAttnFunc.apply(
                        self.training,
                        max_seqlen_q, max_seqlen_kv,
                        cu_seqlens_q, cu_seqlens_kv,
3289
                        seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
3290
3291
3292
                        query_layer, key_layer, value_layer,
                        qkv_dtype,
                        core_attention_bias,
3293
                        self.softmax_scale,
3294
3295
3296
3297
3298
3299
3300
3301
3302
3303
3304
                        self.attention_dropout if self.training else 0.0,
                        fast_zero_fill,
                        qkv_layout,
                        core_attention_bias_type,
                        attn_mask_type,
                        None, # rng_gen
                        fused_attention_backend,
                        use_FAv2_bwd,
                        self.fp8 and self.fp8_meta["recipe"].fp8_dpa,
                        self.fp8_meta,
                    )
3305

3306
3307
        # ...hd -> ...(hd)
        return output.view(*output.shape[:-2], -1)
3308
3309


3310
3311
3312
3313
3314
3315
3316
class DotProductAttention(torch.nn.Module):
    """Allows the model to jointly attend to information from different
    representation subspaces as described in the paper:
    `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.

    .. note::

3317
        Argument :attr:`attention_mask` in the `forward` call is only used when
3318
        :attr:`attn_mask_type` includes '"padding"' or `"arbitrary"`.
3319
3320
3321

    .. warning::

3322
        FlashAttention uses a non-deterministic algorithm for optimal performance. To observe
3323
        deterministic behavior at the cost of performance, use FlashAttention version >= `2.4.1`
3324
3325
        and set the environment variable :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO=0`. In order
        to disable`flash-attn` entirely, set :attr:`NVTE_FLASH_ATTN=0`.
3326
3327
3328
3329
3330
3331

    Parameters
    ----------
    num_attention_heads : int
                         number of attention heads in the transformer layer.
    kv_channels : int
3332
                number of key-query-value channels per attention head.
3333
3334
3335
3336
3337
3338
3339
3340
    num_gqa_groups : Optional[int] = None
                    number of GQA groups in the transformer layer.
                    Grouped Query Attention is described in
                    `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
                    This only affects the keys and values, not the queries.
                    GQA-1 is equivalent to Multi-Query Attention
                    (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
                    is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
3341
3342
    attention_dropout: float, default = 0.0
                      dropout probability for the dropout op during multi-head attention.
3343
    attn_mask_type: str, default = `causal`
3344
3345
3346
3347
3348
3349
3350
3351
3352
3353
3354
3355
3356
3357
                   type of attention mask passed into softmax operation, options are "`no_mask`",
                   "`padding`", "`causal`", "`padding,causal`", "`causal,padding`", and
                   "`arbitrary`", where "`padding,causal`" and "`causal,padding`" are equivalent.
                   This arg can be overridden by :attr:`attn_mask_type` in the `forward` method.
                   It is useful for cases involving compilation/tracing, e.g. ONNX export, and the
                   forward arg is useful for dynamically changing mask types, e.g. a different mask
                   for training and inference. For "`no_mask`", no attention mask is applied. For
                   "`causal`" or the causal mask in "`padding,causal`", TransformerEngine calculates
                   and applies an upper triangular mask to the softmax input. No user input is
                   needed. For "`padding`" or the padding mask in "`padding,causal`", users need to
                   provide the locations of padded tokens either via :attr:`cu_seqlens_q` and
                   :attr:`cu_seqlens_kv` in the shape of [batch_size + 1] or :attr:`attention_mask`
                   in the shape [batch_size, 1, 1, max_seq_len]. For the "`arbitrary`" mask, users
                   need to provide a mask that is broadcastable to the shape of softmax input.
3358
3359
3360
3361
3362
3363
    window_size: Optional[Tuple[int, int]], default = `None`
                sliding window size for local attention, where query at position i attends to keys
                in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
                + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
                window and causal mask specifically. Similar to :attr:`attn_mask_type`, it can
                be overridden by :attr:`window_size` in `forward` as well.
3364
3365
    attention_type: str, default = `self`
                   type of attention, either "`self`" and "`cross`".
3366
3367
3368
    layer_number: int, default = `None`
                 layer number of the current `DotProductAttention` when multiple such modules
                 are concatenated, for instance in consecutive transformer blocks.
3369
3370
3371
3372
3373
3374
3375
3376
3377
3378
    qkv_format: str, default = `sbhd`
               dimension format for `query_layer`, `key_layer` and `value_layer`,
               {`sbhd`, `bshd`, `thd`}. `s` stands for the sequence length, `b` batch size,
               `h` the number of heads, `d` head size, and `t` the total number of sequences
               in a batch, with `t = sum(s_i), for i = 0...b-1`. `sbhd` and `bshd` formats
               are used for when sequences in a batch are of equal length or padded to
               equal length, and the `thd` format is used for when sequences in a batch
               have different lengths. Please note that these formats do not reflect how
               tensors `query_layer`, `key_layer`, `value_layer` are laid out in memory.
               For that, please use `_get_qkv_layout` to gain the layout information.
3379
3380
3381
    softmax_scale: Optional[float], default = `None`
                softmax scale for the attention scores. If `None`, defaults to
                `1.0 / math.sqrt(kv_channels)`.
3382
3383
3384
3385
3386
3387
3388
3389
3390

    Parallelism parameters
    ----------------------
    sequence_parallel : bool, default = `False`
                       if set to `True`, uses sequence parallelism.
    tp_size : int, default = 1
             tensor parallel world size.
    tp_group : ProcessGroup, default = `None`
              tensor parallel process group.
3391
3392
3393
3394
3395
3396
3397
3398
3399
    cp_group : ProcessGroup, default = `None`
              context parallel process group.
    cp_global_ranks : list of global rank IDs, default = `None`
                     global rank IDs of GPUs that are in cp_group.
    cp_stream : CUDA stream, default = `None`
               context parallelism splits flash attention into multiple steps for
               compute and communication overlapping. To address the wave quantization
               issue of each split step, we add an additional CUDA stream so that we
               can overlap two flash attention kernels.
3400
3401
3402
3403
3404
3405
    """

    def __init__(
        self,
        num_attention_heads: int,
        kv_channels: int,
3406
        num_gqa_groups: Optional[int] = None,
3407
        attention_dropout: float = 0.0,
3408
        qkv_format: str = "sbhd",
3409
        attn_mask_type: str = "causal",
3410
        window_size: Optional[Tuple[int, int]] = None,
3411
3412
3413
3414
3415
        sequence_parallel: bool = False,
        tp_size: int = 1,
        get_rng_state_tracker: Optional[Callable] = None,
        tp_group: Optional[dist_group_type] = None,
        layer_number: Optional[int] = None,
3416
        attention_type: str = "self",
3417
        cp_group: Optional[dist_group_type] = None,
3418
        cp_global_ranks: List[int] = None,
3419
        cp_stream: torch.cuda.Stream = None,
3420
        softmax_scale: Optional[float] = None,
3421
3422
3423
    ) -> None:
        super().__init__()

3424
        self.logger = logging.getLogger("DotProductAttention")
3425
        self.qkv_format = qkv_format
3426
3427
3428
        attn_mask_type = attn_mask_type.replace(",","_")
        if attn_mask_type == "causal_padding":
            attn_mask_type = "padding_causal"
3429
        self.attn_mask_type = attn_mask_type
3430
3431
        self.window_size = window_size
        self.window_size = check_set_window_size(attn_mask_type, self.window_size)
3432
        self.tp_size = tp_size if tp_group is None else get_distributed_world_size(tp_group)
3433
3434
        self.tp_group = tp_group
        self.get_rng_state_tracker = get_rng_state_tracker
3435
        self.num_attention_heads = num_attention_heads
3436
        self.layer_number = 1 if layer_number is None else layer_number
3437
3438
3439
        self.cp_group = cp_group
        self.cp_global_ranks = cp_global_ranks
        self.cp_stream = cp_stream
3440

3441
        self.hidden_size_per_attention_head = kv_channels
3442

3443
3444
        self.num_gqa_groups = (
            num_attention_heads if num_gqa_groups is None else num_gqa_groups
3445
        )
3446
3447
3448
3449
        self.num_gqa_groups_per_partition = int(self.num_gqa_groups // tp_size)

        assert (num_attention_heads % self.num_gqa_groups == 0
                ), "The number of attention heads must be divisible by the number of GQA groups!"
3450

3451
        self.rng_states_tracker = None
3452
3453
3454
        if sequence_parallel or get_rng_state_tracker is None:
            attention_dropout_ctx = nullcontext
        else:
3455
3456
3457
            self.rng_states_tracker = get_rng_state_tracker()
            set_all_rng_states(self.rng_states_tracker.get_states())
            attention_dropout_ctx = self.rng_states_tracker.fork
3458

3459
3460
        if softmax_scale is None:
            softmax_scale = 1.0 / math.sqrt(kv_channels)
3461
3462

        self.device_compute_capability = get_device_compute_capability()
3463
3464
        self.deterministic = not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1"))) \
                             or torch.are_deterministic_algorithms_enabled()
3465

3466
3467
        self.use_flash_attention = (
            int(os.getenv("NVTE_FLASH_ATTN", "1"))
Tim Moon's avatar
Tim Moon committed
3468
            and self.device_compute_capability >= (8, 0)
3469
        )
3470
3471
3472
3473
3474
        if int(os.getenv("NVTE_FLASH_ATTN", "1")) == 0:
            self.logger.debug("Disabling FlashAttention due to NVTE_FLASH_ATTN=0")
        if self.device_compute_capability < (8, 0):
            self.logger.debug("Disabling FlashAttention for compute capability < sm80")

3475
        if not _flash_attn_2_4_1_plus and self.deterministic:
3476
            self.use_flash_attention = False
3477
            self.logger.warning(
3478
3479
3480
                "Disabling usage of FlashAttention since version <2.4.1 does not support "
                "deterministic execution. In order to use FA with deterministic behavior,"
                " please install FlashAttention version >=2.4.1."
3481
3482
            )

3483
3484
        self.use_fused_attention = (
            int(os.getenv("NVTE_FUSED_ATTN", "1"))
Tim Moon's avatar
Tim Moon committed
3485
            and self.device_compute_capability >= (8, 0)
3486
        )
3487
3488
3489
3490
        if int(os.getenv("NVTE_FUSED_ATTN", "1")) == 0:
            self.logger.debug("Disabling FusedAttention due to NVTE_FUSED_ATTN=0")
        if self.device_compute_capability < (8, 0):
            self.logger.debug("Disabling FusedAttention for compute capability < sm80")
3491

3492
3493
3494
3495
3496
3497
3498
        assert (
            attention_type in AttnTypes
        ), f"attention_type {attention_type} not supported"

        self.attention_type = attention_type
        self.attention_dropout = attention_dropout

3499
3500
3501
3502
3503
3504
        attn_kwargs = {
            "attention_dropout": attention_dropout,
            "attention_dropout_ctx": attention_dropout_ctx,
        }

        if self.use_flash_attention:
3505
            self.flash_attention = FlashAttention(softmax_scale,
3506
3507
3508
3509
3510
                                                  attention_type=attention_type,
                                                  layer_number=layer_number,
                                                  deterministic=self.deterministic,
                                                  **attn_kwargs)

3511
        # Instantiating three types since use of flash-attn and FusedAttention
3512
        # might be ruled out due to forward inputs.
3513
        if self.use_fused_attention:
3514
            self.fused_attention = FusedAttention(softmax_scale,
3515
3516
3517
                                                  attention_type=attention_type,
                                                  layer_number=layer_number,
                                                  deterministic=self.deterministic,
3518
                                                  **attn_kwargs)
3519

3520
        self.unfused_attention = UnfusedDotProductAttention(
3521
            softmax_scale, **attn_kwargs, layer_number=layer_number)
3522
3523
3524
3525
3526

    def _checkpointed_attention_forward(
        self,
        attention_func: Callable,
        *forward_args: Tuple[torch.Tensor, ...],
3527
        **forward_kwargs: Dict[str, Any],
3528
3529
3530
    ) -> torch.Tensor:
        """Forward method with activation checkpointing."""

3531
3532
        def custom_forward(*input_args, **input_kwargs):
            return attention_func(*input_args, **input_kwargs)
3533
3534
3535

        hidden_states = checkpoint(
            custom_forward,
3536
3537
3538
            distribute_saved_activations=False,
            get_rng_state_tracker=self.get_rng_state_tracker,
            tp_group=self.tp_group,
3539
            *forward_args,
3540
            **forward_kwargs,
3541
3542
3543
3544
        )

        return hidden_states

3545
3546
3547
3548
3549
3550
    def set_context_parallel_group(
        self,
        cp_group: Union[dist_group_type, None],
        cp_global_ranks: List[int],
        cp_stream: torch.cuda.Stream,
    ) -> None:
3551
3552
3553
3554
3555
3556
3557
3558
3559
3560
3561
3562
3563
        """
        Set the context parallel attributes for the given
        module before executing the forward pass.

        Parameters
        ----------
        cp_group : ProcessGroup
                  context parallel process group.
        cp_global_ranks : List[int]
                         list of global ranks in the context group.
        cp_stream : torch.cuda.Stream
                   cuda stream for context parallel execution.
        """
3564
3565
3566
3567
        self.cp_group = cp_group
        self.cp_global_ranks = cp_global_ranks
        self.cp_stream = cp_stream

3568
    @no_torch_dynamo(recursive=False)
3569
3570
3571
3572
3573
    def forward(
        self,
        query_layer: torch.Tensor,
        key_layer: torch.Tensor,
        value_layer: torch.Tensor,
3574
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
3575
3576
3577
        qkv_format: Optional[str] = None,
        cu_seqlens_q: Optional[torch.Tensor] = None,
        cu_seqlens_kv: Optional[torch.Tensor] = None,
3578
3579
3580
3581
        seq_offsets_q: Optional[torch.Tensor] = None,
        seq_offsets_k: Optional[torch.Tensor] = None,
        seq_offsets_v: Optional[torch.Tensor] = None,
        seq_offsets_o: Optional[torch.Tensor] = None,
3582
3583
        max_seqlen_q: Optional[int] = None,
        max_seqlen_kv: Optional[int] = None,
3584
        attn_mask_type: Optional[str] = None,
3585
        window_size: Optional[Tuple[int, int]] = None,
3586
        checkpoint_core_attention: bool = False,
3587
3588
        core_attention_bias_type: str = "no_bias",
        core_attention_bias: Optional[torch.Tensor] = None,
3589
        alibi_slopes: Optional[torch.Tensor] = None,
3590
        fast_zero_fill: bool = True,
3591
        inference_params: Optional[InferenceParams] = None,
3592
        is_first_microbatch: Optional[bool] = None,
3593
3594
3595
3596
3597
3598
    ) -> torch.Tensor:
        """
        Dot Product Attention Layer.

        .. note::

3599
3600
            Argument :attr:`attention_mask` is only used when :attr:`attn_mask_type`
            includes '"padding"' or `"arbitrary"`.
3601
3602
3603

        .. note::

3604
3605
3606
            Input tensor :attr:`query_layer` must be of shape
            (:attr:`sequence_length`, :attr:`batch_size`, :attr:`num_attention_heads`,
            :attr:`kv_channels`) and the tensors :attr:`key_layer` and :attr:`value_layer`
3607
            must each be of shape (:attr:`sequence_length`, :attr:`batch_size`,
3608
            :attr:`num_gqa_groups`, :attr:`kv_channels`). Output of shape
3609
3610
3611
            (:attr:`sequence_length`, :attr:`batch_size`, :attr:`num_attention_heads`
            * :attr:`kv_channels`) is returned.

3612
3613
        .. note::

3614
3615
3616
3617
3618
3619
3620
3621
3622
3623
3624
3625
3626
3627
3628
3629
3630
3631
            DotProductAttention supports three backends: 1) FlashAttention which calls
            HazyResearch/Dao-AILab's `flash-attn <https://arxiv.org/pdf/2305.13245.pdf>`_
            PyTorch API, 2) FusedAttention which has multiple fused attention implementations
            based on `cuDNN Graph API
            <https://docs.nvidia.com/deeplearning/cudnn/developer-guide/index.html#op-fusion>`_
            (see :attr:`FusedAttention` for more details on FusedAttention backends), and 3)
            UnfusedDotProductAttention which is the native PyTorch implementation
            with fused scaled masked softmax.

        .. note::

            Users can use environment variables :attr:`NVTE_FLASH_ATTN`, :attr:`NVTE_FUSED_ATTN`,
            and :attr:`NVTE_FUSED_ATTN_BACKEND` to control which DotProductAttention backend,
            and FusedAttention backend if applicable, to use. TransformerEngine prioritizes
            FlashAttention over FusedAttention and over UnfusedDotProductAttention.
            If FusedAttention is being used, users can also choose to switch to flash-attn's
            implementation for backward by setting :attr:`NVTE_FUSED_ATTN_USE_FAv2_BWD=1`
            (default: 0), because of the performance differences between various versions of
3632
3633
3634
3635
3636
            flash-attn and FusedAttention. Further, :attr:`NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT`
            can be used to enable (:attr:`1`) or disable (:attr:`0`) the workspace related
            optimizations in FusedAttention. When unset, TransformerEngine determines the code path
            based on its internal logic. These optimizations trade memory for performance
            and should be used with care.
3637

3638
3639
3640
3641
3642
3643
3644
3645
        Parameters
        ----------
        query_layer : torch.Tensor
                     Query tensor.
        key_layer : torch.Tensor
                   Key tensor.
        value_layer : torch.Tensor
                     Value tensor.
3646
3647
3648
3649
3650
3651
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]],
             default = `None`. Boolean tensor(s) used to mask out attention softmax input.
             It should be 'None' for 'causal' and 'no_mask' types. For 'padding' masks, it should be
             a single tensor of [batch_size, 1, 1, seqlen_q] for self-attention, and a tuple of
             two tensors in shapes [batch_size, 1, 1, seqlen_q] and [batch_size, 1, 1, seqlen_kv]
             for cross-attention. For the 'arbitrary' mask type, it should be in a shape that is
3652
3653
3654
             broadcastable to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv]. A `True` value
             means the corresponding position is masked out and a `False` means that position is
             allowed to participate in attention.
3655
3656
3657
3658
3659
3660
3661
3662
        qkv_format: str, default = `None`
                   If provided, overrides :attr:`qkv_format` from initialization.
        cu_seqlens_q: Optional[torch.Tensor], default = `None`
                   Cumulative sum of sequence lengths in a batch for `query_layer`,
                   with shape [batch_size + 1] and dtype torch.int32.
        cu_seqlens_kv: Optional[torch.Tensor], default = `None`
                   Cumulative sum of sequence lengths in a batch for `key_layer` and `value_layer`,
                   with shape [batch_size + 1] and dtype torch.int32.
3663
3664
3665
3666
3667
3668
3669
3670
3671
3672
3673
3674
        seq_offsets_q: Optional[torch.Tensor], default = `None`
                   Cumulative offset of different sequences in a batch for `query_layer`,
                   with shape [batch_size + 1] and dtype torch.int32. Required for `thd` layouts.
        seq_offsets_k: Optional[torch.Tensor], default = `None`
                   Cumulative offset of different sequences in a batch for `key_layer`,
                   with shape [batch_size + 1] and dtype torch.int32. Required for `thd` layouts.
        seq_offsets_v: Optional[torch.Tensor], default = `None`
                   Cumulative offset of different sequences in a batch for `value_layer`,
                   with shape [batch_size + 1] and dtype torch.int32. Required for `thd` layouts.
        seq_offsets_o: Optional[torch.Tensor], default = `None`
                   Cumulative offset of different sequences in a batch for forward output,
                   with shape [batch_size + 1] and dtype torch.int32. Required for `thd` layouts.
3675
3676
3677
3678
3679
3680
        max_seqlen_q: Optional[int], default = `None`
                      Maximum sequence length in `query_layer`.
                      Calculated from `cu_seqlens_q` if not provided.
        max_seqlen_kv: Optional[int], default = `None`
                       Maximum sequence length in `key_layer` and `value_layer`.
                       Calculated from `cu_seqlens_kv` if not provided.
3681
3682
3683
        attn_mask_type: {`no_mask`, `padding`, `causal`, `padding,causal`, `causal,padding`,
                       `arbitrary`}, default = `None`. Type of attention mask passed into
                       softmax operation. 'padding,causal' and 'causal,padding' are equivalent.
3684
        window_size: Optional[Tuple[int, int]], default = `None`
3685
                    Sliding window size for local attention.
3686
3687
3688
3689
3690
        checkpoint_core_attention : bool, default = `False`
                                   If true, forward activations for attention are recomputed
                                   during the backward pass in order to save memory that would
                                   otherwise be occupied to store the forward activations until
                                   backprop.
3691
        core_attention_bias_type: str, default = `no_bias`
3692
                    Bias type, {`no_bias`, `pre_scale_bias`, `post_scale_bias`, `alibi`}
3693
        core_attention_bias: Optional[torch.Tensor], default = `None`
3694
3695
                    Bias tensor for Q * K.T, shape [1, num_head, max_seqlen_q, max_seqlen_kv].
                    It should be 'None' for 'no_bias' and 'alibi' bias types.
3696
3697
3698
3699
        alibi_slopes: Optional[torch.Tensor], default = `None`
                     ALiBi slopes in FP32 and shape [nheads] or [batch_size, nheads].
                     It adds a bias of (-alibi_slope * (i + seqlen_k - seqlen_q - j))
                     to the attention score of query i and key j.
3700
        fast_zero_fill: bool, default = `True`
3701
                    Whether to use the fast path to set output tensors to 0 or not.
3702
3703
3704
3705
3706
3707
3708
3709
3710
3711
        inference_params: Optional[InferenceParams], default = `None`
            Optimizes execution performance during inference by caching Keys and Values of the
            current decoding iteration. These cached values are appended to the K and V values
            computed in previous iterations, eliminating the need to recalculate them for the
            entire sequence.
            Initialization of `inference_params` is required prior to use to ensure sufficient
            memory allocation.
            Adjustments of the sequence_len_offset should be done after a complete forward pass.
            If rotary positional embeddings (RoPE) are utilized, they must be prepared beforehand.
            Supports "sbhd" and "bshd" layouts, with the "sbhd" layout being more efficient.
3712
3713
3714
3715
3716
3717
3718
3719
3720
3721
3722
3723
3724
        is_first_microbatch : {True, False, None}, default = None
                             During training using either gradient accumulation or
                             pipeline parallelism a minibatch of data is further split
                             into microbatches. Between the microbatches of the same minibatch
                             the model weights are not updated. Setting this parameter indicates
                             whether the current microbatch is the first in a minibatch or not.
                             When set, this parameter enables additional optimizations:

                             * during FP8 training, it allows caching of the FP8 versions of
                               the weights
                             * it also allows skipping gradient accumulation during the
                               first microbatch (since it is the first gradient being
                               produced)
3725
3726
        """

3727
3728
3729
3730
        assert (
            query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda
            ), 'DotProductAttention only supports CUDA tensors.'

3731
3732
3733
        assert (key_layer.shape == value_layer.shape
            ), "Keys and values must have the same shape!"

3734
3735
        if attn_mask_type is not None:
            window_size = check_set_window_size(attn_mask_type, window_size)
3736
        if attn_mask_type is None:
3737
            attn_mask_type = self.attn_mask_type
3738
3739
3740
3741
3742
3743
3744
        else:
            attn_mask_type = attn_mask_type.replace(",","_")
            if attn_mask_type == "causal_padding":
                attn_mask_type = "padding_causal"

        assert (attn_mask_type in AttnMaskTypes
            ), f"Attention mask type {attn_mask_type} is not supported!"
3745
3746
3747
        if qkv_format == 'thd':
            assert ('padding' in attn_mask_type
                ), "Attention mask type must be padding or padding_causal for qkv_format=thd!"
3748

3749
3750
3751
3752
3753
3754
3755
3756
        if self.rng_states_tracker is not None and is_graph_capturing():
            assert (
                isinstance(self.rng_states_tracker, CudaRNGStatesTracker)
            ), "Unsupported RNG states tracker."
            assert (
                graph_safe_rng_available()
            ), "Upgrade PyTorch version to get RNG manipulation support for cuda graph capture."

3757
3758
3759
        if window_size is None:
            window_size = self.window_size

3760
3761
        if qkv_format is None:
            qkv_format = self.qkv_format
3762

3763
3764
3765
3766
3767
3768
3769
3770
3771
3772
3773
3774
3775
3776
3777
3778
3779
3780
3781
3782
3783
3784
3785
3786
3787
3788
3789
3790
3791
3792
3793
3794
3795
        if inference_params is not None:
            assert self.layer_number is not None, "Layer number must be set!"

            if qkv_format == "bshd":
                key_layer = key_layer.transpose(0, 1)
                value_layer = value_layer.transpose(0, 1)

            (inference_key_memory, inference_value_memory,
            ) = inference_params.key_value_memory_dict[self.layer_number]

            batch_start = inference_params.batch_size_offset
            batch_end = batch_start + key_layer.size(1)
            assert batch_end <= inference_key_memory.size(1)

            sequence_start = inference_params.sequence_len_offset
            sequence_end = sequence_start + key_layer.size(0)
            assert sequence_end <= inference_key_memory.size(0)

            # Copy keys and values into KV-cache
            inference_key_memory[
                sequence_start:sequence_end, batch_start:batch_end, ...] = key_layer
            inference_value_memory[
                sequence_start:sequence_end, batch_start:batch_end, ...] = value_layer
            key_layer = inference_key_memory[:sequence_end, batch_start:batch_end, ...]
            value_layer = inference_value_memory[:sequence_end, batch_start:batch_end, ...]

            if qkv_format == "bshd":
                key_layer = key_layer.transpose(0, 1)
                value_layer = value_layer.transpose(0, 1)

            key_layer = key_layer.contiguous()
            value_layer = value_layer.contiguous()

3796
        assert (key_layer.shape[-2] == self.num_gqa_groups_per_partition
3797
3798
3799
3800
3801
3802
3803
3804
3805
3806
3807
3808
3809
3810
3811
3812
3813
            and value_layer.shape[-2] == self.num_gqa_groups_per_partition
            ), f"Keys and values must have num_gqa_group = {self.num_gqa_groups} heads!"
        assert (qkv_format in ['sbhd', 'bshd', 'thd']
            ), "DotProductAttention only supports qkv_format = {'sbhd', 'bshd', 'thd'}!"

        if qkv_format == 'thd':
            assert (all(len(x.shape) == 3 for x in (query_layer, key_layer, value_layer))
                ), "Queries, keys and values must be 3D tensors when qkv_format = thd!"
            assert (cu_seqlens_q is not None and cu_seqlens_kv is not None
                ), "cu_seqlens_q and cu_seqlens_kv can not be None when qkv_format = thd!"
            assert (cu_seqlens_q.shape == cu_seqlens_kv.shape
                and len(cu_seqlens_q.shape) == 1
                and len(cu_seqlens_kv.shape) == 1
                ), "cu_seqlens_q and cu_seqlens_q must both have shape [batch_size + 1]!"
            assert (cu_seqlens_q.dtype == torch.int32
                and cu_seqlens_kv.dtype == torch.int32
                ), "cu_seqlens_q and cu_seqlens_q must both be in dtype torch.int32!"
3814
3815
            if max_seqlen_q is None:
                seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
3816
                max_seqlen_q = pow(2, math.ceil(math.log2(seqlens_q.max().item())))
3817
3818
            if max_seqlen_kv is None:
                seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
3819
                max_seqlen_kv = pow(2, math.ceil(math.log2(seqlens_kv.max().item())))
3820
3821
3822
3823
3824
3825
3826
3827
3828
3829
3830
3831
3832
3833
3834
3835
3836
3837
3838

        if qkv_format in ['sbhd', 'bshd']:
            assert (all(len(x.shape) == 4 for x in (query_layer, key_layer, value_layer))
                ), f"Queries, keys and values must be 4D tensors when qkv_format = {qkv_format}!"
            if qkv_format == 'sbhd':
                max_seqlen_q, max_seqlen_kv = (query_layer.shape[0], key_layer.shape[0])
            if qkv_format == 'bshd':
                max_seqlen_q, max_seqlen_kv = (query_layer.shape[1], key_layer.shape[1])
            if cu_seqlens_q is not None:
                seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
                assert (all(seqlens_q <= max_seqlen_q)
                    ), """Sequence lengths indicated by cu_seqlens_q must be no greater than
                    the sequence dimention in 'query_layer'!"""
            if cu_seqlens_kv is not None:
                seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
                assert (all(seqlens_kv <= max_seqlen_kv)
                    ), """Sequence lengths indicated by cu_seqlens_kv must be no greater than
                    the sequence dimention in 'key_layer' and 'value_layer'!"""

3839
3840
3841
3842
3843
3844
3845
3846
        if (isinstance(query_layer, Float8Tensor)
            and isinstance(key_layer, Float8Tensor)
            and isinstance(value_layer, Float8Tensor)):
            qkv_layout, query_layer._data, key_layer._data, value_layer._data = _get_qkv_layout(
                query_layer._data, key_layer._data, value_layer._data, qkv_format = qkv_format)
        else:
            qkv_layout, query_layer, key_layer, value_layer = _get_qkv_layout(
                query_layer, key_layer, value_layer, qkv_format = qkv_format)
3847

3848
3849
        # The priority for attention backends (subject to availability and clearing the filters)
        # is: FlashAttention > FusedAttention (cuDNN) > UnfusedDotProductAttention.
3850
        use_flash_attention = self.use_flash_attention
3851
        use_fused_attention = self.use_fused_attention
3852
        use_unfused_attention = True
3853

3854
3855
3856
        # The following section filters out some backends based on
        # certain asserts before executing the forward pass.

3857
        # Filter: QKV layout.
3858
3859
        if use_unfused_attention and qkv_format == 'thd':
            self.logger.debug("Disabling UnusedDotProductAttention for qkv_format = thd")
3860
3861
            use_unfused_attention = False

3862
3863
        # Filter: ONNX export.
        if is_in_onnx_export_mode():
3864
3865
            if use_flash_attention:
                self.logger.debug("Disabling FlashAttention for ONNX mode")
3866
            use_flash_attention = False
3867
3868
            if use_fused_attention:
                self.logger.debug("Disabling FusedAttention for ONNX mode")
3869
3870
            use_fused_attention = False

3871
        # Filter: Input type.
3872
3873
3874
3875
3876
        if (use_flash_attention
            and (query_layer.dtype not in [torch.bfloat16, torch.float16]
                or key_layer.dtype not in [torch.bfloat16, torch.float16]
                or value_layer.dtype not in [torch.bfloat16, torch.float16]
                or any(isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer]))
3877
        ):
3878
3879
3880
3881
3882
            self.logger.debug(
                "Disabling FlashAttention due to unsupported QKV data types. "
                "Supported: [torch.bfloat16, torch.float16]. "
                "Found: query_layer.dtype=%s, key_layer.dtype=%s, value_layer.dtype=%s.",
                query_layer.dtype, key_layer.dtype, value_layer.dtype)
3883
            use_flash_attention = False
3884
3885
3886
3887
        if (use_fused_attention
            and (query_layer.dtype not in [torch.bfloat16, torch.float16]
                or key_layer.dtype not in [torch.bfloat16, torch.float16]
                or value_layer.dtype not in [torch.bfloat16, torch.float16])
3888
        ):
3889
3890
3891
3892
3893
            self.logger.debug(
                "Disabling FusedAttention due to unsupported QKV data types. "
                "Supported: [torch.bfloat16, torch.float16, Float8Tensor]. "
                "Found: query_layer.dtype=%s, key_layer.dtype=%s, value_layer.dtype=%s.",
                query_layer.dtype, key_layer.dtype, value_layer.dtype)
3894
            use_fused_attention = False
3895

3896
        # Filter: Device and dimensions.
3897
        # FAv2 supports head_dim <= 256, and for >192 requires sm80/sm90
3898
        # FAv2 requires head_dim % 8 == 0
3899
3900
3901
3902
3903
3904
3905
3906
3907
3908
3909
        if (use_flash_attention
            and (query_layer.shape[-1] > 256
                or query_layer.shape[-1] % 8 != 0
                or (query_layer.shape[-1] > 192
                    and self.device_compute_capability not in ((8, 0), (9, 0))))):
            self.logger.debug(
                "Disabling FlashAttention due to unsupported head_dim. "
                "Supported: %%8 == 0, and <= 256; sm80/90 for >192. "
                "Found: query_layer.shape[-1]=%s, key_layer.shape[-1]=%s, sm=%s",
                query_layer.shape[-1], key_layer.shape[-1],
                '.'.join([str(i) for i in self.device_compute_capability]))
3910
3911
            use_flash_attention = False

3912
        # Filter: cross attention + causal mask.
3913
        # (in training mode)
3914
3915
        if (use_flash_attention
            and inference_params is None
3916
            and _flash_attn_2_1_plus
3917
            and "causal" in attn_mask_type
3918
3919
            and max_seqlen_q != max_seqlen_kv
        ):
3920
            self.logger.warning(
3921
3922
                "In training mode, disable the use of FlashAttention since version 2.1+ has "
                "changed its behavior for causal mask in cross attention. See "
3923
3924
3925
3926
                "https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag"
            )
            use_flash_attention = False

3927
3928
3929
        context_parallel = (self.cp_group is not None and \
            get_distributed_world_size(self.cp_group) != 1)

3930
3931
3932
        # Filter: sliding window attention.
        # UnfusedDotProductAttention can support SWA via arbitrary attention mask.
        if window_size not in ((-1, -1), (-1, 0)):
3933
3934
            if use_fused_attention:
                self.logger.debug("Disabling FusedAttention for SWA")
3935
3936
            use_fused_attention = False
            if (not _flash_attn_2_3_plus) or context_parallel:
3937
3938
3939
3940
                if use_flash_attention:
                    self.logger.debug(
                        "Disabling FusedAttention as it requires flash-attn 2.3+ "
                        "and no context parallelism")
3941
3942
                use_flash_attention = False

3943
        # Filter: Attention mask type.
3944
        #   attn_mask_type(s)    |     supported backends
3945
        # ------------------------------------------------
3946
3947
        #   no_mask              |     All
        #   padding              |     UnfusedDotProductAttention, FlashAttention, FusedAttention
3948
        #   causal               |     All
3949
        #   padding + causal     |     FlashAttention, FusedAttention
3950
3951
3952
        #   arbitrary            |     UnfusedDotProductAttention
        #
        if attn_mask_type == "arbitrary":
3953
3954
            if use_flash_attention:
                self.logger.debug("Disabling FlashAttention for arbitrary mask")
3955
            use_flash_attention = False
3956
3957
            if use_fused_attention:
                self.logger.debug("Disabling FusedAttention for arbitrary mask")
3958
            use_fused_attention = False
3959

3960
3961
        if (use_unfused_attention
            and inference_params is None
3962
3963
3964
            and "causal" in attn_mask_type
            and max_seqlen_q != max_seqlen_kv
        ):
3965
            self.logger.debug("Disabling UnusedDotProductAttention for qkv_format = thd")
3966
            use_unfused_attention = False
3967

3968
3969
3970
3971
3972
3973
3974
3975
3976
3977
3978
3979
3980
3981
3982
3983
3984
3985
        # Filter: bias.
        global _alibi_cache
        if alibi_slopes is not None:
            assert (core_attention_bias_type == "alibi"
                ), "core_attention_bias_type must be alibi in order to use alibi_slopes!"
            if self.layer_number == 1:
                _alibi_cache["_alibi_slopes_require_update"] = True
                _alibi_cache["_alibi_bias_require_update"] = True
        if core_attention_bias_type == "alibi":
            assert (core_attention_bias is None
                ), "core_attention_bias must be None when core_attention_bias_type is alibi!"
            if (_alibi_cache["_num_heads"] != query_layer.shape[-2]
                or _alibi_cache["_max_seqlen_q"] != max_seqlen_q
                or _alibi_cache["_max_seqlen_kv"] != max_seqlen_kv
                or _alibi_cache["_alibi_slopes"] is None):
                _alibi_cache["_alibi_slopes_require_update"] = True
                _alibi_cache["_alibi_bias_require_update"] = True

3986
3987
3988
3989
        if (use_flash_attention
            and (core_attention_bias_type not in ["no_bias", "alibi"]
                or core_attention_bias is not None)):
            self.logger.debug("Disabling FlashAttention for pre/post_scale_bias")
3990
3991
3992
3993
3994
3995
3996
3997
3998
            use_flash_attention = False

        fu_core_attention_bias_type = core_attention_bias_type
        fu_core_attention_bias = core_attention_bias
        if core_attention_bias_type == "alibi" and use_fused_attention and alibi_slopes is not None:
            fu_core_attention_bias_type = "post_scale_bias"
            _, fu_core_attention_bias = get_alibi(
                query_layer.shape[-2], max_seqlen_q, max_seqlen_kv, alibi_slopes=alibi_slopes,
                bias_dtype=query_layer.dtype)
3999
4000
4001
4002
4003
4004
4005
        if (use_fused_attention
            and fu_core_attention_bias_type == "post_scale_bias"
            and (fu_core_attention_bias.shape[0] != 1
            or fu_core_attention_bias.shape[1] != query_layer.shape[-2])):
            if fu_core_attention_bias.requires_grad:
                # remove this line when cuDNN adds bwd support for
                # [1, 1, s, s], [b, 1, s, s] and [b, h, s, s]
4006
                self.logger.debug("Disabling FusedAttention for dBias in [1, H, S, S] shape")
4007
                use_fused_attention = False
4008
            else:
4009
4010
4011
                # max512 backend will only support [1, h, s, s]
                os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1"

4012
4013
        if use_fused_attention:
            fused_attention_backend = tex.get_fused_attn_backend(
4014
4015
4016
4017
                TE_DType[query_layer.dtype]
                if not isinstance(query_layer, Float8Tensor) else query_layer._fp8_dtype,
                TE_DType[key_layer.dtype]
                if not isinstance(key_layer, Float8Tensor) else key_layer._fp8_dtype,
4018
                QKVLayout[qkv_layout],
4019
                AttnBiasType[fu_core_attention_bias_type],
4020
                AttnMaskType[attn_mask_type],
4021
                self.attention_dropout,
4022
4023
4024
4025
4026
4027
                query_layer.shape[-2], # num_attn_heads
                key_layer.shape[-2], # num_gqa_groups
                max_seqlen_q,
                max_seqlen_kv,
                query_layer.shape[-1], # head_dim
            )
4028
4029
            # DPA does not support FP8; for FP8, use cpp_extensions modules directly
            is_backend_avail = (fused_attention_backend in
4030
4031
4032
                [FusedAttnBackend["F16_max512_seqlen"],
                FusedAttnBackend["F16_arbitrary_seqlen"],
                FusedAttnBackend["FP8"]])
4033
4034
4035
4036
            use_fused_attention = ( \
                use_fused_attention and is_backend_avail and \
                (not context_parallel or \
                 fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]))
4037
4038
4039
4040
            if (fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]
                and fu_core_attention_bias_type == "post_scale_bias"
                and (fu_core_attention_bias.shape[0] != 1
                or fu_core_attention_bias.shape[1] != query_layer.shape[-2])):
4041
4042
                self.logger.debug(
                    "Disabling FusedAttention as no backend supports the provided input")
4043
                use_fused_attention = False
4044

4045
4046
4047
4048
4049
4050
4051
4052
4053
4054
4055
4056
4057
4058
4059
4060
        # Filter: determinism.
        # backend                                  | deterministic
        # ---------------------------------------------------------
        # flash-attn v1                            | yes
        # flash-attn v2                            | no
        # FusedAttnBackend["F16_max512_seqlen"]    | yes
        # FusedAttnBackend["F16_arbitrary_seqlen"] | workspace optimization path: yes; otherwise: no
        # UnfusedDotProductAttention               | yes
        #
        # Note that FusedAttnBackend["F16_arbitrary_seqlen"] only has workspace optimization path
        # on sm90 architectures.
        #
        if (use_fused_attention
            and fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]
            and self.deterministic
            and self.device_compute_capability != (9, 0)):
4061
            self.logger.debug("Disabling FusedAttention for determinism reasons")
4062
4063
            use_fused_attention = False

4064
4065
4066
4067
4068
        # Select FusedAttention on sm90 and FlashAttention on others for performance
        if (use_flash_attention
            and use_fused_attention
            and fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]):
            if self.device_compute_capability == (9, 0):
4069
4070
4071
                self.logger.debug(
                    "Disabling FlashAttention to give FusedAttention preference on Hopper+ "
                    "for performance reasons")
4072
                use_flash_attention = False
4073

4074
4075
4076
4077
4078
4079
4080
4081
4082
4083
4084
4085
4086
4087
4088
4089
4090
4091
4092
4093
4094
        run_config = {
            "compute_capability":"sm"+str((lambda x,y: x*10+y)(
                self.device_compute_capability[0],self.device_compute_capability[1])),
            "q_dtype":query_layer.dtype,
            "k_dtype":key_layer.dtype,
            "v_dtype":value_layer.dtype,
            "q_shape":list(query_layer.shape),
            "k_shape":list(key_layer.shape),
            "v_shape":list(value_layer.shape),
            "qkv_format":qkv_format,
            "qkv_layout":qkv_layout,
            "mask_type":attn_mask_type,
            "bias_type":core_attention_bias_type,
            "bias_shape":core_attention_bias.shape if core_attention_bias is not None else None,
            "dropout":self.attention_dropout,
            "context_parallel":context_parallel,
            "is_training":self.training,
            "transformer_engine_version":te.__version__,
            "flash_attn_version":_flash_attn_version,
            "cudnn_version":'.'.join([str(i) for i in get_cudnn_version()])}

4095
        if use_flash_attention:
4096
4097
            self.logger.info("Running with FlashAttention backend ")
            self.logger.debug("Running with config=%s",run_config)
4098
4099
4100
            if core_attention_bias_type == "alibi":
                alibi_slopes, _ = get_alibi(
                    query_layer.shape[-2], max_seqlen_q, max_seqlen_kv, alibi_slopes=alibi_slopes)
4101
4102
4103
4104
4105
4106
4107
4108
            return self.flash_attention(query_layer,
                                        key_layer,
                                        value_layer,
                                        attention_mask=attention_mask,
                                        qkv_layout=qkv_layout,
                                        cu_seqlens_q=cu_seqlens_q,
                                        cu_seqlens_kv=cu_seqlens_kv,
                                        attn_mask_type=attn_mask_type,
4109
                                        window_size=window_size,
4110
                                        alibi_slopes=alibi_slopes,
4111
4112
                                        cp_group=self.cp_group,
                                        cp_global_ranks=self.cp_global_ranks,
4113
4114
4115
                                        cp_stream=self.cp_stream,
                                        max_seqlen_q=max_seqlen_q,
                                        max_seqlen_kv=max_seqlen_kv)
4116

4117
        if use_fused_attention:
4118
4119
4120
4121
            self.logger.info(
                "Running with FusedAttention backend (sub-backend %s)",
                int(fused_attention_backend))
            self.logger.debug("Running with config=%s",run_config)
4122
            if checkpoint_core_attention:
4123
4124
4125
4126
4127
4128
4129
4130
                return self._checkpointed_attention_forward(
                    self.fused_attention,
                    query_layer,
                    key_layer,
                    value_layer,
                    qkv_layout=qkv_layout,
                    cu_seqlens_q=cu_seqlens_q,
                    cu_seqlens_kv=cu_seqlens_kv,
4131
4132
4133
4134
                    seq_offsets_q=seq_offsets_q,
                    seq_offsets_k=seq_offsets_k,
                    seq_offsets_v=seq_offsets_v,
                    seq_offsets_o=seq_offsets_o,
4135
4136
                    max_seqlen_q=max_seqlen_q,
                    max_seqlen_kv=max_seqlen_kv,
4137
4138
4139
                    attn_mask_type=attn_mask_type,
                    attention_mask=attention_mask,
                    fused_attention_backend=fused_attention_backend,
4140
4141
                    core_attention_bias_type=fu_core_attention_bias_type,
                    core_attention_bias=fu_core_attention_bias,
4142
4143
4144
4145
                    fast_zero_fill=fast_zero_fill,
                    cp_group=self.cp_group,
                    cp_global_ranks=self.cp_global_ranks,
                    cp_stream=self.cp_stream,
4146
                    is_first_microbatch=is_first_microbatch)
4147
4148
4149
4150
4151
4152
4153
            return self.fused_attention(
                query_layer,
                key_layer,
                value_layer,
                qkv_layout=qkv_layout,
                cu_seqlens_q=cu_seqlens_q,
                cu_seqlens_kv=cu_seqlens_kv,
4154
4155
4156
4157
                seq_offsets_q=seq_offsets_q,
                seq_offsets_k=seq_offsets_k,
                seq_offsets_v=seq_offsets_v,
                seq_offsets_o=seq_offsets_o,
4158
4159
                max_seqlen_q=max_seqlen_q,
                max_seqlen_kv=max_seqlen_kv,
4160
4161
4162
                attn_mask_type=attn_mask_type,
                attention_mask=attention_mask,
                fused_attention_backend=fused_attention_backend,
4163
4164
                core_attention_bias_type=fu_core_attention_bias_type,
                core_attention_bias=fu_core_attention_bias,
4165
4166
4167
4168
                fast_zero_fill=fast_zero_fill,
                cp_group=self.cp_group,
                cp_global_ranks=self.cp_global_ranks,
                cp_stream=self.cp_stream,
4169
                is_first_microbatch=is_first_microbatch)
4170
4171
4172

        assert (not context_parallel), \
            "Context parallelism is only implemented with Flash Attention and Fused Attention!"
4173

4174
4175
4176
4177
4178
4179
4180
        from .cpu_offload import CPUOffloadEnabled
        if CPUOffloadEnabled:
            warnings.warn(
                           "Attention activation Offloading is only implemented"
                           "with Flash Attention and Fused Attention!"
                         )

4181
        if use_unfused_attention:
4182
4183
            self.logger.info("Running with UnfusedDotProductAttention backend")
            self.logger.debug("Running with config=%s",run_config)
4184
4185
4186
4187
4188
4189
4190
4191
4192
4193
4194
4195
            if checkpoint_core_attention:
                return self._checkpointed_attention_forward(
                    self.unfused_attention,
                    query_layer,
                    key_layer,
                    value_layer,
                    qkv_layout = qkv_layout,
                    cu_seqlens_q = cu_seqlens_q,
                    cu_seqlens_kv = cu_seqlens_kv,
                    attn_mask_type = attn_mask_type,
                    attention_mask = attention_mask,
                    core_attention_bias_type = core_attention_bias_type,
4196
4197
                    core_attention_bias = core_attention_bias,
                    alibi_slopes = alibi_slopes)
4198
4199
4200
4201
4202
4203
4204
4205
4206
            return self.unfused_attention(query_layer,
                    key_layer,
                    value_layer,
                    qkv_layout = qkv_layout,
                    cu_seqlens_q = cu_seqlens_q,
                    cu_seqlens_kv = cu_seqlens_kv,
                    attn_mask_type = attn_mask_type,
                    attention_mask = attention_mask,
                    core_attention_bias_type = core_attention_bias_type,
4207
4208
                    core_attention_bias = core_attention_bias,
                    alibi_slopes = alibi_slopes)
4209
4210

        raise Exception("No dot product attention support for the provided inputs!")
4211
4212


4213
4214
4215
4216
4217
4218
4219
class MultiheadAttention(torch.nn.Module):
    r"""
    Multi-head Attention (MHA), including Query,
    Key, Value and Output projection.

    .. note::

4220
4221
        Argument :attr:`attention_mask` in the `forward` call is only used when
        :attr:`attn_mask_type` includes '"padding"' or `"arbitrary"`.
4222

4223
4224
4225
4226
4227
4228
4229
4230
4231
4232
4233
4234
4235
4236
4237
4238
4239
4240
4241
4242
4243
4244
4245
4246
4247
    Parameters
    ----------
    hidden_size : int
                 size of each input sample.
    num_attention_heads : int
                         number of attention heads in the transformer layer.
    kv_channels: int, default = `None`
                number of key-value channels. defaults to
                :attr:`hidden_size` / :attr:`num_attention_heads` if `None`.
    attention_dropout: float, default = 0.1
                      dropout probability for the dropout op during multi-head attention.
    layernorm_epsilon : float, default = 1e-5
                       a value added to the denominator of layer normalization
                       for numerical stability.
    init_method : Callable, default = `None`
                 used for initializing weights of QKV and FC1 weights in the following way:
                 `init_method(weight)`. When set to `None`, defaults to
                 `torch.nn.init.normal_(mean=0.0, std=0.023)`.
    output_layer_init_method : Callable, default = `None`
                              used for initializing weights of PROJ and FC2 in the following way:
                              `output_layer_init_method(weight)`. When set to `None`, defaults to
                              `torch.nn.init.normal_(mean=0.0, std=0.023)`.
    layer_number: int, default = `None`
                 layer number of the current `TransformerLayer` when multiple such modules are
                 concatenated to form a transformer block.
4248
4249
    attn_mask_type: {'no_mask', 'padding', 'causal', 'padding_causal' 'arbitrary'},
                   default = `causal`
4250
4251
4252
4253
4254
                   type of attention mask passed into softmax operation. Overridden by
                   :attr:`attn_mask_type` in the `forward` method. The forward
                   arg is useful for dynamically changing mask types, e.g. a different
                   mask for training and inference. The init arg is useful for cases
                   involving compilation/tracing, e.g. ONNX export.
4255
4256
4257
4258
4259
4260
    window_size: Optional[Tuple[int, int]], default = `None`
                sliding window size for local attention, where query at position i attends to keys
                in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
                + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
                window and causal mask specifically. Similar to :attr:`attn_mask_type`, it can
                be overridden by :attr:`window_size` in `forward` as well.
4261
4262
4263
4264
4265
4266
4267
4268
4269
4270
4271
4272
4273
    num_gqa_groups : int, default = `None`
                         number of GQA groups in the transformer layer.
                         Grouped Query Attention is described in
                         `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
                         This only affects the keys and values, not the querys.
                         GQA-1 is equivalent to Multi-Query Attention
                         (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
                         is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
    return_layernorm_output : bool, default = `False`
                             if set to `True`, output of layernorm is returned from the forward
                             together with the output of the linear transformation.
                             Example use case: residual connection for transformer module is
                             taken post layernorm.
4274
4275
    input_layernorm: bool, default = `False`
                     if set to `True`, layer normalization to the input is applied.
4276
4277
4278
4279
4280
4281
4282
4283
4284
4285
4286
4287
4288
4289
4290
4291
4292
4293
4294
4295
4296
4297
4298
    attention_type: { 'self', 'cross' }, default = 'self'
                   type of attention applied.
    zero_centered_gamma : bool, default = 'False'
                         if set to 'True', gamma parameter in LayerNorm is initialized to 0 and
                         the LayerNorm formula changes to

                         .. math::
                            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} *
                            (1 + \gamma) + \beta
    normalization : { 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm'
                   type of normalization applied.
    qkv_weight_interleaved : bool, default = `True`
                            if set to `False`, the QKV weight is interpreted as a concatenation of
                            query, key, and value weights along the `0th` dimension. The default
                            interpretation is that the individual `q`, `k`, and `v` weights for each
                            attention head are interleaved. This parameter is set to `False` when
                            using :attr:`fuse_qkv_params=False`.
    bias : bool, default = `True`
          if set to `False`, the transformer layer will not learn any additive biases.
    device : Union[torch.device, str], default = "cuda"
          The device on which the parameters of the model will allocated. It is the user's
          responsibility to ensure all parameters are moved to the GPU before running the
          forward pass.
4299
4300
4301
4302
4303
4304
4305
4306
    qkv_format: str, default = `sbhd`
            dimension format for `query_layer`, `key_layer` and `value_layer`,
            {`sbhd`, `bshd`}. `s` stands for the sequence length, `b` batch size,
            `h` the number of heads and `d` head size. `sbhd` and `bshd` formats
            are used for when sequences in a batch are of equal length or padded to
            equal length. Please note that these formats do not reflect how
            tensors `query_layer`, `key_layer`, `value_layer` are laid out in memory.
            For that, please use `_get_qkv_layout` to gain the layout information.
4307
4308
4309
4310
4311
4312
4313
4314
4315
4316
4317
4318
4319
4320
4321
4322
4323
4324
4325
4326
4327
4328
4329
4330
4331
4332
4333
4334
4335
4336
4337
4338
4339
4340
4341
4342
4343
4344
4345
4346

    Parallelism parameters
    ----------------------
    set_parallel_mode : bool, default = `False`
                      if set to `True`, QKV and FC1 layers are used as Column Parallel
                      whereas PROJ and FC2 is used as Row Parallel as described
                      `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
    sequence_parallel : bool, default = `False`
                       if set to `True`, uses sequence parallelism.
    tp_group : ProcessGroup, default = `None`
              tensor parallel process group.
    tp_size : int, default = 1
             used as TP (tensor parallel) world size when TP groups are not formed during
             initialization. In this case, users must call the
             `set_tensor_parallel_group(tp_group)` method on the initialized module before the
             forward pass to supply the tensor parallel group needed for tensor and sequence
             parallel collectives.

    Optimization parameters
    -----------------------
    fuse_wgrad_accumulation : bool, default = 'False'
                             if set to `True`, enables fusing of creation and accumulation of
                             the weight gradient. When enabled, it is assumed that the weights
                             have an additional `main_grad` attribute (used instead of the
                             regular `grad`) which is a pre-allocated buffer of the correct
                             size to accumulate gradients in.
    params_dtype : torch.dtype, default = `torch.get_default_dtype()`
                  it controls the type used to allocate the initial parameters. Useful when
                  the model is trained with lower precision and the original FP32 parameters
                  would not fit in GPU memory.
    return_bias : bool, default = `False`
                 when set to `True`, this module will not apply the additive bias itself, but
                 instead return the bias value during the forward pass together with the
                 output of the linear transformation :math:`y = xA^T`. This is useful when
                 the bias addition can be fused to subsequent operations.
    fuse_qkv_params: bool, default = 'False'
                    if set to `True`, `TransformerLayer` module exposes a single fused
                    parameter for query-key-value. This enables optimizations such as QKV
                    fusion without concatentations/splits and also enables the argument
                    `fuse_wgrad_accumulation`.
4347
4348
4349
4350
4351
4352
    """

    def __init__(
        self,
        hidden_size: int,
        num_attention_heads: int,
4353
4354
4355
4356
4357
        kv_channels: Optional[int] = None,
        attention_dropout: float = 0.1,
        layernorm_epsilon: float = 1e-5,
        init_method: Optional[Callable] = None,
        output_layer_init_method: Optional[Callable] = None,
4358
        layer_number: Optional[int] = None,
4359
        attn_mask_type: str = "causal",
4360
        window_size: Optional[Tuple[int, int]] = None,
4361
4362
        tp_group: Optional[dist_group_type] = None,
        tp_size: int = 1,
4363
        num_gqa_groups: Optional[int] = None,
4364
4365
4366
        fuse_wgrad_accumulation: bool = False,
        get_rng_state_tracker: Optional[Callable] = None,
        sequence_parallel: bool = False,
4367
        params_dtype: Optional[torch.dtype] = None,
4368
        return_bias: bool = False,
4369
4370
4371
4372
4373
4374
4375
4376
4377
        return_layernorm_output: bool = False,
        input_layernorm: bool = False,
        attention_type: str = "self",
        set_parallel_mode: bool = False,
        fuse_qkv_params: bool = False,
        zero_centered_gamma: bool = False,
        qkv_weight_interleaved: bool = True,
        ub_bulk_wgrad: bool = False,
        ub_bulk_dgrad: bool = False,
Jaemin Choi's avatar
Jaemin Choi committed
4378
        ub_overlap_rs_dgrad: bool = False,
4379
4380
        ub_overlap_rs: bool = False,
        ub_overlap_ag: bool = False,
4381
        bias: bool = True,
4382
        normalization: str = "LayerNorm",
4383
        device: Union[torch.device, str] = "cuda",
4384
        qkv_format: str = "sbhd"
4385
4386
    ) -> None:
        super().__init__()
4387

4388
        self.qkv_format = qkv_format
4389
        self.attn_mask_type = attn_mask_type
4390
4391
        self.window_size = window_size
        self.window_size = check_set_window_size(attn_mask_type, self.window_size)
4392
        self.layer_number = layer_number
4393
4394
4395
4396
4397
        self.input_layernorm = input_layernorm
        self.attention_type = attention_type
        self.get_rng_state_tracker = get_rng_state_tracker
        self.tp_group = tp_group
        self.return_layernorm_output = return_layernorm_output
4398
        self.params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
4399
        self.num_attention_heads = num_attention_heads
4400
4401
4402
4403
4404
4405
4406
4407
        self.return_bias = return_bias

        kv_channels = kv_channels if kv_channels else (hidden_size // num_attention_heads)

        if init_method is None:
            init_method = get_default_init_method()
        if output_layer_init_method is None:
            output_layer_init_method = get_default_init_method()
4408
4409
4410
4411
4412

        if not fuse_qkv_params:
            qkv_weight_interleaved = False
        self.qkv_weight_interleaved = qkv_weight_interleaved

4413
4414
4415
        assert attention_type in AttnTypes, f"attention_type {attention_type} not supported"
        if layer_number is not None:
            assert layer_number > 0, "layer_number must be a positive integer"
4416
4417
4418
4419
4420
4421

        tp_size = tp_size if tp_group is None else get_distributed_world_size(tp_group)
        self.tp_size = tp_size
        self.sequence_parallel = (tp_size > 1) and sequence_parallel

        self.num_attention_heads_per_partition = divide(num_attention_heads, tp_size)
4422
4423
4424
4425
        self.num_gqa_groups = (
            num_attention_heads if num_gqa_groups is None else num_gqa_groups
        )
        assert (num_attention_heads % self.num_gqa_groups == 0
cyanguwa's avatar
cyanguwa committed
4426
4427
                ), "The number of attention heads must be divisible by the number of GQA groups!"
        assert (self.num_gqa_groups % tp_size == 0
4428
4429
                ), "The number of GQA groups must be divisible by tensor parallel size!"
        self.num_gqa_groups_per_partition = int(self.num_gqa_groups // tp_size)
4430
4431
4432
4433

        self.hidden_size_per_attention_head = kv_channels
        self.hidden_size_q = self.hidden_size_per_attention_head * num_attention_heads
        self.hidden_size_kv = self.hidden_size_per_attention_head * self.num_gqa_groups
4434
4435
4436
4437
4438
4439
4440

        common_gemm_kwargs = {
            "fuse_wgrad_accumulation": fuse_wgrad_accumulation,
            "tp_group": tp_group,
            "tp_size": tp_size,
            "get_rng_state_tracker": get_rng_state_tracker,
            "sequence_parallel": sequence_parallel,
4441
            "params_dtype": self.params_dtype,
4442
            "device": device,
4443
4444
4445
4446
        }

        qkv_parallel_mode = "column" if set_parallel_mode else None

cyanguwa's avatar
cyanguwa committed
4447
        if self.attention_type == "self":
4448
4449
4450
            parameters_split = None
            if not fuse_qkv_params:
                parameters_split = collections.OrderedDict([
4451
                    ("query", self.hidden_size_q),
4452
4453
4454
                    ("key", self.hidden_size_kv),
                    ("value", self.hidden_size_kv),
                ])
4455
4456
4457
            if self.input_layernorm:
                self.layernorm_qkv = LayerNormLinear(
                    hidden_size,
4458
                    self.hidden_size_q + 2 * self.hidden_size_kv,
4459
4460
4461
4462
4463
4464
                    eps=layernorm_epsilon,
                    init_method=init_method,
                    bias=bias,
                    return_bias=False,
                    parallel_mode=qkv_parallel_mode,
                    return_layernorm_output=return_layernorm_output,
cyanguwa's avatar
cyanguwa committed
4465
                    parameters_split=parameters_split,
4466
4467
4468
                    zero_centered_gamma=zero_centered_gamma,
                    ub_bulk_wgrad=ub_bulk_wgrad,
                    ub_bulk_dgrad=ub_bulk_dgrad,
Jaemin Choi's avatar
Jaemin Choi committed
4469
                    ub_overlap_rs_dgrad=ub_overlap_rs_dgrad,
4470
                    ub_overlap_ag=ub_overlap_ag,
4471
                    normalization=normalization,
Przemyslaw Tredak's avatar
Przemyslaw Tredak committed
4472
                    ub_name="qkv",
4473
4474
4475
4476
4477
                    **common_gemm_kwargs,
                )
            else:
                self.qkv = Linear(
                    hidden_size,
4478
                    self.hidden_size_q + 2 * self.hidden_size_kv,
4479
4480
4481
4482
                    init_method=init_method,
                    bias=bias,
                    return_bias=False,
                    parallel_mode=qkv_parallel_mode,
cyanguwa's avatar
cyanguwa committed
4483
                    parameters_split=parameters_split,
4484
4485
                    **common_gemm_kwargs,
                )
cyanguwa's avatar
cyanguwa committed
4486
        elif self.attention_type == "cross":
4487
4488
4489
            if self.input_layernorm:
                self.layernorm_query = LayerNormLinear(
                    hidden_size,
4490
                    self.hidden_size_q,
4491
4492
4493
4494
4495
                    eps=layernorm_epsilon,
                    init_method=init_method,
                    bias=bias,
                    return_bias=False,
                    parallel_mode=qkv_parallel_mode,
4496
                    parameters_split=("query",) if not fuse_qkv_params else None,
4497
4498
4499
4500
                    return_layernorm_output=return_layernorm_output,
                    zero_centered_gamma=zero_centered_gamma,
                    ub_bulk_wgrad=ub_bulk_wgrad,
                    ub_bulk_dgrad=ub_bulk_dgrad,
Jaemin Choi's avatar
Jaemin Choi committed
4501
                    ub_overlap_rs_dgrad=ub_overlap_rs_dgrad,
4502
                    ub_overlap_ag=ub_overlap_ag,
4503
                    normalization=normalization,
Przemyslaw Tredak's avatar
Przemyslaw Tredak committed
4504
                    ub_name="qkv",
4505
4506
4507
4508
4509
                    **common_gemm_kwargs,
                )
            else:
                self.query_layer = Linear(
                    hidden_size,
4510
                    self.hidden_size_q,
4511
4512
4513
4514
4515
4516
4517
4518
                    init_method=init_method,
                    bias=bias,
                    return_bias=False,
                    parallel_mode=qkv_parallel_mode,
                    **common_gemm_kwargs,
                )
            self.key_value = Linear(
                hidden_size,
4519
                2 * self.hidden_size_kv,
4520
4521
4522
4523
                init_method=init_method,
                bias=bias,
                return_bias=False,
                parallel_mode=qkv_parallel_mode,
4524
                parameters_split=("key", "value") if not fuse_qkv_params else None,
4525
4526
4527
4528
4529
4530
                **common_gemm_kwargs,
            )

        # Attention.
        self.core_attention = DotProductAttention(
            num_attention_heads,
4531
            self.hidden_size_per_attention_head,
4532
4533
            num_gqa_groups=self.num_gqa_groups,
            attention_dropout=attention_dropout,
4534
            qkv_format=self.qkv_format,
4535
4536
4537
4538
            tp_size=tp_size,
            get_rng_state_tracker=get_rng_state_tracker,
            sequence_parallel=sequence_parallel,
            tp_group=tp_group,
4539
            layer_number=self.layer_number,
4540
            attention_type=self.attention_type,
4541
4542
4543
4544
        )

        # Linear
        self.proj = Linear(
4545
            self.hidden_size_q,
4546
4547
4548
            hidden_size,
            init_method=output_layer_init_method,
            bias=bias,
4549
            return_bias=return_bias,
4550
            parallel_mode="row" if set_parallel_mode else None,
4551
4552
            ub_overlap_rs=ub_overlap_rs,
            ub_overlap_ag=ub_overlap_ag,
Przemyslaw Tredak's avatar
Przemyslaw Tredak committed
4553
            ub_name="proj",
4554
4555
4556
4557
4558
            **common_gemm_kwargs,
        )


    def _allocate_memory(
4559
        self, inference_max_sequence_len: int, batch_size: int, dtype: torch.dtype
4560
4561
4562
4563
    ) -> torch.Tensor:
        return torch.empty(
            inference_max_sequence_len,
            batch_size,
4564
            self.num_gqa_groups_per_partition,
4565
            self.hidden_size_per_attention_head,
4566
            dtype=dtype,
4567
4568
4569
4570
            device=torch.cuda.current_device(),
        )

    def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None:
4571
4572
4573
4574
4575
4576
4577
4578
4579
        """
        Set the tensor parallel group for the given
        module before executing the forward pass.

        Parameters
        ----------
        tp_group : ProcessGroup, default = `None`
                  tensor parallel process group.
        """
4580
4581
        self.tp_group = tp_group

4582
    def set_context_parallel_group(
4583
4584
        self,
        cp_group: Union[dist_group_type, None],
4585
        cp_global_ranks: List[int],
4586
4587
        cp_stream: torch.cuda.Stream,
    ) -> None:
4588
4589
4590
4591
4592
4593
4594
4595
4596
4597
4598
4599
4600
        """
        Set the context parallel attributes for the given
        module before executing the forward pass.

        Parameters
        ----------
        cp_group : ProcessGroup
                  context parallel process group.
        cp_global_ranks : List[int]
                         list of global ranks in the context group.
        cp_stream : torch.cuda.Stream
                   cuda stream for context parallel execution.
        """
4601
4602
4603
4604
4605
4606
        # Deep iterate but skip self to avoid infinite recursion.
        for index, child in enumerate(self.modules()):
            if index == 0:
                continue
            if hasattr(child, "set_context_parallel_group"):
                child.set_context_parallel_group(cp_group, cp_global_ranks, cp_stream)
4607

4608
4609
4610
    def forward(
        self,
        hidden_states: torch.Tensor,
4611
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
4612
        encoder_output: Optional[torch.Tensor] = None,
4613
        attn_mask_type: Optional[str] = None,
4614
        window_size: Optional[Tuple[int, int]] = None,
4615
4616
        is_first_microbatch: Optional[bool] = None,
        checkpoint_core_attention: bool = False,
4617
        inference_params: Optional[InferenceParams] = None,
4618
        rotary_pos_emb: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
4619
4620
        core_attention_bias_type: str = "no_bias",
        core_attention_bias: Optional[torch.Tensor] = None,
4621
        alibi_slopes: Optional[torch.Tensor] = None,
4622
        fast_zero_fill: bool = True,
4623
    ) -> Tuple[Union[torch.Tensor, None], ...]:
4624
4625
4626
4627
4628
        """
        Forward propagation for MultiheadAttention layer.

        .. note::

4629
4630
            Argument :attr:`attention_mask` is only used when :attr:`attn_mask_type`
            includes `"padding"` or `"arbitrary"`.
4631
4632
4633
4634
4635

        Parameters
        ----------
        hidden_states : torch.Tensor
             Input tensor.
4636
4637
4638
4639
4640
4641
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]],
             default = `None`. Boolean tensor(s) used to mask out attention softmax input.
             It should be 'None' for 'causal' and 'no_mask' types. For 'padding' masks, it should be
             a single tensor of [batch_size, 1, 1, seqlen_q] for self-attention, and a tuple of
             two tensors in shapes [batch_size, 1, 1, seqlen_q] and [batch_size, 1, 1, seqlen_kv]
             for cross-attention. For the 'arbitrary' mask type, it should be in a shape that is
4642
4643
4644
             broadcastable to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv]. A `True` value
             means the corresponding position is masked out and a `False` means that position is
             allowed to participate in attention.
4645
4646
        attn_mask_type: {'no_mask', 'padding', 'causal', 'padding_causal', 'arbitrary'},
                       default = `None`
4647
                       type of attention mask passed into softmax operation.
4648
4649
        window_size: Optional[Tuple[int, int]], default = `None`
                    sliding window size for local attention.
4650
4651
4652
4653
4654
4655
4656
4657
4658
4659
4660
4661
4662
4663
4664
4665
4666
4667
4668
4669
4670
4671
4672
4673
4674
        encoder_output : Optional[torch.Tensor], default = `None`
             Output of the encoder block to be fed into the decoder block if using
             `layer_type="decoder"`.
        is_first_microbatch : {True, False, None}, default = None
                             During training using either gradient accumulation or
                             pipeline parallelism a minibatch of data is further split
                             into microbatches. Between the microbatches of the same minibatch
                             the model weights are not updated. Setting this parameter indicates
                             whether the current microbatch is the first in a minibatch or not.
                             When set, this parameter enables additional optimizations:

                             * during FP8 training, it allows caching of the FP8 versions of
                               the weights
                             * it also allows skipping gradient accumulation during the
                               first microbatch (since it is the first gradient being
                               produced)
        checkpoint_core_attention: bool, default = `False`
                                  If true, forward activations for core attention are recomputed
                                  during the backward pass in order to save memory that would
                                  otherwise be occupied to store the forward activations until
                                  backprop.
        rotary_pos_emb: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], default = `None`
                       Embeddings for query and key tensors for applying rotary position
                       embedding. By default no input embedding is applied.
        core_attention_bias_type: str, default = `no_bias`
4675
                    Bias type, {`no_bias`, `pre_scale_bias`, 'post_scale_bias`, `alibi`}
4676
        core_attention_bias: Optional[torch.Tensor], default = `None`
4677
4678
                    Bias tensor for Q * K.T, shape [1, num_head, max_seqlen_q, max_seqlen_kv].
                    It should be 'None' for 'no_bias' and 'alibi' bias types.
4679
4680
4681
4682
        alibi_slopes: Optional[torch.Tensor], default = `None`
                     ALiBi slopes in FP32 and shape [nheads] or [batch_size, nheads].
                     It adds a bias of (-alibi_slope * (i + seqlen_k - seqlen_q - j))
                     to the attention score of query i and key j.
4683
4684
4685
        fast_zero_fill: bool, default = `True`
                    Whether to set output tensors to 0 or not before use.
        """
4686
4687
        # hidden_states: [sq, b, h]

4688
4689
        if attn_mask_type is not None:
            window_size = check_set_window_size(attn_mask_type, window_size)
4690
        if attn_mask_type is None:
4691
            attn_mask_type = self.attn_mask_type
4692
4693
        if window_size is None:
            window_size = self.window_size
4694

4695
4696
4697
4698
4699
        if "padding" in attn_mask_type and attention_mask is not None:
            for i,_ in enumerate(attention_mask):
                assert (
                    attention_mask[i].dtype == torch.bool
                ), "Attention mask must be in boolean type!"
4700

4701
4702
        assert (core_attention_bias_type in AttnBiasTypes
                ), f"core_attention_bias_type {core_attention_bias_type} is not supported!"
4703

4704
        # =================================================
4705
        # Pre-allocate memory for key-values for inference
4706
4707
4708
4709
        # =================================================

        if inference_params and self.layer_number is not None:
            if self.layer_number not in inference_params.key_value_memory_dict:
4710
                inf_max_seq_len = inference_params.max_sequence_length
4711
4712
                inf_max_batch_size = inference_params.max_batch_size
                inference_key_memory = self._allocate_memory(
4713
                    inf_max_seq_len, inf_max_batch_size, hidden_states.dtype
4714
4715
                )
                inference_value_memory = self._allocate_memory(
4716
                    inf_max_seq_len, inf_max_batch_size, hidden_states.dtype
4717
4718
4719
4720
4721
4722
4723
4724
4725
4726
4727
                )
                inference_params.key_value_memory_dict[self.layer_number] = (
                    inference_key_memory,
                    inference_value_memory,
                )
            else:
                (
                    inference_key_memory,
                    inference_value_memory,
                ) = inference_params.key_value_memory_dict[self.layer_number]

4728
        # ======================
4729
        # Query, Key, and Value
4730
        # ======================
4731

cyanguwa's avatar
cyanguwa committed
4732
4733
        if self.attention_type == "self":
            # Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn]
4734
4735
4736
4737
4738
4739
4740
4741
4742
4743
4744
4745
4746
            if self.input_layernorm:
                layernorm_qkv_outputs = self.layernorm_qkv(
                    hidden_states,
                    is_first_microbatch=is_first_microbatch,
                )
                if self.return_layernorm_output:
                    mixed_x_layer, layernorm_output = layernorm_qkv_outputs
                else:
                    mixed_x_layer = layernorm_qkv_outputs
            else:
                mixed_x_layer = self.qkv(
                    hidden_states,
                    is_first_microbatch=is_first_microbatch,
4747
                    is_first_module_in_mha=True, # specific to FP8 MHA
4748
4749
                )

cyanguwa's avatar
cyanguwa committed
4750
4751
            num_queries_per_key_value = (self.num_attention_heads_per_partition //
                                         self.num_gqa_groups_per_partition)
4752
            if self.qkv_weight_interleaved:
cyanguwa's avatar
cyanguwa committed
4753
                # [sq, b, ng * (np/ng + 2) * hn] --> [sq, b, ng, (np/ng + 2), hn]
4754
                new_tensor_shape = mixed_x_layer.size()[:-1] + (
cyanguwa's avatar
cyanguwa committed
4755
4756
                    self.num_gqa_groups_per_partition,
                    (num_queries_per_key_value + 2),
4757
4758
4759
4760
                    self.hidden_size_per_attention_head,
                )
                # split along second last dimension
                split_dim = -2
cyanguwa's avatar
cyanguwa committed
4761
4762
4763
4764
4765
4766
4767
4768
4769
            else:
                # [sq, b, ng * (np/ng + 2) * hn] --> [sq, b, (np/ng + 2), ng, hn]
                new_tensor_shape = mixed_x_layer.size()[:-1] + (
                    (num_queries_per_key_value + 2),
                    self.num_gqa_groups_per_partition,
                    self.hidden_size_per_attention_head
                )
                # split along third last dimension
                split_dim = -3
4770
4771
4772

            mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)

cyanguwa's avatar
cyanguwa committed
4773
4774
4775
4776
4777
4778
4779
4780
4781
            # qkv_weight_interleaved:
            #  [sq, b, ng, (np/ng + 2), hn]
            #  --> [sq, b, ng, np/ng, hn], [sq, b, ng, 1, hn], [sq, b, ng, 1, hn]
            # not qkv_weight_interleaved:
            #  [sq, b, (np/ng + 2), ng, hn]
            #  --> [sq, b, np/ng, np, hn], [sq, b, 1, ng, hn], [sq, b, 1, ng, hn]
            if not is_in_onnx_export_mode():
                query_layer, key_layer, value_layer = _SplitAlongDim.apply(
                    mixed_x_layer, split_dim, (num_queries_per_key_value, 1, 1)
4782
                )
4783
            else:
cyanguwa's avatar
cyanguwa committed
4784
4785
4786
4787
4788
4789
4790
4791
4792
4793
4794
4795
                query_layer, key_layer, value_layer = torch.split(
                    mixed_x_layer, (num_queries_per_key_value, 1, 1), dim = split_dim,
                 )

            # query: -> [sq, b, np, hn]
            # key, value: -> [sq, b, ng, hn]
            query_layer, key_layer, value_layer = (x.reshape(x.size(0), x.size(1), -1,
                                                             self.hidden_size_per_attention_head)
                                                   for x in (query_layer, key_layer, value_layer))

        elif self.attention_type == "cross":
            # Attention heads [sk, b, h] --> [sk, b, (ng * 2 * hn)]
4796
            mixed_kv_layer = self.key_value(
cyanguwa's avatar
cyanguwa committed
4797
                encoder_output,
4798
                is_first_microbatch=is_first_microbatch,
4799
                is_first_module_in_mha=True, # specific to FP8 MHA
4800
4801
4802
            )

            if self.qkv_weight_interleaved:
cyanguwa's avatar
cyanguwa committed
4803
                # [sq, b, (ng * 2 * hn)] --> [sq, b, ng, 2 * hn]
4804
                new_tensor_shape = mixed_kv_layer.size()[:-1] + (
4805
                    self.num_gqa_groups_per_partition,
4806
4807
4808
4809
4810
                    2 * self.hidden_size_per_attention_head,
                )
                # split along last dimension
                split_dim = -1
            else:
cyanguwa's avatar
cyanguwa committed
4811
                # [sq, b, (ng * 2 * hn)] --> [sq, b, 2 * ng, hn]
4812
                new_tensor_shape = mixed_kv_layer.size()[:-1] + (
4813
                    2 * self.num_gqa_groups_per_partition,
4814
4815
4816
4817
4818
4819
4820
                    self.hidden_size_per_attention_head,
                )
                # split along second last dimension
                split_dim = -2

            mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)

cyanguwa's avatar
cyanguwa committed
4821
4822
4823
4824
4825
            # mixed_kv_layer --> 2 [sk, b, ng, hn]
            if not is_in_onnx_export_mode():
                key_layer, value_layer = _SplitAlongDim.apply(
                    mixed_kv_layer, split_dim, mixed_kv_layer.shape[split_dim] // 2,
                )
4826
            else:
cyanguwa's avatar
cyanguwa committed
4827
4828
4829
                key_layer, value_layer = torch.split(
                    mixed_kv_layer, mixed_kv_layer.shape[split_dim] // 2, dim = split_dim,
                )
4830
4831
4832
            key_layer, value_layer = (x.reshape(
                x.size(0), x.size(1), -1, self.hidden_size_per_attention_head,
                ) for x in (key_layer, value_layer))
4833
4834
4835
4836
4837
4838
4839
4840
4841
4842
4843
4844
4845
4846
4847

            # Attention head [sq, b, h] --> [sq, b, hp]
            if self.input_layernorm:
                layernorm_query_outputs = self.layernorm_query(
                    hidden_states,
                    is_first_microbatch=is_first_microbatch,
                )
                if self.return_layernorm_output:
                    query_layer, layernorm_output = layernorm_query_outputs
                else:
                    query_layer = layernorm_query_outputs
            else:
                query_layer = self.query_layer(
                    hidden_states,
                    is_first_microbatch=is_first_microbatch,
4848
                    is_first_module_in_mha=True, # specific to FP8 MHA
4849
4850
4851
4852
4853
4854
4855
4856
4857
                )

            # [sq, b, hp] --> [sq, b, np, hn]
            new_tensor_shape = query_layer.size()[:-1] + (
                self.num_attention_heads_per_partition,
                self.hidden_size_per_attention_head,
            )
            query_layer = query_layer.view(*new_tensor_shape)

4858
4859
4860
        # ======================================================
        # Apply relative positional encoding (rotary embedding)
        # ======================================================
4861

4862
        if rotary_pos_emb is not None:
4863
4864
4865
            assert (not isinstance(query_layer, Float8Tensor)
                and not isinstance(key_layer, Float8Tensor)
                ), "RoPE is not supported for Float8Tensors!"
4866
            # duplicate the pos_emb for self attention
4867
4868
4869
4870
            if not isinstance(rotary_pos_emb, tuple):
                rotary_pos_emb = ((rotary_pos_emb,) * 2)

            q_pos_emb, k_pos_emb = rotary_pos_emb
4871
4872
4873
4874
4875
4876
4877
4878
4879
4880
4881
4882
4883
4884

            # adjust key and value for inference
            if inference_params is not None:
                if self.qkv_format == "sbhd":
                    sequence_length = key_layer.size(0)
                elif self.qkv_format == "bshd":
                    sequence_length = key_layer.size(1)

                sequence_start = inference_params.sequence_len_offset
                sequence_end = sequence_start + sequence_length

                q_pos_emb = q_pos_emb[sequence_start:sequence_end, ...]
                k_pos_emb = k_pos_emb[sequence_start:sequence_end, ...]

4885
4886
            query_layer = apply_rotary_pos_emb(query_layer, q_pos_emb, self.qkv_format, fused=True)
            key_layer = apply_rotary_pos_emb(key_layer, k_pos_emb, self.qkv_format, fused=True)
4887

4888
4889
4890
4891
        # ===========================
        # Core attention computation
        # ===========================

4892
4893
4894
4895
        context_layer = self.core_attention(
            query_layer,
            key_layer,
            value_layer,
4896
            qkv_format=self.qkv_format,
4897
4898
            cu_seqlens_q=None,
            cu_seqlens_kv=None,
4899
4900
            attention_mask=attention_mask,
            attn_mask_type=attn_mask_type,
4901
            window_size=window_size,
4902
4903
4904
            checkpoint_core_attention=checkpoint_core_attention,
            core_attention_bias_type=core_attention_bias_type,
            core_attention_bias=core_attention_bias,
4905
            alibi_slopes=alibi_slopes,
4906
            fast_zero_fill=fast_zero_fill,
4907
            inference_params=inference_params,
4908
4909
        )

4910
        # ===================
4911
        # Output. [sq, b, h]
4912
        # ===================
4913

4914
        projection_output = self.proj(
4915
4916
            context_layer,
            is_first_microbatch=is_first_microbatch,
4917
4918
        )

4919
4920
4921
4922
4923
4924
4925
4926
        if self.return_bias:
            attention_output, attention_bias = projection_output
        else:
            attention_output, attention_bias = projection_output, None

        outputs = (attention_output,)
        if self.return_bias:
            outputs += (attention_bias,)
4927
        if self.input_layernorm and self.return_layernorm_output:
4928
4929
            outputs += (layernorm_output,)
        return outputs if len(outputs) > 1 else outputs[0]