attention.py 223 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.

"""Attention."""
6
import collections
7
8
9
from contextlib import nullcontext
from importlib.metadata import version
import math
10
import os
11
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
12
import warnings
13

cyanguwa's avatar
cyanguwa committed
14
import numpy as np
15
from pkg_resources import packaging
16
17

import torch
18
import torch.nn.functional as F
19
20

import transformer_engine_extensions as tex
21
22
23
24
from transformer_engine.pytorch.cpp_extensions import (
    cast_to_fp8,
    cast_from_fp8,
)
25
26
27
28
29
from transformer_engine.pytorch.cpp_extensions.fused_attn import (
    fused_attn_fwd_qkvpacked,
    fused_attn_bwd_qkvpacked,
    fused_attn_fwd_kvpacked,
    fused_attn_bwd_kvpacked,
30
31
    fused_attn_fwd,
    fused_attn_bwd,
32
33
34
35
36
    QKVLayout,
    AttnBiasType,
    AttnMaskType,
    FusedAttnBackend,
)
37
38
from transformer_engine.pytorch.fp8 import get_fp8_te_dtype
from transformer_engine.pytorch.float8_tensor import Float8Tensor
39
from transformer_engine.pytorch.module import LayerNormLinear, Linear
40
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
41
42
43
44
45
from transformer_engine.pytorch.utils import (
    divide,
    attention_mask_func,
    split_tensor_along_dim,
    get_device_compute_capability,
46
    get_default_init_method,
47
48
49
50
)
from transformer_engine.pytorch.constants import (
    AttnMaskTypes,
    AttnTypes,
51
    AttnBiasTypes,
52
    QKVLayouts,
53
    dist_group_type,
54
    TE_DType,
55
56
57
58
)
from transformer_engine.pytorch.softmax import FusedScaleMaskSoftmax
from transformer_engine.pytorch.distributed import (
    get_distributed_world_size,
59
    get_distributed_rank,
60
    checkpoint,
61
62
63
    set_all_rng_states,
    CudaRNGStatesTracker,
    graph_safe_rng_available,
64
65
)
from transformer_engine.pytorch.export import is_in_onnx_export_mode
66
from transformer_engine.pytorch.jit import jit_fuser, no_torch_dynamo
67
68
from transformer_engine.pytorch.graph import is_graph_capturing

69
70

_flash_attn_version = packaging.version.Version(version("flash-attn"))
71
_flash_attn_version_required = packaging.version.Version("2.0.6")
72
_flash_attn_max_version = packaging.version.Version("2.5.8")
73
_flash_attn_2_1_plus = _flash_attn_version >= packaging.version.Version("2.1")
74
_flash_attn_2_3_plus = _flash_attn_version >= packaging.version.Version("2.3")
75
76
_flash_attn_2_4_plus = _flash_attn_version >= packaging.version.Version("2.4")
_flash_attn_2_4_1_plus = _flash_attn_version >= packaging.version.Version("2.4.1")
77

78
if _flash_attn_version >= _flash_attn_version_required:
79
    from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_forward_func # pylint: disable=no-name-in-module
80
    from flash_attn_2_cuda import varlen_bwd as flash_attn_cuda_bwd # pylint: disable=no-name-in-module
81
82
    from flash_attn.flash_attn_interface import _flash_attn_varlen_forward as _flash_attn_forward # pylint: disable=no-name-in-module,ungrouped-imports
    from flash_attn.flash_attn_interface import _flash_attn_varlen_backward as _flash_attn_backward # pylint: disable=no-name-in-module
83

84
85
86
87
88
89
META_QKV  = tex.FP8FwdTensors.GEMM1_OUTPUT
META_DQKV = tex.FP8BwdTensors.GRAD_OUTPUT1
META_O    = tex.FP8FwdTensors.GEMM2_INPUT
META_DO   = tex.FP8BwdTensors.GRAD_INPUT2
META_S    = tex.FP8FwdTensors.GEMM3_OUTPUT
META_DP   = tex.FP8BwdTensors.GRAD_INPUT3
90

91
_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
92
93
94
95
96
97
98
99
100
_alibi_cache = {
    "_num_heads": None,
    "_alibi_slopes": None,
    "_max_seqlen_q": None,
    "_max_seqlen_kv": None,
    "_alibi_bias": None,
    "_alibi_slopes_require_update": False,
    "_alibi_bias_require_update": False,
    }
101
102


103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
__all__ = ["DotProductAttention", "InferenceParams", "MultiheadAttention"]

class InferenceParams: # pylint: disable=too-few-public-methods
    """
    Inference parameters that are passed to the main model in order
    to efficienly calculate and store the context during inference.

    Parameters
    ----------
    max_batch_size : int
                    maximum batch size during inference.
    max_sequence_length : int
                         maximum sequence length during inference.
    """

    def __init__(self, max_batch_size, max_sequence_length):
        self.max_sequence_length = max_sequence_length
        self.max_batch_size = max_batch_size
        self.sequence_len_offset = 0
        self.batch_size_offset = 0
        self.key_value_memory_dict = {}

    def swap_key_value_dict(self, batch_indices):
        """
        Reorders the KV cache using the specified batch indices.

        Parameters
        ----------
        batch_indices : List[int]
                       Sequence of indices to reorder along the batch dimensions of
                       the KV cache. Must have a length equal to the batch size.
        """
        if len(self.key_value_memory_dict) == 0:
            raise ValueError("should not swap when dict in empty")

        for layer_number, inference_memory in self.key_value_memory_dict.items():
            inference_key_memory, inference_value_memory = inference_memory
            assert (
                len(batch_indices) == inference_key_memory.shape[1]
            )  # make sure batch size is the same
            new_inference_key_memory = inference_key_memory[:, batch_indices]
            new_inference_value_memory = inference_value_memory[:, batch_indices]
            self.key_value_memory_dict[layer_number] = (
                new_inference_key_memory,
                new_inference_value_memory,
            )
149

150
151
152
153
154
@torch.no_grad()
def get_alibi(
    num_heads: int,
    max_seqlen_q: int,
    max_seqlen_kv: int,
155
156
157
    alibi_slopes: Optional[torch.Tensor] = None,
    bias_dtype: Optional[torch.dtype] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
158
    """
159
160
161
162
163
164
165
166
167
168
169
170
    Parameters
    ----------
    num_heads: int
        Number of heads.
    max_seqlen_q: int
        Maximum sequence length for queries.
    max_seqlen_kv: int
        Maximum sequence length for keys and values.
    alibi_slopes: Optional[torch.Tensor], default = `None`
        Custom ALiBi slopes, FP32, CUDA tensor, in shape [num_heads] or [batch_size, num_heads].
    bias_dtype: Optional[torch.dtype], default = `None`
        Dtype of the generated ALiBi bias. If None, use torch.float32.
171

172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
    Returns
    ----------
    alibi_slopes: torch.Tensor
        ALiBi slopes in FP32 and shape [num_heads] or [batch_size, num_heads].
    alibi_bias: torch.Tensor
        ALiBi bias in FP32 or `bias_dtype`. If `alibi_slopes` is in [num_heads] shape,
        then `alibi_bias` is in [1, num_heads, max_seqlen_q, max_seqlen_kv], and if
        `alibi_slopes` is in [batch_size, num_heads], then the bias is in
        [batch_size, num_heads, max_seqlen_q, max_seqlen_kv].
    """
    global _alibi_cache
    if _alibi_cache["_alibi_slopes_require_update"]:
        if alibi_slopes is not None:
            _alibi_cache["_alibi_slopes"] = alibi_slopes
        else:
            n = 2 ** math.floor(math.log2(num_heads))
            m_0 = 2.0 ** (-8.0 / n)
            m = torch.pow(m_0, torch.arange(1, 1 + n))

            if n < num_heads:
                m_hat_0 = 2.0 ** (-4.0 / n)
                m_hat = torch.pow(m_hat_0, torch.arange(1, 1 + 2 * (num_heads - n), 2))
                m = torch.cat([m, m_hat])

            _alibi_cache["_alibi_slopes"] = m.to(dtype=torch.float32, device="cuda")
        _alibi_cache["_num_heads"] = num_heads
        _alibi_cache["_alibi_slopes_require_update"] = False

    if _alibi_cache["_alibi_bias_require_update"]:
        assert _alibi_cache["_alibi_slopes"] is not None, "ALiBi slopes can not be None!"
        if _alibi_cache["_alibi_slopes"].dim() == 1:
            slopes_shape = torch.Size([1, _alibi_cache["_alibi_slopes"].shape[0], 1, 1])
        if _alibi_cache["_alibi_slopes"].dim() == 2:
            slopes_shape = torch.Size([*_alibi_cache["_alibi_slopes"].shape[:], 1, 1])
        bias = torch.arange(
            1 - max_seqlen_kv, 1, dtype=torch.int32, device="cuda").view(1, 1, 1, max_seqlen_kv)
        bias = bias - torch.arange(
            1 - max_seqlen_q, 1, dtype=torch.int32, device="cuda").view(1, 1, max_seqlen_q, 1)
        bias = bias.abs().mul(-1)
        bias = bias * _alibi_cache["_alibi_slopes"].view(slopes_shape)
        _alibi_cache["_max_seqlen_q"], _alibi_cache["_max_seqlen_kv"] = max_seqlen_q, max_seqlen_kv
        bias_dtype = torch.float32 if bias_dtype is None else bias_dtype
        _alibi_cache["_alibi_bias"] = bias.contiguous().to(dtype=bias_dtype, device="cuda")
        _alibi_cache["_alibi_bias_require_update"] = False

    return _alibi_cache["_alibi_slopes"], _alibi_cache["_alibi_bias"]
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233


def get_cu_seqlens(mask: torch.Tensor) -> torch.Tensor:
    """
    Given a padding mask of shape [batch_size, 1, 1, max_seqlen], returns an int32
    tensor of shape [batch_size + 1] containing the cumulative sequence lengths of
    the samples in a batch.
    """
    mask = mask.squeeze(1).squeeze(1)
    reduced_mask = mask.sum(dim=1)
    cu_seqlens = reduced_mask.cumsum(dim=0).to(torch.int32)
    zero = torch.zeros(1, dtype=torch.int32, device="cuda")
    cu_seqlens = torch.cat((zero, cu_seqlens))

    return cu_seqlens

234

235
236
237
def get_cu_seqlens_and_indices(mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Given a padding mask of shape [batch_size, 1, 1, max_seqlen], returns an int32
238
239
240
    tensor of shape [batch_size + 1] containing the cumulative sequence lengths of
    the samples in a batch, and another int32 tensor of shape [batch_size * max_seqlen, 1, 1]
    containing the indices for the valid tokens.
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
    """
    mask = mask.squeeze(1).squeeze(1)
    bs, seqlen = mask.shape

    reduced_mask = mask.sum(dim=1)
    cu_seqlens = reduced_mask.cumsum(dim=0).to(torch.int32)
    zero = torch.zeros(1, dtype=torch.int32, device="cuda")
    cu_seqlens = torch.cat((zero, cu_seqlens))

    mask = mask.reshape(-1)
    indices = mask.nonzero()
    indices = indices.unsqueeze(-1)

    num_nonzeros = indices.shape[0]
    pad_amount = bs * seqlen - num_nonzeros
    indices = F.pad(input=indices, pad=(0, 0, 0, 0, 0, pad_amount),
                    mode="constant", value=float(bs * seqlen))

    return cu_seqlens, indices


262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
def get_indices(max_seqlen: int, cu_seqlens: torch.Tensor) -> torch.Tensor:
    """
    Given max_seqlen and cu_seqlens of shape [batch_size + 1], returns an int32
    tensor of shape [batch_size * max_seqlen, 1, 1] containing the indices for
    the valid tokens in a batch.
    """
    bs = len(cu_seqlens) - 1
    seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
    indices = [i*max_seqlen + ii for i,j in enumerate(seqlens) for ii in range(j)]
    indices = torch.Tensor(indices).unsqueeze(1).unsqueeze(1).to(
                    dtype=torch.int64, device="cuda")

    num_nonzeros = indices.shape[0]
    pad_amount = bs * max_seqlen - num_nonzeros
    indices = F.pad(input=indices, pad=(0, 0, 0, 0, 0, pad_amount),
                    mode="constant", value=float(bs * max_seqlen))

    return indices

281
_cu_seqlens_cache = {}
282
283
284
285
286
287
288
289
290
291
def _get_full_cu_seqlens(
    batch_size: int,
    max_seqlen: int,
    device: torch.device,
) -> torch.Tensor:
    """Cumulative sequence lengths in full data batch

    All sequences in batch have the maximum sequence length.

    """
292
293
294
295
296
297
298
299
300
301
    global _cu_seqlens_cache
    if (batch_size, max_seqlen) not in _cu_seqlens_cache:
        _cu_seqlens_cache[(batch_size, max_seqlen)] = torch.arange(
            0,
            (batch_size + 1) * max_seqlen,
            step=max_seqlen,
            dtype=torch.int32,
            device=device,
        )
    return _cu_seqlens_cache[(batch_size, max_seqlen)]
302
303


304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
@jit_fuser
def pack_tensor(
    indices: torch.Tensor,
    tensor: torch.Tensor,
) -> torch.Tensor:
    """
    Packs the given tensor using the `indices`.
    """
    padding_indice = torch.zeros(
        1, tensor.shape[1], tensor.shape[2], dtype=tensor.dtype, device=tensor.device)
    tensor = torch.cat((tensor, padding_indice), dim=0)

    indices = indices.repeat(1, tensor.shape[1], tensor.shape[2])
    packed = torch.gather(tensor, 0, indices)
    return packed


@jit_fuser
def pack_2_tensors(
    indices: torch.Tensor,
    t1: torch.Tensor,
    t2: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Packs the given 2 tensors using the `indices`.
    """
    t1_packed = pack_tensor(indices, t1)
    t2_packed = pack_tensor(indices, t2)
    return t1_packed, t2_packed


@jit_fuser
def pack_3_tensors(
    indices: torch.Tensor,
    t1: torch.Tensor,
    t2: torch.Tensor,
    t3: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Packs the given 3 tensors using the `indices`.
    """
    t1_packed = pack_tensor(indices, t1)
    t2_packed = pack_tensor(indices, t2)
    t3_packed = pack_tensor(indices, t3)
    return t1_packed, t2_packed, t3_packed


@jit_fuser
def unpack_tensor(
    indices: torch.Tensor,
    dim0: int,
    tensor: torch.Tensor,
) -> torch.Tensor:
    """
    Inverse of `pack_tensor`.
    """
    indices = indices.repeat(1, tensor.shape[1], tensor.shape[2])
    unpacked = torch.zeros(
        dim0 + 1, tensor.shape[1], tensor.shape[2], dtype=tensor.dtype, device=tensor.device)
    unpacked.scatter_(0, indices, tensor)
    unpacked = unpacked[0:-1,:,:]
    return unpacked


@jit_fuser
def unpack_2_tensors(
    indices: torch.Tensor,
    dim0: int,
    t1: torch.Tensor,
    t2: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Inverse of `pack_2_tensors`.
    """
    t1_unpacked = unpack_tensor(indices, dim0, t1)
    t2_unpacked = unpack_tensor(indices, dim0, t2)
    return t1_unpacked, t2_unpacked


@jit_fuser
def unpack_3_tensors(
    indices: torch.Tensor,
    dim0: int,
    t1: torch.Tensor,
    t2: torch.Tensor,
    t3: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Inverse of `pack_3_tensors`.
    """
    t1_unpacked = unpack_tensor(indices, dim0, t1)
    t2_unpacked = unpack_tensor(indices, dim0, t2)
    t3_unpacked = unpack_tensor(indices, dim0, t3)
    return t1_unpacked, t2_unpacked, t3_unpacked


class PackTensors(torch.autograd.Function):
    """
    Autograd function to pack tensors.
    """
    @staticmethod
    def forward(
        ctx,
        indices: torch.Tensor,
        *tensors: Tuple[torch.Tensor, ...]
    ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
        assert 1 <= len(tensors) <= 3, f"Packing {len(tensors)} tensors not supported."
        ctx.indices = indices
        ctx.dim0 = tensors[0].shape[0]
        if len(tensors) == 1:
            return pack_tensor(indices, *tensors)
        if len(tensors) == 2:
            return pack_2_tensors(indices, *tensors)
        return pack_3_tensors(indices, *tensors)

    @staticmethod
    def backward(ctx, *grad_outputs: Tuple[torch.Tensor, ...]):
        if len(grad_outputs) == 1:
            return None, unpack_tensor(ctx.indices, ctx.dim0, *grad_outputs)
        if len(grad_outputs) == 2:
            return None, *unpack_2_tensors(ctx.indices, ctx.dim0, *grad_outputs)
        return None, *unpack_3_tensors(ctx.indices, ctx.dim0, *grad_outputs)


class UnpackTensor(torch.autograd.Function):
    """
    Autograd function to unpack a tensor.
    """
    @staticmethod
    def forward(
        ctx,
        indices: torch.Tensor,
        dim0: int,
        tensor: torch.Tensor,
    ) -> torch.Tensor:
        ctx.indices = indices
        return unpack_tensor(indices, dim0, tensor)

    @staticmethod
    def backward(ctx, grad_output):
        return None, None, pack_tensor(ctx.indices, grad_output)


447
448
449
def flash_attn_p2p_communicate(rank, send_tensor, send_dst,
                               recv_tensor, recv_src,
                               cp_group, batch_p2p_comm):
450
    """Point-to-point communications of KV and dKV in Attention with context parallelism"""
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
    send_recv_ops = []

    if batch_p2p_comm:
        if rank % 2 == 0:
            send_op = torch.distributed.P2POp(torch.distributed.isend,
                                              send_tensor,
                                              send_dst,
                                              cp_group)
            recv_op = torch.distributed.P2POp(torch.distributed.irecv,
                                              recv_tensor,
                                              recv_src,
                                              cp_group)
            send_recv_ops.append(send_op)
            send_recv_ops.append(recv_op)
        else:
            recv_op = torch.distributed.P2POp(torch.distributed.irecv,
                                              recv_tensor,
                                              recv_src,
                                              cp_group)
            send_op = torch.distributed.P2POp(torch.distributed.isend,
                                              send_tensor,
                                              send_dst,
                                              cp_group)
            send_recv_ops.append(recv_op)
            send_recv_ops.append(send_op)
        send_recv_reqs = torch.distributed.batch_isend_irecv(send_recv_ops)
    else:
        if rank % 2 == 0:
            send_op = torch.distributed.isend(send_tensor, send_dst, cp_group)
            recv_op = torch.distributed.irecv(recv_tensor, recv_src, cp_group)
            send_recv_ops.append(send_op)
            send_recv_ops.append(recv_op)
        else:
            recv_op = torch.distributed.irecv(recv_tensor, recv_src, cp_group)
            send_op = torch.distributed.isend(send_tensor, send_dst, cp_group)
            send_recv_ops.append(recv_op)
            send_recv_ops.append(send_op)
        send_recv_reqs = send_recv_ops

    return send_recv_reqs


493
@jit_fuser
494
495
def flash_attn_fwd_out_correction(out, out_per_step, seq_dim,
                                  softmax_lse, softmax_lse_per_step):
496
    """Merge partial outputs of each step in Attention with context parallelism"""
497
    softmax_lse_corrected_exp = torch.exp(softmax_lse_per_step - softmax_lse).movedim(2, seq_dim)
498
499
500
501
502
    softmax_lse_corrected_exp = softmax_lse_corrected_exp.unsqueeze(-1)
    out_corrected = out_per_step*softmax_lse_corrected_exp
    out.add_(out_corrected)


503
@jit_fuser
504
def flash_attn_fwd_softmax_lse_correction(softmax_lse, softmax_lse_per_step):
505
    """Merge softmax stats of each step in Attention with context parallelism"""
506
507
508
509
    max_scale = torch.max(softmax_lse, softmax_lse_per_step)
    min_scale = torch.min(softmax_lse, softmax_lse_per_step)
    new_scale = max_scale + torch.log(1 + torch.exp(min_scale - max_scale))
    softmax_lse.copy_(new_scale)
510
511


512
class AttnFuncWithCP(torch.autograd.Function):
513
    """
514
515
    Attention implementation with context parallelism.
    Split attention compute into multiple steps, and overlap current-step
516
517
518
519
    compute with next-step communication.
    """

    @staticmethod
520
    def forward(ctx, is_training, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
521
522
                dropout_p, cp_group, cp_global_ranks, cp_stream, softmax_scale, qkv_format,
                attn_mask_type, attn_bias_type, attn_bias, deterministic, use_fused_attention):
523
524
525
526
527
528
        if softmax_scale is None:
            softmax_scale = q.shape[-1] ** (-0.5)

        cp_size = get_distributed_world_size(cp_group)
        rank = get_distributed_rank(cp_group)
        send_dst = cp_global_ranks[(rank + 1) % cp_size]
529
        recv_src = cp_global_ranks[(rank - 1) % cp_size]
530
531
        batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2)

532
533
        causal = (attn_mask_type == "causal")

534
535
        qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format

536
        if causal:
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
            if qkv_format == "bshd":
                # [b, s, np, hn] -> [b, 2, s//2, np, hn]
                q, k, v = [x.view(x.shape[0], 2, x.shape[1]//2, *x.shape[2:]) for x in [q, k, v]]
            elif qkv_format == "sbhd":
                # [s, b, np, hn] -> [2, s//2, b, np, hn]
                q, k, v = [x.view(2, x.shape[0]//2, *x.shape[1:]) for x in [q, k, v]]
        if attn_bias is not None:
            assert (len(attn_bias.shape) == 4), (
                "Only support bias shape of [b, h, sq, sk] for forward, "
                "and [1, h, sq, sk] for backward!"
            )
            # [b, np, sq, sk] -> [b, np, 2, sq//2, 2*cp, sk//(2*cp)]
            attn_bias_ = attn_bias.view( \
                *attn_bias.shape[:-2], \
                2, attn_bias.shape[-2]//2, \
                2*cp_size, attn_bias.shape[-1]//(2*cp_size) \
            )
            # [b, np, sq, sk] -> [b, np, sq, 2*cp, sk//(2*cp)]
            attn_bias = attn_bias.view( \
                *attn_bias.shape[:-1], \
                2*cp_size, attn_bias.shape[-1]//(2*cp_size) \
            )
559
        assert(q.shape[-1] % 8 == 0), "hidden size per attention head should be multiple of 8"
560
561
562
563
564
        fa_optional_forward_kwargs = {}
        if _flash_attn_2_3_plus:
            fa_optional_forward_kwargs["window_size"] = [-1, 0] if causal else [-1, -1]
        if _flash_attn_2_4_plus:
            fa_optional_forward_kwargs["alibi_slopes"] = None
565

566
567
568
        # Flash Attn inputs
        q_inputs = [None, None]
        kv_inputs = [None, None]
569
        attn_bias_inputs = [None, None]
570
571
572
573
        # Flash Attn outputs
        out_per_step = [None for _ in range(cp_size)]
        softmax_lse_per_step = [None for _ in range(cp_size)]
        rng_states = [None for _ in range(cp_size)]
574
        attn_biases = [None for _ in range(cp_size)]
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604

        # create two streams to resolve wave quantization issue of Flash Attn in each step
        flash_attn_streams = [torch.cuda.current_stream(), cp_stream]
        # synchronize fwd results correction across steps
        fwd_results_correction_done = torch.cuda.Event()

        p2p_comm_buffers = [None for _ in range(cp_size)]
        p2p_comm_buffers[0] = torch.cat((k.unsqueeze(0), v.unsqueeze(0)), dim=0)
        send_recv_reqs = [[], []]

        for i in range(cp_size+1):
            if i < cp_size:
                with torch.cuda.stream(flash_attn_streams[i%2]):
                    # wait until KV is received
                    for req in send_recv_reqs[(i+1)%2]:
                        req.wait()

                    if i < (cp_size-1):
                        p2p_comm_buffers[i+1] = torch.empty_like(p2p_comm_buffers[i])
                        send_recv_reqs[i%2] = flash_attn_p2p_communicate(rank,
                                                                         p2p_comm_buffers[i],
                                                                         send_dst,
                                                                         p2p_comm_buffers[i+1],
                                                                         recv_src,
                                                                         cp_group,
                                                                         batch_p2p_comm)

                    kv_inputs[i%2] = p2p_comm_buffers[i]
                    if causal:
                        if i == 0:
605
                            if use_fused_attention:
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
                                if qkv_format == "bshd":
                                    # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
                                    q_inputs[i%2] = q.view(q.shape[0], -1, *q.shape[-2:])
                                    # [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn]
                                    kv_inputs[i%2] = kv_inputs[i%2].view(
                                        2, k.shape[0], -1, *k.shape[-2:])
                                elif qkv_format == "sbhd":
                                    # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
                                    q_inputs[i%2] = q.view(-1, *q.shape[-3:])
                                    # [2, 2, sk//2, b, np, hn] -> [2, sk, b, np, hn]
                                    kv_inputs[i%2] = kv_inputs[i%2].view(
                                        2, -1, *k.shape[-3:])
                                if attn_bias is not None:
                                    idx = (rank - i) % cp_size
                                    attn_bias_inputs[i%2] = torch.cat(
                                        (attn_bias[..., idx, :], \
                                         attn_bias[..., (2*cp_size-idx-1), :]),
                                        dim=-1
                                    ).contiguous()
                                out_per_step[i], [softmax_lse_per_step[i], rng_states[i], *rest] = \
626
627
628
629
630
631
                                fused_attn_fwd(
                                    is_training, max_seqlen_q, max_seqlen_k, cu_seqlens_q,
                                    cu_seqlens_k, q_inputs[i%2], kv_inputs[i%2][0],
                                    kv_inputs[i%2][1], TE_DType[q.dtype],
                                    tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
                                    attn_scale=softmax_scale, dropout=dropout_p,
632
633
                                    qkv_layout=qkv_layout, attn_mask_type="causal",
                                    attn_bias_type=attn_bias_type, attn_bias=attn_bias_inputs[i%2],
634
                                )
635
636
                                if len(rest) > 0:
                                    attn_biases[i] = rest[0]
637
638
639
640
641
642
643
644
645
646
647
648
                            else:
                                # [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
                                q_inputs[i%2] = q.view(-1, *q.shape[-2:])
                                # [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn]
                                kv_inputs[i%2] = kv_inputs[i%2].view(2, -1, *k.shape[-2:])
                                _, _, _, _, out_per_step[i], \
                                softmax_lse_per_step[i], _, rng_states[i] = _flash_attn_forward(
                                    q_inputs[i%2], kv_inputs[i%2][0], kv_inputs[i%2][1],
                                    cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
                                    dropout_p, softmax_scale, causal=True, return_softmax=False,
                                    **fa_optional_forward_kwargs
                                )
649
                        elif i <= rank:
650
                            if use_fused_attention:
651
652
653
654
655
656
657
658
659
660
661
662
663
664
                                if qkv_format == "bshd":
                                    # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
                                    q_inputs[i%2] = q.view(q.shape[0], -1, *q.shape[-2:])
                                    # [2, b, 2, sk//2, np, hn] -> [2, b, sk//2, np, hn]
                                    kv_inputs[i%2] = kv_inputs[i%2][:, :, 0, ...].contiguous()
                                elif qkv_format == "sbhd":
                                    # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
                                    q_inputs[i%2] = q.view(-1, *q.shape[-3:])
                                    # [2, 2, sk//2, b, np, hn] -> [2, sk//2, b, np, hn]
                                    kv_inputs[i%2] = kv_inputs[i%2][:, 0, ...].contiguous()
                                if attn_bias is not None:
                                    idx = (rank - i) % cp_size
                                    attn_bias_inputs[i%2] = attn_bias[..., idx, :].contiguous()
                                out_per_step[i], [softmax_lse_per_step[i], rng_states[i], *rest] = \
665
666
667
668
669
670
                                fused_attn_fwd(
                                    is_training, max_seqlen_q, max_seqlen_k//2, cu_seqlens_q,
                                    cu_seqlens_k//2, q_inputs[i%2], kv_inputs[i%2][0],
                                    kv_inputs[i%2][1], TE_DType[q.dtype],
                                    tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
                                    attn_scale=softmax_scale, dropout=dropout_p,
671
672
                                    qkv_layout=qkv_layout, attn_mask_type="no_mask",
                                    attn_bias_type=attn_bias_type, attn_bias=attn_bias_inputs[i%2],
673
                                )
674
675
                                if len(rest) > 0:
                                    attn_biases[i] = rest[0]
676
677
678
                            else:
                                # [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
                                q_inputs[i%2] = q.view(-1, *q.shape[-2:])
679
680
681
682
683
684
685
                                if qkv_format == "thd":
                                    # [2, t, np, hn] -> [2, t/2, np, hn]
                                    kv_inputs[i%2] = tex.thd_read_half_tensor(
                                        kv_inputs[i%2], cu_seqlens_k, 0)
                                else:
                                    # [2, b, 2, sk//2, np, hn] -> [2, b, sk//2, np, hn]
                                    kv_inputs[i%2] = kv_inputs[i%2][:, :, 0, ...].contiguous()
686
687
688
689
690
691
692
693
694
695
696
697
698
                                # [2, b, sk//2, np, hn] -> [2, b*sk//2, np, hn]
                                kv_inputs[i%2] = kv_inputs[i%2].view(2, -1, *k.shape[-2:])
                                if _flash_attn_2_3_plus:
                                    fa_optional_forward_kwargs["window_size"] = [-1, -1]
                                _, _, _, _, out_per_step[i], \
                                softmax_lse_per_step[i], _, rng_states[i] = _flash_attn_forward(
                                    q_inputs[i%2], kv_inputs[i%2][0], kv_inputs[i%2][1],
                                    cu_seqlens_q, cu_seqlens_k//2, max_seqlen_q, max_seqlen_k//2,
                                    dropout_p, softmax_scale, causal=False, return_softmax=False,
                                    **fa_optional_forward_kwargs
                                )
                        else:
                            if use_fused_attention:
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
                                if qkv_format == "bshd":
                                    # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
                                    q_inputs[i%2] = q[:, 1, ...].contiguous()
                                    # [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn]
                                    kv_inputs[i%2] = kv_inputs[i%2].view(
                                        2, k.shape[0], -1, *k.shape[-2:])
                                elif qkv_format == "sbhd":
                                    # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn]
                                    q_inputs[i%2] = q[1].contiguous()
                                    # [2, 2, sk//2, b, np, hn] -> [2, sk, b, np, hn]
                                    kv_inputs[i%2] = kv_inputs[i%2].view(
                                        2, -1, *k.shape[-3:])
                                if attn_bias is not None:
                                    idx = (rank - i) % cp_size
                                    attn_bias_inputs[i%2] = torch.cat(
                                        (attn_bias_[..., 1, :, idx, :], \
                                         attn_bias_[..., 1, :, (2*cp_size-idx-1), :]),
                                        dim=-1
                                    ).contiguous()
                                out_per_step[i], [softmax_lse_per_step[i], rng_states[i], *rest] = \
719
720
721
722
723
724
                                fused_attn_fwd(
                                    is_training, max_seqlen_q//2, max_seqlen_k, cu_seqlens_q//2,
                                    cu_seqlens_k, q_inputs[i%2], kv_inputs[i%2][0],
                                    kv_inputs[i%2][1], TE_DType[q.dtype],
                                    tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
                                    attn_scale=softmax_scale, dropout=dropout_p,
725
726
                                    qkv_layout=qkv_layout, attn_mask_type="no_mask",
                                    attn_bias_type=attn_bias_type, attn_bias=attn_bias_inputs[i%2],
727
                                )
728
729
                                if len(rest) > 0:
                                    attn_biases[i] = rest[0]
730
                            else:
731
732
733
734
735
736
737
                                if qkv_format == "thd":
                                    # [t, np, hn] -> [t/2, np, hn]
                                    q_inputs[i%2] = tex.thd_read_half_tensor(q, cu_seqlens_q, 1)
                                else:
                                    # [b, 2, sq//2, np, hn]->[b, sq//2, np, hn]->[b*sq//2, np, hn]
                                    q_inputs[i%2] = \
                                        q[:, 1, ...].contiguous().view(-1, *q.shape[-2:])
738
739
740
741
742
743
744
745
746
747
748
749
750
                                # [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn]
                                kv_inputs[i%2] = kv_inputs[i%2].view(2, -1, *k.shape[-2:])
                                if _flash_attn_2_3_plus:
                                    fa_optional_forward_kwargs["window_size"] = [-1, -1]
                                _, _, _, _, out_per_step[i], \
                                softmax_lse_per_step[i], _, rng_states[i] = _flash_attn_forward(
                                    q_inputs[i%2], kv_inputs[i%2][0], kv_inputs[i%2][1],
                                    cu_seqlens_q//2, cu_seqlens_k, max_seqlen_q//2, max_seqlen_k,
                                    dropout_p, softmax_scale, causal=False, return_softmax=False,
                                    **fa_optional_forward_kwargs
                                )
                    else:
                        if use_fused_attention:
751
752
753
754
755
756
757
                            if attn_bias is not None:
                                idx = (rank - i) % cp_size
                                attn_bias_inputs[i%2] = torch.cat(
                                    (attn_bias[..., idx, :], attn_bias[..., (2*cp_size-idx-1), :]),
                                    dim=-1
                                ).contiguous()
                            out_per_step[i], [softmax_lse_per_step[i], rng_states[i], *rest] = \
758
759
760
761
762
763
                            fused_attn_fwd(
                                is_training, max_seqlen_q, max_seqlen_k, cu_seqlens_q,
                                cu_seqlens_k, q, kv_inputs[i%2][0],
                                kv_inputs[i%2][1], TE_DType[q.dtype],
                                tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
                                attn_scale=softmax_scale, dropout=dropout_p,
764
765
                                qkv_layout=qkv_layout, attn_mask_type="no_mask",
                                attn_bias_type=attn_bias_type, attn_bias=attn_bias_inputs[i%2],
766
                            )
767
768
                            if len(rest) > 0:
                                attn_biases[i] = rest[0]
769
                        else:
770
771
772
                            # [b, sq, np, hn] -> [b*sq, np, hn]
                            q_inputs[i%2] = q.view(-1, *q.shape[-2:])
                            # [2, b, sk, np, hn] -> [2, b*sk, np, hn]
773
                            kv_inputs[i%2] = kv_inputs[i%2].view(2, -1, *k.shape[-2:])
774
775
776
                            _, _, _, _, out_per_step[i], \
                            softmax_lse_per_step[i], _, rng_states[i] = _flash_attn_forward(
                                q_inputs[i%2], kv_inputs[i%2][0], kv_inputs[i%2][1],
777
778
779
                                cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
                                dropout_p, softmax_scale, causal=False, return_softmax=False,
                                **fa_optional_forward_kwargs
780
                            )
781
782
783
784
785
786

            if i > 0:
                # wait until fwd restuls correction of last step is done
                if i > 1:
                    flash_attn_streams[(i-1)%2].wait_event(fwd_results_correction_done)

787
788
789
790
                if use_fused_attention:
                    # [b, np, sq, 1] -> [b, np, sq]
                    softmax_lse_per_step[i-1].squeeze_(-1)

791
                with torch.cuda.stream(flash_attn_streams[(i-1)%2]):
792
793
794
                    if i == 1:
                        out = torch.empty_like(q).zero_()
                        softmax_lse = torch.clone(softmax_lse_per_step[0]).to(torch.double)
795
                        if causal and qkv_format != "thd":
796
797
798
799
                            # [b, np, sq] -> [b, np, 2, sq//2]
                            softmax_lse_ = softmax_lse.view(
                                *softmax_lse.shape[:-1], 2, softmax_lse.shape[-1]//2
                            )
800
801
802
                    elif (i-1) <= rank or not causal:
                        flash_attn_fwd_softmax_lse_correction(softmax_lse,
                                                              softmax_lse_per_step[i-1])
803
                    else:
804
805
806
807
808
809
810
811
                        if qkv_format == "thd":
                            tex.thd_second_half_lse_correction(softmax_lse,
                                                               softmax_lse_per_step[i-1],
                                                               cu_seqlens_q,
                                                               q.size(0))
                        else:
                            flash_attn_fwd_softmax_lse_correction(softmax_lse_[..., 1, :],
                                                                  softmax_lse_per_step[i-1])
812
813
814
815
816
817
818

                if i < cp_size:
                    flash_attn_streams[(i-1)%2].record_event(fwd_results_correction_done)

        torch.cuda.current_stream().wait_stream(flash_attn_streams[1])

        softmax_lse = softmax_lse.to(torch.float)
819
820
        if qkv_format in ["bshd", "sbhd"]:
            seq_dim = qkv_format.index("s")
821
        for i in range(cp_size):
822
823
824
825
826
827
            if qkv_format == "bshd":
                out_per_step[i] = out_per_step[i].view(out.shape[0], -1, *out.shape[-2:])
                out_ = out[:, 1, ...]
            elif qkv_format == "sbhd":
                out_per_step[i] = out_per_step[i].view(-1, *out.shape[-3:])
                out_ = out[1]
828

829
            if i <= rank or not causal:
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
                if qkv_format in ["bshd", "sbhd"]:
                    flash_attn_fwd_out_correction(out.view(*out_per_step[i].shape),
                                                  out_per_step[i],
                                                  seq_dim,
                                                  softmax_lse,
                                                  softmax_lse_per_step[i])
                elif qkv_format == "thd":
                    tex.thd_out_correction(out,
                                           out_per_step[i],
                                           softmax_lse,
                                           softmax_lse_per_step[i],
                                           cu_seqlens_q,
                                           False)
                else:
                    assert False, f"{qkv_format} is an unsupported qkv_format!"
845
            else:
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
                if qkv_format in ["bshd", "sbhd"]:
                    flash_attn_fwd_out_correction(out_,
                                                  out_per_step[i],
                                                  seq_dim,
                                                  softmax_lse_[..., 1, :],
                                                  softmax_lse_per_step[i])
                elif qkv_format == "thd":
                    tex.thd_out_correction(out,
                                           out_per_step[i],
                                           softmax_lse,
                                           softmax_lse_per_step[i],
                                           cu_seqlens_q,
                                           True)
                else:
                    assert False, f"{qkv_format} is an unsupported qkv_format!"
861
862

        kv = p2p_comm_buffers[-1]
863
        if use_fused_attention:
864
865
866
867
            if qkv_format == "bshd":
                out = out.view(out.shape[0], -1, *out.shape[-2:])
            elif qkv_format == "sbhd":
                out = out.view(-1, *out.shape[-3:])
868
869
        else:
            out = out.view(-1, *out.shape[-2:])
870

871
872
873
874
875
876
877
878
879
        ctx.save_for_backward(q, kv, out, softmax_lse, cu_seqlens_q, cu_seqlens_k)
        ctx.rng_states = rng_states
        ctx.cp_group = cp_group
        ctx.cp_global_ranks = cp_global_ranks
        ctx.dropout_p = dropout_p
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_k = max_seqlen_k
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
880
881
882
883
        ctx.qkv_format = qkv_format
        ctx.attn_bias_type = attn_bias_type
        ctx.attn_bias_shape = None if attn_bias is None else attn_bias.shape
        ctx.attn_biases = attn_biases
884
        ctx.deterministic = deterministic
885
        ctx.use_fused_attention = use_fused_attention
886
887
888
889
890
891
892
893
        return out

    @staticmethod
    def backward(ctx, dout):
        q, kv, out, softmax_lse, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors

        cp_size = get_distributed_world_size(ctx.cp_group)
        rank = get_distributed_rank(ctx.cp_group)
894
        send_dst = ctx.cp_global_ranks[(rank - 1) % cp_size]
895
896
897
        recv_src = ctx.cp_global_ranks[(rank + 1) % cp_size]
        batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2)

898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
        qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format

        if ctx.attn_biases[0] is not None:
            # [b, np, sq, 2*cp, sk//(2*cp)]
            attn_dbias = torch.zeros(
                *ctx.attn_bias_shape,
                dtype=ctx.attn_biases[0].dtype,
                device=ctx.attn_biases[0].device
            )
            # [b, np, sq, 2*cp, sk//(2*cp)] -> [b, np, 2, sq//2, 2*cp, sk//(2*cp)]
            attn_dbias_ = attn_dbias.view(
                *attn_dbias.shape[:-3], 2, attn_dbias.shape[-3]//2, *attn_dbias.shape[-2:]
            )
        else:
            attn_dbias = None

914
        if ctx.causal:
915
916
917
918
919
920
921
922
923
924
925
            if ctx.qkv_format == "thd":
                softmax_lse_ = tex.thd_read_second_half_lse(softmax_lse, cu_seqlens_q, q.size(0))
            else:
                # [b, np, sq] -> [b, np, 2, sq//2]
                softmax_lse_ = \
                    softmax_lse.view(*softmax_lse.shape[:-1], 2, softmax_lse.shape[-1]//2)
                softmax_lse_ = softmax_lse_[..., 1, :].contiguous()
                if ctx.use_fused_attention:
                    # [b, np, sq//2] -> [b, np, sq//2, 1]
                    softmax_lse_.unsqueeze_(-1)

926
927
928
        if ctx.use_fused_attention:
            # [b, np, sq] -> [b, np, sq, 1]
            softmax_lse.unsqueeze_(-1)
929
930
931
932
933
934
935
936
937
938
        out = out.view(*q.shape)
        dout = dout.view(*q.shape)
        # Flash Attn outputs
        dq = torch.empty_like(q)

        p2p_comm_buffers = [torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device), \
                            torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device)]
        p2p_comm_buffers[0][0].copy_(kv)
        send_recv_reqs = []

939
940
941
942
943
944
        fa_optional_backward_kwargs = {}
        if _flash_attn_2_4_plus:
            fa_optional_backward_kwargs["alibi_slopes"] = None
        if _flash_attn_2_4_1_plus:
            fa_optional_backward_kwargs["deterministic"] = ctx.deterministic

945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
        for i in range(cp_size):
            # wait until KV is received
            for req in send_recv_reqs:
                req.wait()

            send_tensor = p2p_comm_buffers[i%2]
            recv_tensor = p2p_comm_buffers[(i+1)%2]
            if i == 0:
                send_tensor = send_tensor[0]
                recv_tensor = recv_tensor[0]
            if i == (cp_size-1):
                send_tensor = send_tensor[1]
                recv_tensor = recv_tensor[1]

            send_recv_reqs = flash_attn_p2p_communicate(rank,
                                                        send_tensor,
                                                        send_dst,
                                                        recv_tensor,
                                                        recv_src,
                                                        ctx.cp_group,
                                                        batch_p2p_comm)

            kv = p2p_comm_buffers[i%2][0]
            # In reversed order of fwd
            if ctx.causal:
                if i == (cp_size-1):
971
                    if ctx.use_fused_attention:
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
                        if ctx.qkv_format == "bshd":
                            # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
                            q_ = q.view(q.shape[0], -1, *q.shape[-2:])
                            # [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn]
                            kv_ = kv.view(*kv.shape[0:2], -1, *kv.shape[-2:])
                            # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
                            out_ = out.view(out.shape[0], -1, *out.shape[-2:])
                            dout_ = dout.view(dout.shape[0], -1, *dout.shape[-2:])
                        elif ctx.qkv_format == "sbhd":
                            # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
                            q_ = q.view(-1, *q.shape[-3:])
                            # [2, 2, sk//2, b, np, hn] -> [2, sk, b, np, hn]
                            kv_ = kv.view(kv.shape[0], -1, *kv.shape[-3:])
                            # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
                            out_ = out.view(-1, *out.shape[-3:])
                            dout_ = dout.view(-1, *dout.shape[-3:])
                        aux_ctx_tensors = [softmax_lse, ctx.rng_states[cp_size-i-1]]
                        if attn_dbias is not None:
                            aux_ctx_tensors += [ctx.attn_biases[cp_size-i-1]]
                        dq_, dk_, dv_, dbias_ = fused_attn_bwd(
992
993
                            ctx.max_seqlen_q, ctx.max_seqlen_k,
                            cu_seqlens_q, cu_seqlens_k,
994
995
                            q_, kv_[0], kv_[1], out_, dout_,
                            TE_DType[q.dtype], TE_DType[kv.dtype], aux_ctx_tensors,
996
997
998
                            tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
                            attn_scale=ctx.softmax_scale,
                            dropout=ctx.dropout_p,
999
                            qkv_layout=qkv_layout,
1000
                            attn_mask_type="causal",
1001
                            attn_bias_type=ctx.attn_bias_type,
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
                        )
                    else:
                        # [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
                        q_ = q.view(-1, *q.shape[-2:])
                        dq_ = torch.empty_like(q_)
                        # [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn]
                        kv_ = kv.view(2, -1, *kv.shape[-2:])
                        dkv_ = torch.empty_like(kv_)
                        # [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
                        out_ = out.view(-1, *out.shape[-2:])
                        dout_ = dout.view(-1, *dout.shape[-2:])
                        if _flash_attn_2_3_plus:
                            fa_optional_backward_kwargs["window_size"] = [-1, 0]
                        _flash_attn_backward(
                            dout_, q_, kv_[0], kv_[1], out_, softmax_lse,
                            dq_, dkv_[0], dkv_[1], cu_seqlens_q, cu_seqlens_k,
                            ctx.max_seqlen_q, ctx.max_seqlen_k,
                            ctx.dropout_p, ctx.softmax_scale, True,
                            rng_state=ctx.rng_states[cp_size-i-1],
                            **fa_optional_backward_kwargs
                        )
                elif i >= (cp_size-rank-1):
                    if ctx.use_fused_attention:
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
                        if ctx.qkv_format == "bshd":
                            # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
                            q_ = q.view(q.shape[0], -1, *q.shape[-2:])
                            # [2, b, 2, sk//2, np, hn] -> [2, b, sk//2, np, hn]
                            kv_ = kv[:, :, 0, ...].contiguous()
                            # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
                            out_ = out.view(out.shape[0], -1, *out.shape[-2:])
                            dout_ = dout.view(dout.shape[0], -1, *dout.shape[-2:])
                        elif ctx.qkv_format == "sbhd":
                            # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
                            q_ = q.view(-1, *q.shape[-3:])
                            # [2, 2, sk//2, b, np, hn] -> [2, sk//2, b, np, hn]
                            kv_ = kv[:, 0, ...].contiguous()
                            # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
                            out_ = out.view(-1, *out.shape[-3:])
                            dout_ = dout.view(-1, *dout.shape[-3:])
                        aux_ctx_tensors = [softmax_lse, ctx.rng_states[cp_size-i-1]]
                        if attn_dbias is not None:
                            aux_ctx_tensors += [ctx.attn_biases[cp_size-i-1]]
                        dq_, dk_, dv_, dbias_ = fused_attn_bwd(
1045
1046
                            ctx.max_seqlen_q, ctx.max_seqlen_k//2,
                            cu_seqlens_q, cu_seqlens_k//2,
1047
1048
                            q_, kv_[0], kv_[1], out_, dout_,
                            TE_DType[q.dtype], TE_DType[kv.dtype], aux_ctx_tensors,
1049
1050
1051
                            tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
                            attn_scale=ctx.softmax_scale,
                            dropout=ctx.dropout_p,
1052
                            qkv_layout=qkv_layout,
1053
                            attn_mask_type="no_mask",
1054
                            attn_bias_type=ctx.attn_bias_type,
1055
1056
1057
1058
1059
                        )
                    else:
                        # [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
                        q_ = q.view(-1, *q.shape[-2:])
                        dq_ = torch.empty_like(q_)
1060
1061
1062
1063
1064
1065
                        if ctx.qkv_format == "thd":
                            # [2, t, np, hn] -> [2, t/2, np, hn]
                            kv_ = tex.thd_read_half_tensor(kv, cu_seqlens_k, 0)
                        else:
                            # [2, b, 2, sk//2, np, hn]->[2, b, sk//2, np, hn]->[2, b*sk//2, np, hn]
                            kv_ = kv[:, :, 0, ...].contiguous().view(2, -1, *kv.shape[-2:])
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
                        dkv_ = torch.empty_like(kv_)
                        # [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
                        out_ = out.view(-1, *out.shape[-2:])
                        dout_ = dout.view(-1, *dout.shape[-2:])
                        if _flash_attn_2_3_plus:
                            fa_optional_backward_kwargs["window_size"] = [-1, -1]
                        _flash_attn_backward(
                            dout_, q_, kv_[0], kv_[1], out_, softmax_lse,
                            dq_, dkv_[0], dkv_[1], cu_seqlens_q, cu_seqlens_k//2,
                            ctx.max_seqlen_q, ctx.max_seqlen_k//2,
                            ctx.dropout_p, ctx.softmax_scale, False,
                            rng_state=ctx.rng_states[cp_size-i-1],
                            **fa_optional_backward_kwargs
                        )
                else:
                    if ctx.use_fused_attention:
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
                        if ctx.qkv_format == "bshd":
                            # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
                            q_ = q[:, 1, ...].contiguous()
                            # [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn]
                            kv_ = kv.view(*kv.shape[0:2], -1, *kv.shape[-2:])
                            # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
                            out_ = out[:, 1, ...].contiguous()
                            dout_ = dout[:, 1, ...].contiguous()
                        elif ctx.qkv_format == "sbhd":
                            # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn]
                            q_ = q[1].contiguous()
                            # [2, 2, sk//2, b, np, hn] -> [2, sk, b, np, hn]
                            kv_ = kv.view(kv.shape[0], -1, *kv.shape[-3:])
                            # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn]
                            out_ = out[1].contiguous()
                            dout_ = dout[1].contiguous()
                        aux_ctx_tensors = [softmax_lse_, ctx.rng_states[cp_size-i-1]]
                        if attn_dbias is not None:
                            aux_ctx_tensors += [ctx.attn_biases[cp_size-i-1]]
                        dq_, dk_, dv_, dbias_ = fused_attn_bwd(
1102
1103
                            ctx.max_seqlen_q//2, ctx.max_seqlen_k,
                            cu_seqlens_q//2, cu_seqlens_k,
1104
1105
                            q_, kv_[0], kv_[1], out_, dout_,
                            TE_DType[q.dtype], TE_DType[kv.dtype], aux_ctx_tensors,
1106
1107
1108
                            tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
                            attn_scale=ctx.softmax_scale,
                            dropout=ctx.dropout_p,
1109
                            qkv_layout=qkv_layout,
1110
                            attn_mask_type="no_mask",
1111
                            attn_bias_type=ctx.attn_bias_type,
1112
1113
                        )
                    else:
1114
1115
1116
1117
1118
1119
                        if ctx.qkv_format == "thd":
                            # [t, np, hn] -> [t/2, np, hn]
                            q_ = tex.thd_read_half_tensor(q, cu_seqlens_q, 1)
                        else:
                            # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] -> [b*sq//2, np, hn]
                            q_ = q[:, 1, ...].contiguous().view(-1, *q.shape[-2:])
1120
1121
1122
1123
                        dq_ = torch.empty_like(q_)
                        # [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn]
                        kv_ = kv.view(2, -1, *kv.shape[-2:])
                        dkv_ = torch.empty_like(kv_)
1124
1125
1126
1127
1128
1129
1130
                        if ctx.qkv_format == "thd":
                            out_ = tex.thd_read_half_tensor(out, cu_seqlens_q, 1)
                            dout_ = tex.thd_read_half_tensor(dout, cu_seqlens_q, 1)
                        else:
                            # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] -> [b*sq//2, np, hn]
                            out_ = out[:, 1, ...].contiguous().view(-1, *out.shape[-2:])
                            dout_ = dout[:, 1, ...].contiguous().view(-1, *dout.shape[-2:])
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
                        if _flash_attn_2_3_plus:
                            fa_optional_backward_kwargs["window_size"] = [-1, -1]
                        _flash_attn_backward(
                            dout_, q_, kv_[0], kv_[1], out_, softmax_lse_,
                            dq_, dkv_[0], dkv_[1], cu_seqlens_q//2, cu_seqlens_k,
                            ctx.max_seqlen_q//2, ctx.max_seqlen_k,
                            ctx.dropout_p, ctx.softmax_scale, False,
                            rng_state=ctx.rng_states[cp_size-i-1],
                            **fa_optional_backward_kwargs
                        )
            else:
                if ctx.use_fused_attention:
1143
1144
1145
1146
                    aux_ctx_tensors = [softmax_lse, ctx.rng_states[cp_size-i-1]]
                    if attn_dbias is not None:
                        aux_ctx_tensors += [ctx.attn_biases[cp_size-i-1]]
                    dq_, dk_, dv_, dbias_ = fused_attn_bwd(
1147
1148
                        ctx.max_seqlen_q, ctx.max_seqlen_k,
                        cu_seqlens_q, cu_seqlens_k,
1149
1150
                        q, kv[0], kv[1], out, dout,
                        TE_DType[q.dtype], TE_DType[kv.dtype], aux_ctx_tensors,
1151
1152
1153
                        tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
                        attn_scale=ctx.softmax_scale,
                        dropout=ctx.dropout_p,
1154
                        qkv_layout=qkv_layout,
1155
                        attn_mask_type="no_mask",
1156
                        attn_bias_type=ctx.attn_bias_type,
1157
1158
1159
                    )
                else:
                    # [b, sq, np, hn] -> [b*sq, np, hn]
1160
1161
                    q_ = q.view(-1, *q.shape[-2:])
                    dq_ = torch.empty_like(q_)
1162
                    # [2, b, sk, np, hn] -> [2, b*sk, np, hn]
1163
1164
                    kv_ = kv.view(2, -1, *kv.shape[-2:])
                    dkv_ = torch.empty_like(kv_)
1165
                    # [b, sq, np, hn] -> [b*sq, np, hn]
1166
1167
                    out_ = out.view(-1, *out.shape[-2:])
                    dout_ = dout.view(-1, *dout.shape[-2:])
1168
1169
                    if _flash_attn_2_3_plus:
                        fa_optional_backward_kwargs["window_size"] = [-1, -1]
1170
1171
1172
1173
1174
                    _flash_attn_backward(
                        dout_, q_, kv_[0], kv_[1], out_, softmax_lse,
                        dq_, dkv_[0], dkv_[1], cu_seqlens_q, cu_seqlens_k,
                        ctx.max_seqlen_q, ctx.max_seqlen_k,
                        ctx.dropout_p, ctx.softmax_scale, False,
1175
                        **fa_optional_backward_kwargs
1176
1177
                    )

1178
1179
1180
1181
1182
            if i >= (cp_size-rank-1) or not ctx.causal:
                # [b*sq, np, hn] -> [b, 2, sq//2, np, hn] if causal
                # [b*sq, np, hn] -> [b, sq, np, hn] if not causal
                dq_ = dq_.view(*dq.shape)
            else:
1183
1184
1185
1186
1187
1188
                if ctx.qkv_format == "bshd":
                    # [b*sq//2, np, hn] -> [b, sq//2, np, hn]
                    dq_ = dq_.view(dq.shape[0], *dq.shape[2:])
                elif ctx.qkv_format == "sbhd":
                    # [b*sq//2, np, hn] -> [sq//2, b, np, hn]
                    dq_ = dq_.view(-1, *dq.shape[-3:])
1189

1190
            if ctx.causal:
1191
1192
1193
1194
1195
1196
                if i > (cp_size-rank-1):
                    dq.add_(dq_)
                elif i == (cp_size-rank-1):
                    if rank == (cp_size-1):
                        dq.copy_(dq_)
                    else:
1197
1198
1199
1200
1201
1202
                        if ctx.qkv_format == "bshd":
                            dq[:, 0, ...].copy_(dq_[:, 0, ...])
                            dq[:, 1, ...].add_(dq_[:, 1, ...])
                        elif ctx.qkv_format == "sbhd":
                            dq[0].copy_(dq_[0])
                            dq[1].add_(dq_[1])
1203
1204
                        elif ctx.qkv_format == "thd":
                            tex.thd_grad_correction(dq, dq_, cu_seqlens_q, "copy", "add")
1205
                elif i > 0:
1206
1207
1208
1209
                    if ctx.qkv_format == "bshd":
                        dq[:, 1, ...].add_(dq_)
                    elif ctx.qkv_format == "sbhd":
                        dq[1].add_(dq_)
1210
1211
                    elif ctx.qkv_format == "thd":
                        tex.thd_grad_correction(dq, dq_, cu_seqlens_q, "none", "add")
1212
                else:
1213
1214
1215
1216
                    if ctx.qkv_format == "bshd":
                        dq[:, 1, ...].copy_(dq_)
                    elif ctx.qkv_format == "sbhd":
                        dq[1].copy_(dq_)
1217
1218
                    elif ctx.qkv_format == "thd":
                        tex.thd_grad_correction(dq, dq_, cu_seqlens_q, "none", "copy")
1219
1220
1221
1222
1223
            else:
                if i == 0:
                    dq.copy_(dq_)
                else:
                    dq.add_(dq_)
1224

1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
            if attn_dbias is not None:
                idx = (rank+i+1)%cp_size
                if i == (cp_size - 1) or not ctx.causal:
                    # [b, np, sq, sk//cp] -> [b, np, sq, 2, sk//(2*cp)]
                    dbias_ = dbias_.view(*dbias_.shape[:-1], 2, dbias_.shape[-1]//2)
                    attn_dbias[..., idx, :].copy_(dbias_[..., 0, :])
                    attn_dbias[..., (2*cp_size-idx-1), :].copy_(dbias_[..., 1, :])
                elif i >= (cp_size-rank-1):
                    # [b, np, sq, sk//(2*cp)]
                    attn_dbias[..., idx, :].copy_(dbias_)
                else:
                    # [b, np, sq//2, sk//cp] -> [b, np, sq//2, 2, sk//(2*cp)]
                    dbias_ = dbias_.view(*dbias_.shape[:-1], 2, dbias_.shape[-1]//2)
                    attn_dbias_[..., 1, :, idx, :].copy_(dbias_[..., 0, :])
                    attn_dbias_[..., 1, :, (2*cp_size-idx-1), :].copy_(dbias_[..., 1, :])

1241
1242
1243
            # wait until dKV is received
            for req in send_recv_reqs:
                req.wait()
1244

1245
1246
1247
1248
            dkv = p2p_comm_buffers[(i+1)%2][1]
            if ctx.use_fused_attention:
                dkv_ = torch.cat((dk_.unsqueeze(0), dv_.unsqueeze(0)), dim=0)
            if ctx.causal and i >= (cp_size-rank-1) and i != (cp_size-1):
1249
1250
1251
1252
1253
1254
                if ctx.qkv_format == "bshd":
                    # [2, b*sk//2, np, hn] -> [2, b, sk//2, np, hn]
                    dkv_ = dkv_.view(*dkv.shape[0:2], *dkv.shape[3:])
                elif ctx.qkv_format == "sbhd":
                    # [2, b*sk//2, np, hn] -> [2, sk//2, b, np, hn]
                    dkv_ = dkv_.view(dkv.shape[0], -1, *dkv.shape[-3:])
1255
1256
1257
1258
            else:
                # [2, b*sk, np, hn] -> [2, b, 2, sk//2, np, hn] if causal
                # [2, b*sk, np, hn] -> [2, b, sk, np, hn] if not causal
                dkv_ = dkv_.view(*dkv.shape)
1259

1260
            if ctx.causal:
1261
1262
                if i == (cp_size-1):
                    if rank == 0:
1263
1264
1265
1266
1267
1268
                        if ctx.qkv_format == "bshd":
                            dkv[:, :, 0, ...].add_(dkv_[:, :, 0, ...])
                            dkv[:, :, 1, ...].copy_(dkv_[:, :, 1, ...])
                        elif ctx.qkv_format == "sbhd":
                            dkv[:, 0, ...].add_(dkv_[:, 0, ...])
                            dkv[:, 1, ...].copy_(dkv_[:, 1, ...])
1269
1270
                        elif ctx.qkv_format == "thd":
                            tex.thd_grad_correction(dkv, dkv_, cu_seqlens_k, "add", "copy")
1271
1272
1273
1274
                    else:
                        dkv.add_(dkv_)
                elif i >= (cp_size-rank-1):
                    if i == 0 and rank == (cp_size-1):
1275
1276
1277
1278
                        if ctx.qkv_format == "bshd":
                            dkv[:, :, 0, ...].copy_(dkv_)
                        elif ctx.qkv_format == "sbhd":
                            dkv[:, 0, ...].copy_(dkv_)
1279
1280
                        elif ctx.qkv_format == "thd":
                            tex.thd_grad_correction(dkv, dkv_, cu_seqlens_k, "copy", "none")
1281
                    else:
1282
1283
1284
1285
                        if ctx.qkv_format == "bshd":
                            dkv[:, :, 0, ...].add_(dkv_)
                        elif ctx.qkv_format == "sbhd":
                            dkv[:, 0, ...].add_(dkv_)
1286
1287
                        elif ctx.qkv_format == "thd":
                            tex.thd_grad_correction(dkv, dkv_, cu_seqlens_k, "add", "none")
1288
1289
1290
1291
1292
                elif i > 0:
                    dkv.add_(dkv_)
                else:
                    dkv.copy_(dkv_)
            else:
1293
1294
1295
1296
1297
1298
                if i == 0:
                    dkv.copy_(dkv_)
                else:
                    dkv.add_(dkv_)

        if ctx.causal:
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
            if ctx.qkv_format == "bshd":
                # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
                dq = dq.view(q.shape[0], -1, *q.shape[-2:])
                # [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn]
                dkv = dkv.view(*kv.shape[0:2], -1, *kv.shape[-2:])
            elif ctx.qkv_format == "sbhd":
                # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
                dq = dq.view(-1, *q.shape[-3:])
                # [2, 2, sk//2, b, np, hn] -> [2, sk, b, np, hn]
                dkv = dkv.view(kv.shape[0], -1, *kv.shape[-3:])

        if attn_dbias is not None:
            # [b, np, sq, 2*cp, sk//(2*cp)] -> [b, np, sq, sk]
            attn_dbias = attn_dbias.view(*attn_dbias.shape[:-2], -1)

1314
        return None, dq, dkv[0], dkv[1], None, None, None, None, None, None, \
1315
                None, None, None, None, None, None, attn_dbias, None, None
1316
1317
1318


def attn_forward_func_with_cp(
1319
1320
1321
1322
    is_training, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
    dropout_p, cp_group, cp_global_ranks, cp_stream, softmax_scale=None, qkv_format="bshd",
    attn_mask_type="causal", attn_bias_type="no_bias", attn_bias=None, deterministic=False,
    use_fused_attention=False
1323
1324
) -> torch.Tensor:
    """Attention implementation with context parallelism"""
1325
    assert(qkv_format in ["bshd", "sbhd", "thd"]
1326
1327
1328
        ), f"QKV format of {qkv_format} is not supported with context parallelism!"
    assert(qkv_format != "sbhd" or use_fused_attention
        ), "FlashAttention does not support sbhd format!"
1329
1330
    assert(not(qkv_format == "thd" and use_fused_attention)
        ), "FusedAttention does not support thd format!"
1331
1332
    assert (attn_mask_type in ["causal", "no_mask"]
        ), f"Mask type of {attn_mask_type} is not supported with context parallelism!"
1333
1334
    assert (attn_bias is None or use_fused_attention
        ), "Attention bias is only supported with FusedAttention!"
1335
1336
    out = AttnFuncWithCP.apply(
        is_training, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
1337
1338
        dropout_p, cp_group, cp_global_ranks, cp_stream, softmax_scale, qkv_format,
        attn_mask_type, attn_bias_type, attn_bias, deterministic, use_fused_attention
1339
1340
1341
1342
    )
    return out


1343
1344
1345
1346
1347
1348
1349
class RotaryPositionEmbedding(torch.nn.Module):
    """
    Implements Rotary Position Embedding from https://arxiv.org/abs/2104.09864.
    """
    def __init__(
        self,
        dim: int,
1350
        rotary_percent: float = 1.0,
1351
1352
1353
1354
1355
1356
1357
1358
        seq_len_interpolation_factor: Optional[int] = None,
        pretrained_max_position_embeddings: Optional[int] = None,
    ):
        """
        Parameters
        ----------
        dim: int
            rotary embedding dimension
1359
1360
        rotary_percent: float
            Percent of rotary dimension to use for rotary position embeddings.
1361
1362
1363
1364
1365
1366
1367
        seq_len_interpolation_factor: int
            if not None, discrete positions will be interpolated by this factor via the trick in
            https://arxiv.org/abs/2306.15595
        pretrained_max_position_embeddings: int
            pre-trained max_position_embeddings before position interpolation
        """
        super().__init__()
1368
1369
        if rotary_percent < 1.0:
            dim = int(dim * rotary_percent)
1370
        self.seq_len_interpolation_factor = seq_len_interpolation_factor
1371
1372
1373
1374
1375
1376
1377
        inv_freq = 1.0 / (
            10000
            ** (
                torch.arange(0, dim, 2, dtype=torch.float32, device=torch.cuda.current_device())
                / dim
            )
        )
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
        self.register_buffer('inv_freq', inv_freq)
        self.pretrained_max_position_embeddings = pretrained_max_position_embeddings

    def forward(self, max_seq_len: int, offset: int = 0):
        """
        Create rotary position embedding frequencies

        Parameters
        ----------
        max_seq_len: int
            sequence length of a sample
        offset: int, default = 0
            fixed offset for freqencies
        """
1392
1393
1394
1395
        seq = (
            torch.arange(max_seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
            + offset
        )
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413

        if (self.pretrained_max_position_embeddings is not None
            and self.seq_len_interpolation_factor is not None):
            if (max_seq_len >
                self.pretrained_max_position_embeddings * self.seq_len_interpolation_factor):
                # dynamic linear scaling (length > position we have learned)
                seq *= 1 / (max_seq_len / self.pretrained_max_position_embeddings)
            else:
                # fixed linear scaling
                seq *= 1 / self.seq_len_interpolation_factor

        freqs = torch.einsum('i , j -> i j', seq, self.inv_freq)
        # first part even vector components, second part odd vector components,
        #  2 * dim in dimension size
        emb = torch.cat((freqs, freqs), dim=-1)
        # emb [seq_length, .., dim]
        return emb.reshape(emb.size(0), 1, 1, emb.size(1))

1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465

class FusedRoPEFunc(torch.autograd.Function):
    """
    Function for FusedRoPE

    This implementation assumes the input tensor to be in `sbhd`, `bshd` or `thd` format and
    the RoPE tensor to be of shape (s, 1, 1, d). It accepts arbitrary memory layouts to avoid
    the expensive `.contiguous()` calls, thus it may not achieve the best memory access pattern.
    """

    @staticmethod
    def forward(
        ctx,
        t: torch.Tensor,
        freqs: torch.Tensor,
        tensor_format: str = "sbhd",
        cu_seqlens: Union[torch.Tensor, None] = None,
    ) -> torch.Tensor:
        if tensor_format == "sbhd":
            output = tex.fused_rope_forward(t, freqs, False)
        elif tensor_format == "bshd":
            output = tex.fused_rope_forward(
                t.transpose(0, 1), freqs, True
            ).transpose(0, 1)
        elif tensor_format == "thd":
            output = tex.fused_rope_thd_forward(t, cu_seqlens, freqs)
        else:
            raise ValueError(f"Unsupported tensor_format: {tensor_format}.")
        ctx.save_for_backward(freqs, cu_seqlens)
        ctx.tensor_format = tensor_format

        return output

    @staticmethod
    def backward(
        ctx, grad_output: torch.Tensor
    ) -> Tuple[Union[torch.Tensor, None], ...]:
        freqs, cu_seqlens = ctx.saved_tensors
        if ctx.tensor_format == "sbhd":
            grad_input = tex.fused_rope_backward(grad_output, freqs, False)
        elif ctx.tensor_format == "bshd":
            grad_input = tex.fused_rope_backward(
                grad_output.transpose(0, 1), freqs, True
            ).transpose(0, 1)
        elif ctx.tensor_format == "thd":
            grad_input = tex.fused_rope_thd_backward(grad_output, cu_seqlens, freqs)
        else:
            raise ValueError(f"Unsupported tensor_format: {ctx.tensor_format}.")

        return grad_input, None, None, None, None


1466
1467
1468
1469
1470
1471
1472
1473
1474
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
    """
    change sign so the last dimension becomes [-odd, +even]
    """
    x = x.view(x.shape[:-1] + torch.Size((2, x.shape[-1] // 2)))
    x1, x2 = x.unbind(dim=-2)
    return torch.cat((-x2, x1), dim=-1)


1475
def apply_rotary_pos_emb(
1476
1477
1478
1479
1480
1481
    t: torch.Tensor,
    freqs: torch.Tensor,
    tensor_format: str = "sbhd",
    fused: bool = False,
    cu_seqlens: Union[torch.Tensor, None] = None,
) -> torch.Tensor:
1482
    """
1483
    Apply rotary positional embedding tensor to the input tensor.
1484

1485
1486
1487
    Parameters
    ----------
    t: torch.Tensor
1488
        Input tensor of shape `[s, b, h, d]`, `[b, s, h, d]` or `[t, h, d]`, on which
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
        rotary positional embedding will be applied.
    freqs: torch.Tensor
        Rotary positional embedding tensor of shape `[s2, 1, 1, d2]` and dtype 'float',
        with `s2 >= s` and `d2 <= d`.
    fused: bool, default = False
        Whether to use a fused applying RoPE implementation.
    tensor_format: {'sbhd', 'bshd', 'thd'}, default = 'sbhd'
        is `bshd` if `t` is of shape `[bs, seq, ...]`, or `sbhd` if `t` is
        of shape `[seq, bs, ...]`. 'thd' is only supported when `fused` is True.
    cu_seqlens: torch.Tensor, default = None.
        Cumulative sum of sequence lengths in a batch for `t`, with shape [b + 1] and
        dtype torch.int32. Only valid when `tensor_format` is 'thd'.
1501
    """
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
    if fused:
        assert (
            tensor_format != "thd" or cu_seqlens is not None
        ), "cu_seqlens must not be None when tensor_format is 'thd'."
        return FusedRoPEFunc.apply(t, freqs, tensor_format, cu_seqlens)

    assert tensor_format in ("sbhd", "bshd"), (
        "Only formats `sbhd` or `bshd` are supported for input tensor `t` "
        f"when fused is False, got {tensor_format}."
    )

1513
1514
1515
1516
1517
    max_seq_len = freqs.shape[0]
    cur_seq_len = t.shape[1] if tensor_format == "bshd" else t.shape[0]

    # Only apply the rotary embeddings up to the sequence length of the running
    # input.
1518
1519
1520
1521
    assert cur_seq_len <= max_seq_len, (
        f"Rotary Embeddings only supported up to {max_seq_len} sequence length!"
    )
    freqs = freqs[:cur_seq_len]
1522
    if tensor_format == "bshd":
1523
1524
1525
1526
        freqs = freqs.transpose(0, 1)  # [seq, 1, 1, dim] -> [1, seq, 1, dim]
    # cos/sin first then dtype conversion for better precision
    cos_ = torch.cos(freqs).to(t.dtype)
    sin_ = torch.sin(freqs).to(t.dtype)
1527

1528
1529
1530
1531
1532
1533
    rot_dim = freqs.shape[-1]
    # ideally t_pass is empty so rotary pos embedding is applied to all tensor t
    t, t_pass = t[..., :rot_dim], t[..., rot_dim:]

    # first part is cosine component
    # second part is sine component, need to change signs with _rotate_half method
1534
    t = (t * cos_) + (_rotate_half(t) * sin_)
1535
1536
1537
    return torch.cat((t, t_pass), dim=-1)


cyanguwa's avatar
cyanguwa committed
1538
class _SplitAlongDim(torch.autograd.Function):
1539
1540
1541
1542
1543
    """"""

    @staticmethod
    def forward(ctx,
                mixed_x_layer: torch.Tensor,
cyanguwa's avatar
cyanguwa committed
1544
1545
                split_dim: int,
                split_size_or_sections: Union[int, List[int], Tuple[int]],
1546
    ) -> Tuple[torch.Tensor, ...]:
cyanguwa's avatar
cyanguwa committed
1547
1548
        ctx.split_dim = split_dim
        ctx.split_size_or_sections = split_size_or_sections
1549
1550
1551
1552
1553
1554
1555
1556
        if isinstance(mixed_x_layer, Float8Tensor):
            return tuple(Float8Tensor.make_like(
                mixed_x_layer,
                data=x,
                ) for x in torch.split(
                    mixed_x_layer._data,
                    split_size_or_sections=split_size_or_sections,
                    dim=split_dim))
cyanguwa's avatar
cyanguwa committed
1557
        return torch.split(mixed_x_layer, split_size_or_sections, dim = split_dim)
1558
1559
1560
1561
1562
1563

    @staticmethod
    def backward(ctx,
                 *grad_outputs):
        assert len(grad_outputs) > 0, "No gradients received for backprop!"

cyanguwa's avatar
cyanguwa committed
1564
1565
1566
1567
1568
1569
1570
1571
1572
        if isinstance(ctx.split_size_or_sections, (list, tuple)):
            split_sizes = ctx.split_size_or_sections
            assert (len(grad_outputs) == len(split_sizes)
                ), "Unequal number of gradients vs split sections for backprop!"
        if isinstance(ctx.split_size_or_sections, int):
            split_sizes = [ctx.split_size_or_sections] * len(grad_outputs)
        dims = len(grad_outputs[0].shape)
        split_dim = (ctx.split_dim + dims) % dims

1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
        if isinstance(grad_outputs[0], Float8Tensor):
            noop_ok = True
            strides = grad_outputs[0].stride()
            data_ptr = grad_outputs[0]._data.untyped_storage().data_ptr()
            shape = list(grad_outputs[0].shape)
            for i, tensor in enumerate(grad_outputs):
                shape_i = shape
                shape_i[split_dim] = split_sizes[i]
                offset_size = sum(split_sizes[:i]) * np.prod(shape[split_dim+1:])
                if (tensor.stride() != strides or
                    list(tensor.shape) != shape_i or
                    tensor._data.untyped_storage().data_ptr() != data_ptr or
                    tensor.storage_offset() != offset_size):
                    noop_ok = False
                    break
            if noop_ok:
                ret = torch.Tensor().to(device=grad_outputs[0].device,
                                        dtype=grad_outputs[0]._data.dtype)
                new_shape = list(shape)
                new_shape[split_dim] = sum(split_sizes)
                ret.set_(grad_outputs[0]._data.untyped_storage(),
                         grad_outputs[0]._data.storage_offset(),
                         new_shape,
                         strides
                )
                return Float8Tensor.make_like(grad_outputs[0], data=ret), None, None

            grad_outputs_data = [x._data for x in grad_outputs]
            return Float8Tensor.make_like(
                grad_outputs[0],
                data=torch.cat(grad_outputs_data, dim = split_dim)), None, None
1604
1605
        noop_ok = True
        strides = grad_outputs[0].stride()
1606
        data_ptr = grad_outputs[0].untyped_storage().data_ptr()
cyanguwa's avatar
cyanguwa committed
1607
        shape = list(grad_outputs[0].shape)
1608
        for i, tensor in enumerate(grad_outputs):
cyanguwa's avatar
cyanguwa committed
1609
1610
1611
            shape_i = shape
            shape_i[split_dim] = split_sizes[i]
            offset_size = sum(split_sizes[:i]) * np.prod(shape[split_dim+1:])
1612
            if (tensor.stride() != strides or
cyanguwa's avatar
cyanguwa committed
1613
                list(tensor.shape) != shape_i or
1614
                tensor.untyped_storage().data_ptr() != data_ptr or
cyanguwa's avatar
cyanguwa committed
1615
                tensor.storage_offset() != offset_size):
1616
1617
1618
1619
1620
1621
                noop_ok = False
                break
        if noop_ok:
            ret = torch.Tensor().to(device=grad_outputs[0].device,
                                    dtype=grad_outputs[0].dtype)
            new_shape = list(shape)
cyanguwa's avatar
cyanguwa committed
1622
1623
            new_shape[split_dim] = sum(split_sizes)
            ret.set_(grad_outputs[0].untyped_storage(),
1624
1625
                     grad_outputs[0].storage_offset(),
                     new_shape,
cyanguwa's avatar
cyanguwa committed
1626
                     strides
1627
            )
cyanguwa's avatar
cyanguwa committed
1628
            return ret, None, None
1629

cyanguwa's avatar
cyanguwa committed
1630
        return torch.cat(grad_outputs, dim = split_dim), None, None
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650


class UnfusedDotProductAttention(torch.nn.Module):
    """Parallel attention w/o QKV and Proj Gemms
    BMM1 -> softmax + dropout -> BMM2
    """

    def __init__(
        self,
        norm_factor: float,
        attention_dropout: float = 0.0,
        attention_dropout_ctx: Optional[Callable] = nullcontext,
        layer_number: Optional[int] = None,
    ) -> None:
        super().__init__()

        self.norm_factor = norm_factor
        self.attention_dropout_ctx = attention_dropout_ctx
        self.layer_number = layer_number

1651
        self.scale_mask_softmax = FusedScaleMaskSoftmax(attention_mask_func)
1652
1653
1654
1655
1656
1657

        # Dropout. Note that for a single iteration, this layer will generate
        # different outputs on different number of parallel partitions but
        # on average it should not be partition dependent.
        self.attention_dropout = torch.nn.Dropout(attention_dropout)

1658
1659
1660
1661
        # An FP16 training trick required for certain GPT-like models.
        self.apply_qk_layer_scaling = (
            bool(int(os.getenv("NVTE_APPLY_QK_LAYER_SCALING", "0"))) and layer_number is not None)

1662
1663
1664
1665
1666
    def forward(
        self,
        query_layer: torch.Tensor,
        key_layer: torch.Tensor,
        value_layer: torch.Tensor,
1667
1668
1669
        qkv_layout: str = "sbh3d",
        cu_seqlens_q: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
        cu_seqlens_kv: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
1670
        attn_mask_type: str = "causal",
1671
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
1672
1673
        core_attention_bias_type: str = "no_bias",
        core_attention_bias: Optional[torch.Tensor] = None,
1674
        alibi_slopes: Optional[torch.Tensor] = None,
1675
    ) -> torch.Tensor:
1676
        """Unfused attention fprop"""
1677

1678
1679
1680
1681
1682
1683
1684
1685
1686
        assert (qkv_layout in QKVLayouts
            ), f"UnfusedDotProductAttention does not support qkv_layout = {qkv_layout}!"
        qkv_format = ''.join([i for i in qkv_layout.split('_')[0] if i.isalpha()])
        assert (qkv_format != 'thd'
            ), """UnfusedDotProductAttention does not support variable sequence lengths!"""
        if qkv_format == 'bshd':
            # convert to sbhd and use sbhd implementation for now
            query_layer, key_layer, value_layer = [x.transpose(0, 1)
                for x in [query_layer, key_layer, value_layer]]
1687

1688
        batch_size, seqlen = query_layer.shape[1], query_layer.shape[0]
1689
        apply_qk_layer_scaling = self.apply_qk_layer_scaling and key_layer.dtype == torch.float16
1690
1691
1692
1693
1694
1695
1696
1697
1698

        # [b, np, sq, sk]
        output_size = (
            query_layer.size(1),
            query_layer.size(2),
            query_layer.size(0),
            key_layer.size(0),
        )

1699
1700
1701
1702
1703
1704
1705
1706
        if key_layer.shape[2] != query_layer.shape[2]:
            assert (query_layer.shape[2]%key_layer.shape[2]==0
                ),"The number of attention heads must be divisible by the number of GQA groups!"
            key_layer = key_layer.repeat_interleave(
                    int(query_layer.shape[2]/key_layer.shape[2]), dim = 2)
            value_layer = value_layer.repeat_interleave(
                    int(query_layer.shape[2]/value_layer.shape[2]), dim = 2)

1707
1708
1709
1710
1711
1712
1713
1714
        # [sq, b, np, hn] -> [sq, b * np, hn]
        query_layer = query_layer.reshape(
            output_size[2], output_size[0] * output_size[1], -1
        )
        # [sk, b, np, hn] -> [sk, b * np, hn]
        key_layer = key_layer.reshape(output_size[3], output_size[0] * output_size[1], -1)

        # preallocting result tensor: [b * np, sq, sk]
1715
1716
        # WAR to set dtype to FP32 as ONNX lacks BF16 support for ConstantOfShape operator
        is_bf16 = query_layer.dtype == torch.bfloat16
1717
1718
1719
1720
        matmul_result = torch.empty(
            output_size[0] * output_size[1],
            output_size[2],
            output_size[3],
1721
            dtype=torch.float32 if is_in_onnx_export_mode() and is_bf16 else query_layer.dtype,
1722
1723
1724
            device=torch.cuda.current_device(),
        )

1725
1726
1727
        if is_in_onnx_export_mode() and is_bf16:
            matmul_result = matmul_result.bfloat16()

1728
1729
1730
1731
1732
        scale = self.norm_factor
        if apply_qk_layer_scaling:
            scale *= self.layer_number

        # Raw attention scores. [b * np, sq, sk]
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
        if core_attention_bias_type == "no_bias":
            matmul_result = torch.baddbmm(
                matmul_result,
                query_layer.transpose(0, 1),  # [b * np, sq, hn]
                key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
                beta=0.0,
                alpha=(1.0 / scale),
            )

        elif core_attention_bias_type == "pre_scale_bias":
            assert core_attention_bias is not None, "core_attention_bias should not be None!"
            matmul_result = torch.bmm(
                query_layer.transpose(0, 1),  # [b * np, sq, hn]
                key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
            )
            matmul_result = (matmul_result.view(
                output_size[0], output_size[1], output_size[2], output_size[3])
                + core_attention_bias).view(-1, output_size[2], output_size[3])
            matmul_result /= scale

1753
1754
1755
1756
        elif core_attention_bias_type in ["post_scale_bias", "alibi"]:
            if core_attention_bias_type == "post_scale_bias":
                assert core_attention_bias is not None, "core_attention_bias should not be None!"
            if core_attention_bias_type == "alibi":
1757
1758
                _, core_attention_bias = get_alibi(
                    output_size[1], output_size[2], output_size[3], alibi_slopes=alibi_slopes)
1759
1760
1761
1762
1763
1764
1765
1766
1767
            matmul_result = torch.baddbmm(
                matmul_result,
                query_layer.transpose(0, 1),  # [b * np, sq, hn]
                key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
                beta=0.0,
                alpha=(1.0 / scale),
            )
            matmul_result = (matmul_result.view(
                output_size[0], output_size[1], output_size[2], output_size[3])
1768
1769
                + core_attention_bias).view(-1, output_size[2], output_size[3]).to(
                dtype=query_layer.dtype)
1770
1771
1772
1773
1774
1775

        # change view to [b, np, sq, sk]
        attention_scores = matmul_result.view(*output_size)

        # attention scores and attention mask [b, np, sq, sk]
        softmax_scale = self.layer_number if apply_qk_layer_scaling else None
1776
1777
        attention_probs = self.scale_mask_softmax(
            attention_scores, attention_mask, attn_mask_type, softmax_scale)
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        with self.attention_dropout_ctx():
            attention_probs = self.attention_dropout(attention_probs)

        # value_layer -> context layer.
        # [sk, b, np, hn] --> [b, np, sq, hn]
        output_size = (
            value_layer.size(1),
            value_layer.size(2),
            query_layer.size(0),
            value_layer.size(3),
        )

        # change view [sk, b * np, hn]
        value_layer = value_layer.reshape(
            value_layer.size(0), output_size[0] * output_size[1], -1
        )

        # change view [b * np, sq, sk]
        attention_probs = attention_probs.view(
            output_size[0] * output_size[1], output_size[2], -1
        )

        # matmul: [b * np, sq, hn]
        context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))

        # change view [b, np, sq, hn]
        context_layer = context_layer.view(*output_size)

1809
1810
1811
        if qkv_format == 'sbhd':
            # [b, np, sq, hn] --> [sq, b, np, hn]
            context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
1812

1813
1814
1815
1816
1817
1818
1819
1820
1821
            # [sq, b, np, hn] --> [sq, b, hp]
            context_layer = context_layer.view(seqlen, batch_size, -1)

        if qkv_format == 'bshd':
            # [b, np, sq, hn] --> [b, sq, np, hn]
            context_layer = context_layer.permute(0, 2, 1, 3).contiguous()

            # [b, sq, np, hn] --> [b, sq, hp]
            context_layer = context_layer.view(batch_size, seqlen, -1)
1822
1823
1824
1825
1826
1827
1828
1829
1830

        return context_layer


class _PrepareQKVForFA(torch.autograd.Function):
    """This class converts QKV from interleaved (s, b, ...) layout
       to separate contiguous q, k, v tensors in (b, s, ...) layout."""

    @staticmethod
1831
1832
1833
1834
1835
1836
    def forward(
        _ctx: torch.autograd.function.FunctionCtx,  # unused
        query_layer: torch.Tensor,
        key_layer: torch.Tensor,
        value_layer: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
        # All inputs received are non-contiguous tensors.
        # The `query_layer` tensor is used to access the
        # full memory region of the QKV tensor.
        qkv = tex.fa_prepare_fwd(query_layer)
        q, k, v = split_tensor_along_dim(qkv, 0, 3)
        query_layer = torch.squeeze(q, 0)
        key_layer = torch.squeeze(k, 0)
        value_layer = torch.squeeze(v, 0)
        return query_layer, key_layer, value_layer

    @staticmethod
1848
1849
1850
1851
1852
    def backward(
        _ctx: torch.autograd.function.FunctionCtx,  # unused
        dq: torch.Tensor,
        dk: torch.Tensor,
        dv: torch.Tensor
1853
1854
1855
1856
1857
    ) -> Tuple[Union[torch.Tensor, None], ...]:
        dqkv = tex.fa_prepare_bwd(dq, dk, dv)
        dq, dk, dv = split_tensor_along_dim(dqkv, -1, 3)
        return dq, dk, dv

1858

1859
1860
1861
1862
1863
1864
1865
def _get_qkv_layout(
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
        qkv_format: str = 'sbhd',
    ) -> str:
    """Get qkv layout.
1866

1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
    Parameters
    ----------
    q: torch.Tensor
        Query tensor.
    k: torch.Tensor
        Key tensor.
    v: torch.Tensor
        Value tensor.
    qkv_format: str, default = `sbhd`
        Dimension format for `q`, `k` and `v`, {`sbhd`, `bshd`, `thd`}. `s` stands for
        the sequence length dimension, `b` batch size, `h` the number of attention heads,
        `d` head size, and `t` the total number of sequences in a batch, i.e.
        `t = sum(s_i) for i = 0...b-1`.

    Returns
    ----------
    qkv_layout: str
       Memory layout of `q`, `k` and `v`. Each `qkv_format` can be mapped to one of five
       memory layouts. For example, `sb3hd` means `q`, `k`, `v` are created as one chunk
       of memory and that they are interleaved in the `2`nd dimension. `sbhd_sbh2d` means
       `q` and `kv` are created in two chunks and that `q` itself is contiguous and `k`, `v`
       are interleaved with each other in the `3`rd dimension, `k = kv[:,:,:,0,:]` and
       `v = kv[:,:,:,1,:]`.
       Mapping:
       `sbhd`: {`sb3hd`, `sbh3d`, `sbhd_sb2hd`, `sbhd_sbh2d`, `sbhd_sbhd_sbhd`}
       `bshd`: {`bs3hd`, `bsh3d`, `bshd_bs2hd`, `bshd_bsh2d`, `bshd_bshd_bshd`}
       `thd` : {`t3hd`, `th3d`, `thd_t2hd`, `thd_th2d`, `thd_thd_thd`}
    """
1895

1896
1897
    check_last_dim_contiguous = all(x.stride(-1) == 1 for x in [q, k, v])
    assert check_last_dim_contiguous, "q, k and v must have stride 1 in their last dimension!"
1898

1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
    def run_iteratively(q, k, v):
        data_ptr = q.untyped_storage().data_ptr()
        check_ptrs_qkv = all(x.untyped_storage().data_ptr() == data_ptr for x in [q, k, v])
        data_ptr = k.untyped_storage().data_ptr()
        check_ptrs_kv = all(x.untyped_storage().data_ptr() == data_ptr for x in [k, v])

        stride = q.stride()
        check_strides_qkv = all(stride == x.stride() for x in [q, k, v])
        stride = k.stride()
        check_strides_kv = all(stride == x.stride() for x in [k, v])

        shape = q.shape
        check_shapes_qkv = all(shape == x.shape for x in [q, k, v])
        shape = k.shape
        check_shapes_kv = all(shape == x.shape for x in [k, v])

        last_dim_size = q.shape[-1]
        check_last_dim_offsets_qkv = all(i * last_dim_size == x.storage_offset()
                            for i, x in enumerate([q, k, v]))
        last_dim_size = k.shape[-1]
        check_last_dim_offsets_kv = all(i * last_dim_size == x.storage_offset()
                            for i, x in enumerate([k, v]))

        last_two_dims_size = q.shape[-1] * q.shape[-2]
        check_last_two_dims_offsets_qkv = all(i * last_two_dims_size == x.storage_offset()
                            for i, x in enumerate([q, k, v]))
        last_two_dims_size = k.shape[-1] * k.shape[-2]
        check_last_two_dims_offsets_kv = all(i * last_two_dims_size == x.storage_offset()
                            for i, x in enumerate([k, v]))

        if (check_ptrs_qkv and check_strides_qkv and check_shapes_qkv
            and check_last_two_dims_offsets_qkv
            and not check_last_dim_offsets_qkv):
            # sb3hd, bs3hd, t3hd
            qkv_layout = qkv_format[:-2] + '3' + qkv_format[-2:]
        elif (check_ptrs_qkv and check_strides_qkv and check_shapes_qkv
            and check_last_dim_offsets_qkv):
            # sbh3d, bsh3d, th3d
            qkv_layout = qkv_format[:-1] + '3' + qkv_format[-1:]
        elif (check_ptrs_kv and check_strides_kv and check_shapes_kv
            and check_last_two_dims_offsets_kv
            and not check_last_dim_offsets_kv):
            # sbhd_sb2hd, bshd_bs2hd, thd_t2hd
            qkv_layout = qkv_format + '_' + qkv_format[:-2] + '2' + qkv_format[-2:]
        elif (check_ptrs_kv and check_strides_kv and check_shapes_kv
            and check_last_dim_offsets_kv):
            # sbhd_sbh2d, bshd_bsh2d, thd_th2d
            qkv_layout = qkv_format + '_' + qkv_format[:-1] + '2' + qkv_format[-1:]
        elif check_strides_kv and check_shapes_kv:
            # sbhd_sbhd_sbhd, bshd_bshd_bshd, thd_thd_thd
            qkv_layout = '_'.join(list([qkv_format])*3)
        else:
            qkv_layout = 'not_supported'

        return qkv_layout

    qkv_layout = run_iteratively(q, k, v)
    if qkv_layout == 'not_supported':
        # force q,k,v to be contiguous and run get_layout again
        q, k, v = [x.contiguous() for x in [q, k, v]]
        qkv_layout = run_iteratively(q, k, v)
    if qkv_layout == 'not_supported':
1961
1962
        raise Exception("The provided qkv memory layout is not supported!")

1963
    return qkv_layout, q, k, v
1964

1965

1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
def check_set_window_size(
        attn_mask_type: str,
        window_size: Tuple[int, int] = None,
    ):
    """Check if sliding window size is compliant with mask type and if not,
    assert or set it to the appropriate size
    """
    if "causal" in attn_mask_type:
        if window_size is None:
            window_size = (-1, 0)
        else:
            assert (
                window_size[1] == 0
            ), "window_size[1] should be 0 when self_attn_mask_type includes 'causal'!"
    else:
        if window_size is None:
            window_size = (-1, -1)
    return window_size
1984

1985

1986
class FlashAttention(torch.nn.Module):
1987
    """Dot product attention, using HazyResearch flash-attn package:
1988
    https://github.com/Dao-AILab/flash-attention
1989
1990
1991
1992
1993
1994
1995
    """

    def __init__(
        self,
        norm_factor: float,
        attention_dropout: float = 0.0,
        attention_dropout_ctx: Optional[Callable] = nullcontext,
1996
1997
        attention_type: str = "self",
        layer_number: Optional[int] = None,
1998
        deterministic: bool = False,
1999
2000
2001
2002
2003
2004
    ) -> None:
        super().__init__()

        assert (
            _flash_attn_version >= _flash_attn_version_required
        ), f"FlashAttention minimum version {_flash_attn_version_required} is required."
2005
2006
2007
        assert (
            _flash_attn_version <= _flash_attn_max_version
        ), f"FlashAttention maximum version {_flash_attn_max_version} is supported."
2008
2009
2010
2011

        self.norm_factor = norm_factor
        self.attention_dropout_ctx = attention_dropout_ctx
        self.attention_dropout = attention_dropout
2012
2013
        self.attention_type = attention_type
        self.layer_number = 1 if layer_number is None else layer_number
2014
        self.deterministic = deterministic
2015
2016
2017
2018
2019
2020

    def forward(
        self,
        query_layer: torch.Tensor,
        key_layer: torch.Tensor,
        value_layer: torch.Tensor,
2021
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
2022
2023
2024
        qkv_layout: str = "sbh3d",
        cu_seqlens_q: Optional[torch.Tensor] = None,
        cu_seqlens_kv: Optional[torch.Tensor] = None,
2025
2026
        max_seqlen_q: Optional[int] = None,
        max_seqlen_kv: Optional[int] = None,
2027
        attn_mask_type: str = "causal",
2028
        window_size: Optional[Tuple[int, int]] = None,
2029
        alibi_slopes: Optional[torch.Tensor] = None,
2030
        cp_group: Optional[dist_group_type] = None,
2031
        cp_global_ranks: List[int] = None,
2032
        cp_stream: torch.cuda.Stream = None,
2033
2034
2035
    ) -> torch.Tensor:
        """flash-attn fprop"""

2036
2037
        window_size = check_set_window_size(attn_mask_type, window_size)

2038
        assert (
2039
2040
2041
            query_layer.dtype in [torch.float16, torch.bfloat16]
            and key_layer.dtype in [torch.float16, torch.bfloat16]
            and value_layer.dtype in [torch.float16, torch.bfloat16]
2042
            ), "FlashAttention currently only supports FP16 and BF16."
2043
2044
        assert (
            query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda
2045
2046
2047
2048
2049
            ), "FlashAttention currently only supports CUDA tensors."
        assert (
            qkv_layout in QKVLayouts
            ), f"FlashAttention does not support qkv_layout = {qkv_layout}!"

2050
2051
        context_parallel = (cp_group is not None) and (get_distributed_world_size(cp_group) != 1)

2052
2053
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
2064
        qkv_format = ''.join([i for i in qkv_layout.split('_')[0] if i.isalpha()])

        if qkv_format == 'sbhd':
            # For now just 128, will make it more general in the future
            if (query_layer.shape[-1] == 128 and
                query_layer.shape[0] * query_layer.shape[1] >= 512 and
                qkv_layout == "sbh3d"):
                query_layer, key_layer, value_layer = _PrepareQKVForFA.apply(query_layer,
                                                                             key_layer,
                                                                             value_layer)
            else:
                query_layer, key_layer, value_layer = [x.transpose(0,1).contiguous()
                    for x in (query_layer, key_layer, value_layer)]
2065
        elif qkv_format == 'bshd':
2066
2067
2068
            query_layer, key_layer, value_layer = [x.contiguous()
                for x in (query_layer, key_layer, value_layer)]

2069
        batch_size = query_layer.shape[0]
2070

2071
        if qkv_format in ['sbhd', 'bshd']:
2072
            max_seqlen_q, max_seqlen_kv = query_layer.shape[1], key_layer.shape[1]
2073
2074
2075
2076
2077
2078
2079
            if not context_parallel:
                # [b * s, h, d]
                query_layer, key_layer, value_layer = [
                    x.view(x.shape[0] * x.shape[1], *x.shape[2:])
                    for x in [query_layer, key_layer, value_layer]
                ]

2080
            if 'padding' in attn_mask_type:
2081
                assert not context_parallel, "Padding mask not supported with context parallelism!"
2082
2083
2084
2085
2086

                if self.attention_type == "self":
                    assert (
                        max_seqlen_q == max_seqlen_kv
                    ), "Maximum sequence length for Q and KV should be the same."
2087
2088
                    if cu_seqlens_q is None:
                        assert (attention_mask is not None
2089
                                ), "Please provide attention_mask for padding!"
2090
2091
2092
2093
2094
2095
                        cu_seqlens_q, indices_q = get_cu_seqlens_and_indices(attention_mask)
                    else:
                        indices_q = get_indices(max_seqlen_q, cu_seqlens_q)
                    cu_seqlens_kv = cu_seqlens_q
                    query_layer, key_layer, value_layer = PackTensors.apply(
                        indices_q, query_layer, key_layer, value_layer
2096
2097
                    )
                else:
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
                    if cu_seqlens_q is None or cu_seqlens_kv is None:
                        assert (attention_mask is not None
                            ), "Please provide attention_mask for padding!"
                        cu_seqlens_q, indices_q = get_cu_seqlens_and_indices(
                            attention_mask[0])
                        cu_seqlens_kv, indices_kv = get_cu_seqlens_and_indices(
                            attention_mask[1])
                    else:
                        indices_q = get_indices(max_seqlen_q, cu_seqlens_q)
                        indices_kv = get_indices(max_seqlen_kv, cu_seqlens_kv)
                    query_layer = PackTensors.apply(indices_q, query_layer)
                    key_layer, value_layer = PackTensors.apply(
                        indices_kv, key_layer, value_layer
2111
2112
                    )
            else:
2113
2114
2115
2116
2117
2118
2119
2120
2121
2122
2123
2124
2125
                # Cumulative sequence lengths for unpadded data
                if cu_seqlens_q is None:
                    cu_seqlens_q = _get_full_cu_seqlens(
                        batch_size,
                        max_seqlen_q,
                        query_layer.device,
                    )
                if cu_seqlens_kv is None:
                    cu_seqlens_kv = _get_full_cu_seqlens(
                        batch_size,
                        max_seqlen_kv,
                        key_layer.device,
                    )
2126
        elif qkv_format == 'thd':
2127
2128
            assert (cu_seqlens_q is not None and cu_seqlens_kv is not None
                ), "cu_seqlens_q and cu_seqlens_kv can not be None when qkv_format = thd!"
2129
2130
2131
2132
2133
2134
            if max_seqlen_q is None:
                seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
                max_seqlen_q = seqlens_q.max().item()
            if max_seqlen_kv is None:
                seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
                max_seqlen_kv = seqlens_kv.max().item()
2135

2136
        if context_parallel:
2137
2138
2139
            assert (
                window_size in ((-1, -1), (-1, 0))
                ), "Sliding window attention is not supported with context parallelism."
2140
2141
2142
            assert (
                alibi_slopes is None
            ), "Alibi slope bias addition is not supported with context parallelism."
2143
            with self.attention_dropout_ctx():
2144
2145
                output = attn_forward_func_with_cp(
                    self.training, query_layer, key_layer, value_layer,
2146
                    cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv,
2147
                    self.attention_dropout if self.training else 0.0,
2148
                    cp_group, cp_global_ranks, cp_stream,
2149
                    softmax_scale=1.0/self.norm_factor,
2150
                    qkv_format="bshd" if qkv_format=="sbhd" else qkv_format,
2151
                    attn_mask_type=attn_mask_type,
2152
                    deterministic=self.deterministic
2153
2154
                )
        else:
2155
2156
2157
2158
2159
2160
2161
2162

            from .cpu_offload import CPUOffloadEnabled
            if CPUOffloadEnabled:
                tensor_list = [query_layer, key_layer, value_layer, cu_seqlens_q, cu_seqlens_kv]
                for tensor in tensor_list:
                    if tensor is not None:
                        tensor.activation_offloading = True

2163
            with self.attention_dropout_ctx():
2164
                fa_optional_forward_kwargs = {}
2165
2166
                if _flash_attn_2_3_plus:
                    fa_optional_forward_kwargs["window_size"] = window_size
2167
2168
2169
2170
                if _flash_attn_2_4_plus:
                    fa_optional_forward_kwargs["alibi_slopes"] = alibi_slopes
                if _flash_attn_2_4_1_plus:
                    fa_optional_forward_kwargs["deterministic"] = self.deterministic
2171
                output = flash_attn_forward_func(
2172
                    query_layer, key_layer, value_layer,
2173
                    cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv,
2174
                    self.attention_dropout if self.training else 0.0,
2175
                    softmax_scale=1.0/self.norm_factor, causal="causal" in attn_mask_type,
2176
                    **fa_optional_forward_kwargs,
2177
                )
2178

2179
        if 'padding' in attn_mask_type:
2180
            output = UnpackTensor.apply(indices_q, batch_size * max_seqlen_q, output)
2181

2182
2183
2184
        if qkv_format == 'sbhd':
            # (bs)hd -> bs(hd) -> sb(hd)
            output = output.view(batch_size, max_seqlen_q, -1).transpose(0, 1).contiguous()
2185
        elif qkv_format == 'bshd':
2186
2187
            # (bs)hd -> bs(hd)
            output = output.view(batch_size, max_seqlen_q, -1).contiguous()
2188
2189
2190
        elif qkv_format == 'thd':
            # thd -> t(hd)
            output = output.view(output.shape[0], -1).contiguous()
2191
2192

        return output
2193

2194
2195
2196
2197
2198
2199
2200
2201
2202
2203
2204
2205
2206
2207
2208
2209
2210
2211
2212
2213
2214
2215
2216
2217
2218
2219
2220
2221
2222
def _combine_tensors(
        tensors: List[torch.Tensor],
        dim: int,
    ) -> torch.Tensor:
    """Combine tensors along a particular dimension"""

    num_tensors = len(tensors)
    new_shape = list(tensors[0].shape)
    new_shape.insert(dim, num_tensors)
    new_stride = list(tensors[0].stride())
    new_stride.insert(dim, int(new_stride[dim-1]/num_tensors))
    if isinstance(tensors[0], Float8Tensor):
        combined_tensor = torch.Tensor().to(
            device=tensors[0].device, dtype=tensors[0]._data.dtype)
        combined_tensor.set_(
            tensors[0]._data.untyped_storage(),
            tensors[0]._data.storage_offset(),
            new_shape, new_stride)
        combined_tensor = Float8Tensor.make_like(
            tensors[0], data=combined_tensor)
    else:
        combined_tensor = torch.Tensor().to(
            device=tensors[0].device, dtype=tensors[0].dtype)
        combined_tensor.set_(
            tensors[0].untyped_storage(),
            tensors[0].storage_offset(),
            new_shape, new_stride)

    return combined_tensor
2223

2224
2225
2226
2227
2228
2229
class FusedAttnFunc_qkvpacked(torch.autograd.Function):
    """Function for FusedAttention with packed QKV input"""

    @staticmethod
    def forward(ctx, is_training, max_seqlen, cu_seqlens, qkv, qkv_dtype, attn_bias, attn_scale,
                dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type,
2230
                rng_gen, fused_attention_backend, use_FAv2_bwd,
2231
                fp8, fp8_meta):
2232
2233
2234
2235
2236
2237
2238
2239
2240
2241
2242
2243
2244
2245
2246
2247
2248
2249
2250
2251
2252
2253
2254
2255
2256
2257
2258
2259
2260
2261
2262
2263
2264
2265
2266
2267
2268
2269
2270
2271
2272
2273
2274
2275
2276
2277
2278
2279
2280
2281
2282
2283
2284
2285
2286
2287
2288
2289
2290
2291
2292
2293
2294
2295
2296
2297
2298
2299
2300
2301
2302
2303
2304
        if fp8:
            if _NVTE_DEBUG:
                print('[DotProductAttention]: using FP8 forward')
            if fp8_meta["recipe"].fp8_mha:
                assert (isinstance(qkv, Float8Tensor)), "qkv must be Float8Tensors for FP8 MHA."
                fp8_meta["scaling_fwd"].scale_inv[META_QKV] = qkv._scale_inv
            fused_attention_backend = FusedAttnBackend["FP8"]
            fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
            # 1: qkv packed, 2: kv packed, 3: qkv separate
            qkv_group = len(qkv_layout.split('_'))
            assert (qkv_group == 1
                ), f"qkv layout should conform to 3hd or h3d, e.g. sb3hd, \
                but found {qkv_layout}."
            if fp8_meta["recipe"].fp8_mha:
                qkv_fp8 = qkv._data
            else:
                qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1])
                qkv_fp8 = cast_to_fp8(qkv_c,
                    fp8_meta["scaling_fwd"],
                    META_QKV, fp8_dtype_forward).view(qkv.shape)
            out_fp8, aux_ctx_tensors = fused_attn_fwd_qkvpacked(
                is_training, max_seqlen, cu_seqlens,
                qkv_fp8, fp8_dtype_forward, fused_attention_backend, attn_bias,
                fp8_meta["scaling_fwd"].scale_inv[META_QKV],
                fp8_meta["scaling_fwd"].scale_inv[META_S],
                fp8_meta["scaling_fwd"].scale[META_S],
                fp8_meta["scaling_fwd"].scale[META_O],
                fp8_meta["scaling_fwd"].amax_history[0][META_S],
                fp8_meta["scaling_fwd"].amax_history[0][META_O],
                attn_scale, dropout_p, fast_zero_fill, qkv_layout,
                attn_bias_type, attn_mask_type, rng_gen)
            if fp8_meta["recipe"].fp8_mha:
                out_ret = Float8Tensor(data=out_fp8,
                    fp8_meta=fp8_meta,
                    fp8_meta_forward=True,
                    fp8_meta_index=META_O,
                    fp8_dtype=fp8_dtype_forward,
                    dtype=qkv.dtype,
                )
            else:
                out_ret = cast_from_fp8(
                    out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]),
                    fp8_meta["scaling_fwd"], META_O,
                    fp8_dtype_forward, qkv_dtype).view(out_fp8.shape)
            out_save = out_ret
            if fp8_meta["recipe"].fp8_mha and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
                qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1])
                qkv = cast_from_fp8(qkv_c._data,
                    fp8_meta["scaling_fwd"],
                    META_QKV, fp8_dtype_forward, TE_DType[qkv.dtype]).view(qkv.shape)
                out_save = cast_from_fp8(
                    out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]),
                    fp8_meta["scaling_fwd"], META_O,
                    fp8_dtype_forward, qkv_dtype).view(out_fp8.shape)
            fp8_tensors = (qkv_fp8, out_fp8,
                fp8_meta["scaling_fwd"].scale.clone(),
                fp8_meta["scaling_fwd"].scale_inv.clone())
        else:
            if _NVTE_DEBUG:
                print('[DotProductAttention]: using non-FP8 forward')
            out_ret, aux_ctx_tensors = fused_attn_fwd_qkvpacked(
                is_training, max_seqlen, cu_seqlens, qkv, qkv_dtype,
                fused_attention_backend, attn_bias,
                None, None, None, None, None, None,
                attn_scale, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type,
                rng_gen)
            fp8_tensors = (None, None, None, None)
            out_save = out_ret

        ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
        qkvo_tensors = (qkv, out_save) if not ctx.fp8 else (None, None)
        ctx.save_for_backward(*qkvo_tensors, cu_seqlens, *fp8_tensors)
        ctx.fp8_meta = fp8_meta
2305
2306
2307
2308
2309
2310
2311
2312
2313
        ctx.aux_ctx_tensors = aux_ctx_tensors
        ctx.max_seqlen = max_seqlen
        ctx.qkv_dtype = qkv_dtype
        ctx.attn_scale = attn_scale
        ctx.dropout_p = dropout_p
        ctx.fast_zero_fill = fast_zero_fill
        ctx.qkv_layout = qkv_layout
        ctx.attn_bias_type = attn_bias_type
        ctx.attn_mask_type = attn_mask_type
2314
2315
        ctx.fused_attention_backend = \
            fused_attention_backend if ctx.fp8 else FusedAttnBackend["F16_arbitrary_seqlen"]
2316
        ctx.use_FAv2_bwd = use_FAv2_bwd
2317

2318
        return out_ret
2319
2320
2321

    @staticmethod
    def backward(ctx, d_out):
2322
2323
2324
2325
2326
2327
        if ctx.fp8_meta["recipe"].fp8_mha:
            assert (isinstance(d_out, Float8Tensor)
                ), "Gradient of the DPA output must be in Float8Tensor type for FP8 MHA."
            d_out_f8tensor = d_out
            d_out = d_out._data

2328
        d_out = d_out.contiguous()
2329
2330
        (qkv, out, cu_seqlens,
            qkv_fp8, out_fp8, fwd_scales, fwd_scale_invs) = ctx.saved_tensors
2331
2332
        if not ctx.aux_ctx_tensors[0].is_contiguous():
            ctx.aux_ctx_tensors[0] = ctx.aux_ctx_tensors[0].contiguous()
2333
2334
2335
2336
2337
2338
2339
2340
2341
2342
        if ctx.use_FAv2_bwd:
            softmax_lse, rng_state = ctx.aux_ctx_tensors
            dqkv = torch.empty_like(qkv)
            maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
            d_out, q, k, v, out = [maybe_contiguous(x)
                for x in (d_out, qkv[:,0], qkv[:,1], qkv[:,2], out)]
            flash_attn_cuda_bwd(
                d_out, q, k, v, out, softmax_lse, dqkv[:,0], dqkv[:,1], dqkv[:,2],
                cu_seqlens, cu_seqlens, ctx.max_seqlen, ctx.max_seqlen,
                ctx.dropout_p, ctx.attn_scale, False,
2343
                "causal" in ctx.attn_mask_type, None, rng_state
2344
2345
2346
            )
            dqkv = dqkv[..., :d_out.shape[-1]]
        else:
2347
2348
2349
2350
2351
2352
2353
2354
2355
2356
2357
2358
2359
2360
2361
2362
2363
2364
2365
2366
2367
2368
2369
2370
2371
2372
2373
2374
2375
2376
2377
2378
2379
2380
2381
2382
2383
2384
2385
2386
2387
2388
2389
2390
2391
2392
2393
2394
2395
2396
2397
2398
2399
2400
2401
2402
2403
2404
2405
            with torch.cuda.nvtx.range("_FusedAttn_qkvpacked"):
                if ctx.fp8:
                    if _NVTE_DEBUG:
                        print('[DotProductAttention]: using FP8 backward')
                    fp8_dtype_forward = get_fp8_te_dtype(
                        ctx.fp8_meta["recipe"], fprop_tensor=True)
                    fp8_dtype_backward = get_fp8_te_dtype(
                        ctx.fp8_meta["recipe"], fprop_tensor=False)
                    if ctx.fp8_meta["recipe"].fp8_mha:
                        d_out_fp8 = d_out
                        ctx.fp8_meta['scaling_bwd'].scale_inv[META_DO] = d_out_f8tensor._scale_inv
                    else:
                        d_out_fp8 = cast_to_fp8(
                            d_out.view(-1, d_out.shape[-2] * d_out.shape[-1]),
                            ctx.fp8_meta["scaling_bwd"], META_DO, fp8_dtype_backward
                            ).view(d_out.shape)
                    dqkv_fp8, *rest = fused_attn_bwd_qkvpacked(
                        ctx.max_seqlen, cu_seqlens,
                        qkv_fp8, out_fp8, d_out_fp8,
                        fp8_dtype_forward, fp8_dtype_backward, ctx.aux_ctx_tensors,
                        ctx.fused_attention_backend,
                        fwd_scale_invs[META_QKV], # d_scale_qkv,
                        fwd_scale_invs[META_S], # d_scale_s,
                        fwd_scale_invs[META_O], # d_scale_o,
                        ctx.fp8_meta['scaling_bwd'].scale_inv[META_DO], # d_scale_do
                        ctx.fp8_meta['scaling_bwd'].scale_inv[META_DP], # d_scale_dp
                        fwd_scales[META_S], # q_scale_s
                        ctx.fp8_meta['scaling_bwd'].scale[META_DP], # q_scale_dp
                        ctx.fp8_meta['scaling_bwd'].scale[META_DQKV], # q_scale_dqkv
                        ctx.fp8_meta['scaling_bwd'].amax_history[0][META_DP], # amax_dp
                        ctx.fp8_meta['scaling_bwd'].amax_history[0][META_DQKV], # amax_dqkv
                        ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill,
                        ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type)
                    if ctx.fp8_meta["recipe"].fp8_mha:
                        dqkv = Float8Tensor(data=dqkv_fp8,
                            fp8_meta=ctx.fp8_meta,
                            fp8_meta_forward=False,
                            fp8_meta_index=META_DQKV,
                            fp8_dtype=fp8_dtype_backward,
                            dtype=d_out_f8tensor.dtype,
                            )
                    else:
                        dqkv_c_fp8 = dqkv_fp8.view(-1,
                            dqkv_fp8.shape[-3] * dqkv_fp8.shape[-2] * dqkv_fp8.shape[-1])
                        dqkv = cast_from_fp8(dqkv_c_fp8,
                            ctx.fp8_meta["scaling_bwd"], META_DQKV,
                            fp8_dtype_backward, ctx.qkv_dtype).view(dqkv_fp8.shape)
                else:
                    if _NVTE_DEBUG:
                        print('[DotProductAttention]: using non-FP8 backward')
                    if d_out.dtype == torch.uint8:
                        d_out = d_out_f8tensor.from_float8(qkv.dtype)
                    dqkv, *rest = fused_attn_bwd_qkvpacked(
                        ctx.max_seqlen, cu_seqlens, qkv, out, d_out,
                        ctx.qkv_dtype, ctx.qkv_dtype, ctx.aux_ctx_tensors,
                        ctx.fused_attention_backend,
                        None, None, None, None, None, None, None, None, None, None,
                        ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill,
                        ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type)
2406

2407
2408
        # if no_bias or alibi, return dqkv
        if ctx.attn_bias_type in ["no_bias", "alibi"]:
2409
2410
2411
2412
2413
2414
2415
2416
            return (None, None, None, dqkv, None, None, None,
                    None, None, None, None, None, None,
                    None, None, None, None, None, None)
        # else, return (dqkv, dbias)
        return (None, None, None, dqkv, None, rest[0], None,
                None, None, None, None, None, None,
                None, None, None, None, None, None)

2417

2418
2419
2420
2421
2422
2423
class FusedAttnFunc_kvpacked(torch.autograd.Function):
    """Function for FusedAttention with packed KV input"""

    @staticmethod
    def forward(ctx, is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
                q, kv, qkv_dtype, attn_bias, attn_scale, dropout_p, fast_zero_fill,
2424
                qkv_layout, attn_bias_type, attn_mask_type, rng_gen, fused_attention_backend,
2425
                use_FAv2_bwd, fp8, fp8_meta):
2426
2427
2428
2429
2430
2431
2432
2433
2434
2435
2436
2437
2438
2439
2440
2441
2442
2443
2444
2445
2446
2447
2448
2449
2450
2451
2452
2453
2454
2455
2456
2457
2458
2459
2460
2461
2462
2463
2464
2465
2466
2467
2468
2469
2470
2471
2472
2473
2474
2475
2476
2477
2478
2479
2480
2481
2482
2483
2484
2485
2486
2487
2488
2489
2490
2491
2492
2493
2494
2495
2496
2497
2498
2499
2500
2501
2502
2503
2504
2505
        if fp8:
            if _NVTE_DEBUG:
                print('[DotProductAttention]: using FP8 forward')
            if fp8_meta["recipe"].fp8_mha:
                assert (isinstance(q, Float8Tensor)
                    and isinstance(kv, Float8Tensor)), "q/kv must be Float8Tensors for FP8 MHA."
                fp8_meta["scaling_fwd"].scale_inv[META_QKV] = q._scale_inv
            fused_attention_backend = FusedAttnBackend["FP8"]
            fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
            if fp8_meta["recipe"].fp8_mha:
                q_fp8, kv_fp8 = q._data, kv._data
            else:
                # 1: qkv packed, 2: kv packed, 3: qkv separate
                qkv_group = len(qkv_layout.split('_'))
                assert (qkv_group == 2
                    ), f"qkv layout should conform to hd_2hd or hd_h2d, e.g. sbhd_sb2hd, \
                    but found {qkv_layout}."
                q_fp8 = cast_to_fp8(q,
                    fp8_meta["scaling_fwd"],
                    META_QKV, fp8_dtype_forward).view(q.shape)
                kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1])
                kv_fp8 = cast_to_fp8(kv_c,
                    fp8_meta["scaling_fwd"],
                    META_QKV, fp8_dtype_forward).view(kv.shape)
            out_fp8, aux_ctx_tensors = fused_attn_fwd_kvpacked(
                is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
                q_fp8, kv_fp8, fp8_dtype_forward, fused_attention_backend, attn_bias,
                fp8_meta["scaling_fwd"].scale_inv[META_QKV],
                fp8_meta["scaling_fwd"].scale_inv[META_S],
                fp8_meta["scaling_fwd"].scale[META_S],
                fp8_meta["scaling_fwd"].scale[META_O],
                fp8_meta["scaling_fwd"].amax_history[0][META_S],
                fp8_meta["scaling_fwd"].amax_history[0][META_O],
                attn_scale, dropout_p, fast_zero_fill, qkv_layout,
                attn_bias_type, attn_mask_type, rng_gen)
            if fp8_meta["recipe"].fp8_mha:
                out_ret = Float8Tensor(data=out_fp8,
                    fp8_meta=fp8_meta,
                    fp8_meta_forward=True,
                    fp8_meta_index=META_O,
                    fp8_dtype=fp8_dtype_forward,
                    dtype=q.dtype,
                )
            else:
                out_ret = cast_from_fp8(
                    out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]),
                    fp8_meta["scaling_fwd"], META_O,
                    fp8_dtype_forward, qkv_dtype).view(out_fp8.shape)
            out_save = out_ret
            if fp8_meta["recipe"].fp8_mha and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
                q = cast_from_fp8(q._data,
                    fp8_meta["scaling_fwd"],
                    META_QKV, fp8_dtype_forward, TE_DType[q.dtype]).view(q.shape)
                kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1])
                kv = cast_from_fp8(kv_c._data,
                    fp8_meta["scaling_fwd"],
                    META_QKV, fp8_dtype_forward, TE_DType[kv.dtype]).view(kv.shape)
                out_save = cast_from_fp8(
                    out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]),
                    fp8_meta["scaling_fwd"], META_O,
                    fp8_dtype_forward, qkv_dtype).view(out_fp8.shape)
            fp8_tensors = (q_fp8, kv_fp8, out_fp8,
                fp8_meta["scaling_fwd"].scale.clone(),
                fp8_meta["scaling_fwd"].scale_inv.clone())
        else:
            if _NVTE_DEBUG:
                print('[DotProductAttention]: using non-FP8 forward')
            out_ret, aux_ctx_tensors = fused_attn_fwd_kvpacked(
                is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
                q, kv, qkv_dtype, fused_attention_backend, attn_bias,
                None, None, None, None, None, None,
                attn_scale, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type,
                rng_gen)
            out_save = out_ret
            fp8_tensors = (None, None, None, None, None)

        ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
        qkvo_tensors = (q, kv, out_save) if not ctx.fp8 else (None, None, None)
        ctx.save_for_backward(*qkvo_tensors, cu_seqlens_q, cu_seqlens_kv, *fp8_tensors)
        ctx.fp8_meta = fp8_meta
2506
2507
2508
2509
2510
2511
2512
2513
2514
2515
        ctx.aux_ctx_tensors = aux_ctx_tensors
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_kv = max_seqlen_kv
        ctx.qkv_dtype = qkv_dtype
        ctx.attn_scale = attn_scale
        ctx.dropout_p = dropout_p
        ctx.fast_zero_fill = fast_zero_fill
        ctx.qkv_layout = qkv_layout
        ctx.attn_bias_type = attn_bias_type
        ctx.attn_mask_type = attn_mask_type
2516
2517
        ctx.fused_attention_backend = \
            fused_attention_backend if ctx.fp8 else FusedAttnBackend["F16_arbitrary_seqlen"]
2518
        ctx.use_FAv2_bwd = use_FAv2_bwd
2519

2520
        return out_ret
2521
2522
2523

    @staticmethod
    def backward(ctx, d_out):
2524
2525
2526
2527
2528
2529
        if ctx.fp8_meta["recipe"].fp8_mha:
            assert (isinstance(d_out, Float8Tensor)
                ), "Gradient of the DPA output must be in Float8Tensor type for FP8 MHA."
            d_out_f8tensor = d_out
            d_out = d_out._data

2530
        d_out = d_out.contiguous()
2531
2532
        (q, kv, out, cu_seqlens_q, cu_seqlens_kv,
            q_fp8, kv_fp8, out_fp8, fwd_scales, fwd_scale_invs) = ctx.saved_tensors
2533
2534
        if not ctx.aux_ctx_tensors[0].is_contiguous():
            ctx.aux_ctx_tensors[0] = ctx.aux_ctx_tensors[0].contiguous()
2535
2536
2537
2538
2539
2540
2541
2542
2543
2544
2545
        if ctx.use_FAv2_bwd:
            softmax_lse, rng_state = ctx.aux_ctx_tensors
            dq = torch.empty_like(q)
            dkv = torch.empty_like(kv)
            maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
            d_out, q, k, v, out = [maybe_contiguous(x)
                for x in (d_out, q, kv[:,0], kv[:,1], out)]
            flash_attn_cuda_bwd(
                d_out, q, k, v, out, softmax_lse, dq, dkv[:,0], dkv[:,1],
                cu_seqlens_q, cu_seqlens_kv, ctx.max_seqlen_q, ctx.max_seqlen_kv,
                ctx.dropout_p, ctx.attn_scale, False,
2546
                "causal" in ctx.attn_mask_type, None, rng_state
2547
2548
2549
2550
            )
            dq = dq[..., :d_out.shape[-1]]
            dkv = dkv[..., :d_out.shape[-1]]
        else:
2551
2552
2553
2554
2555
2556
2557
2558
2559
2560
2561
2562
2563
2564
2565
2566
2567
2568
2569
2570
2571
2572
2573
2574
2575
2576
2577
2578
2579
2580
2581
2582
2583
2584
2585
2586
2587
2588
2589
2590
2591
2592
2593
2594
2595
2596
2597
2598
2599
2600
2601
2602
2603
2604
2605
2606
2607
2608
2609
2610
2611
2612
2613
2614
2615
2616
2617
2618
2619
2620
2621
            with torch.cuda.nvtx.range("_FusedAttn_kvpacked"):
                if ctx.fp8:
                    if _NVTE_DEBUG:
                        print('[DotProductAttention]: using FP8 backward')
                    fp8_dtype_forward = get_fp8_te_dtype(
                        ctx.fp8_meta["recipe"], fprop_tensor=True)
                    fp8_dtype_backward = get_fp8_te_dtype(
                        ctx.fp8_meta["recipe"], fprop_tensor=False)
                    if ctx.fp8_meta["recipe"].fp8_mha:
                        d_out_fp8 = d_out
                        ctx.fp8_meta['scaling_bwd'].scale_inv[META_DO] = d_out_f8tensor._scale_inv
                    else:
                        d_out_fp8 = cast_to_fp8(
                            d_out.view(-1, d_out.shape[-2] * d_out.shape[-1]),
                            ctx.fp8_meta["scaling_bwd"], META_DO, fp8_dtype_backward
                            ).view(d_out.shape)
                    dq_fp8, dkv_fp8, *rest = fused_attn_bwd_kvpacked(
                        ctx.max_seqlen_q, ctx.max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
                        q_fp8, kv_fp8, out_fp8, d_out_fp8,
                        fp8_dtype_forward, fp8_dtype_backward, ctx.aux_ctx_tensors,
                        ctx.fused_attention_backend,
                        fwd_scale_invs[META_QKV], # d_scale_qkv,
                        fwd_scale_invs[META_S], # d_scale_s,
                        fwd_scale_invs[META_O], # d_scale_o,
                        ctx.fp8_meta['scaling_bwd'].scale_inv[META_DO], # d_scale_do
                        ctx.fp8_meta['scaling_bwd'].scale_inv[META_DP], # d_scale_dp
                        fwd_scales[META_S], # q_scale_s
                        ctx.fp8_meta['scaling_bwd'].scale[META_DP], # q_scale_dp
                        ctx.fp8_meta['scaling_bwd'].scale[META_DQKV], # q_scale_dqkv
                        ctx.fp8_meta['scaling_bwd'].amax_history[0][META_DP], # amax_dp
                        ctx.fp8_meta['scaling_bwd'].amax_history[0][META_DQKV], # amax_dqkv
                        ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill,
                        ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type)
                    if ctx.fp8_meta["recipe"].fp8_mha:
                        dq = Float8Tensor(data=dq_fp8,
                            fp8_meta=ctx.fp8_meta,
                            fp8_meta_forward=False,
                            fp8_meta_index=META_DQKV,
                            fp8_dtype=fp8_dtype_backward,
                            dtype=d_out_f8tensor.dtype,
                            )
                        dkv = Float8Tensor(data=dkv_fp8,
                            fp8_meta=ctx.fp8_meta,
                            fp8_meta_forward=False,
                            fp8_meta_index=META_DQKV,
                            fp8_dtype=fp8_dtype_backward,
                            dtype=d_out_f8tensor.dtype,
                            )
                    else:
                        dq = cast_from_fp8(
                            dq_fp8.view(-1, dq_fp8.shape[-2] * dq_fp8.shape[-1]),
                            ctx.fp8_meta["scaling_bwd"], META_DQKV,
                            fp8_dtype_backward, ctx.qkv_dtype).view(dq_fp8.shape)
                        dkv_c_fp8 = dkv_fp8.view(-1,
                            dkv_fp8.shape[-3] * dkv_fp8.shape[-2] * dkv_fp8.shape[-1])
                        dkv = cast_from_fp8(dkv_c_fp8,
                            ctx.fp8_meta["scaling_bwd"], META_DQKV,
                            fp8_dtype_backward, ctx.qkv_dtype).view(dkv_fp8.shape)
                else:
                    if _NVTE_DEBUG:
                        print('[DotProductAttention]: using non-FP8 backward')
                    if d_out.dtype == torch.uint8:
                        d_out = d_out_f8tensor.from_float8(q.dtype)
                    dq, dkv, *rest = fused_attn_bwd_kvpacked(
                        ctx.max_seqlen_q, ctx.max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
                        q, kv, out, d_out,
                        ctx.qkv_dtype, ctx.qkv_dtype, ctx.aux_ctx_tensors,
                        ctx.fused_attention_backend,
                        None, None, None, None, None, None, None, None, None, None,
                        ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill,
                        ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type)
2622

2623
2624
        # if no_bias or alibi, return dqkv
        if ctx.attn_bias_type in ["no_bias", "alibi"]:
2625
2626
2627
2628
2629
2630
2631
2632
            return (None, None, None, None, None, dq, dkv, None, None, None,
                    None, None, None, None, None, None,
                    None, None, None, None, None, None)
        # else, return (dqkv, dbias)
        return (None, None, None, None, None, dq, dkv, None, rest[0], None,
                None, None, None, None, None, None,
                None, None, None, None, None, None)

2633
2634
2635
2636
2637
2638
class FusedAttnFunc(torch.autograd.Function):
    """Function for FusedAttention with separate Q, K, V tensors"""

    @staticmethod
    def forward(ctx, is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
                q, k, v, qkv_dtype, attn_bias, attn_scale, dropout_p, fast_zero_fill,
2639
                qkv_layout, attn_bias_type, attn_mask_type, rng_gen, fused_attention_backend,
2640
                use_FAv2_bwd, fp8, fp8_meta):
2641
2642
2643
2644
2645
2646
2647
2648
2649
2650
2651
2652
2653
2654
2655
2656
2657
2658
2659
2660
2661
2662
2663
2664
2665
2666
2667
2668
2669
2670
2671
2672
2673
2674
2675
2676
2677
2678
2679
2680
2681
2682
2683
2684
2685
2686
2687
2688
2689
2690
2691
2692
2693
2694
2695
2696
2697
2698
2699
2700
2701
2702
2703
2704
2705
2706
2707
2708
2709
2710
2711
2712
2713
2714
2715
2716
2717
2718
2719
2720
2721
2722
2723
2724
2725
2726
2727
2728
2729
2730
2731
2732
2733
2734
2735
2736
2737
2738
2739
2740
2741
2742
2743
2744
2745
2746
2747
2748
2749
2750
2751
2752
2753
2754
2755
2756
2757
2758
2759
2760
2761
2762
2763
2764
        if fp8:
            if _NVTE_DEBUG:
                print('[DotProductAttention]: using FP8 forward')
            fused_attention_backend = FusedAttnBackend["FP8"]
            fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
            if fp8_meta["recipe"].fp8_mha:
                assert (isinstance(q, Float8Tensor)
                    and isinstance(k, Float8Tensor)
                    and isinstance(v, Float8Tensor)), "q/k/v must be Float8Tensors for FP8 MHA."
                fp8_meta["scaling_fwd"].scale_inv[META_QKV] = q._scale_inv
                q_fp8, k_fp8, v_fp8 = q._data, k._data, v._data
            else:
                # 1: qkv packed, 2: kv packed, 3: qkv separate
                qkv_group = len(qkv_layout.split('_'))
                if qkv_group == 1:
                    dim = qkv_layout.find('3')
                    qkv = _combine_tensors([q,k,v], dim)
                    qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1])
                    qkv_fp8 = cast_to_fp8(qkv_c,
                        fp8_meta["scaling_fwd"],
                        META_QKV, fp8_dtype_forward).view(qkv.shape)
                    q_fp8, k_fp8, v_fp8 = _SplitAlongDim.apply(qkv_fp8, dim, [1,1,1])
                    q_fp8, k_fp8, v_fp8 = [x.squeeze(dim) for x in [q_fp8, k_fp8, v_fp8]]
                if qkv_group == 2:
                    q_fp8 = cast_to_fp8(q,
                        fp8_meta["scaling_fwd"],
                        META_QKV, fp8_dtype_forward).view(q.shape)
                    dim = qkv_layout.split('_')[1].find('2')
                    kv = _combine_tensors([k,v], dim)
                    kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1])
                    kv_fp8 = cast_to_fp8(kv_c,
                        fp8_meta["scaling_fwd"],
                        META_QKV, fp8_dtype_forward).view(kv.shape)
                    k_fp8, v_fp8 = _SplitAlongDim.apply(kv_fp8, dim, [1,1])
                    k_fp8, v_fp8 = [x.squeeze(dim) for x in [k_fp8, v_fp8]]
                if qkv_group == 3:
                    q_fp8 = cast_to_fp8(q,
                        fp8_meta["scaling_fwd"],
                        META_QKV, fp8_dtype_forward).view(q.shape)
                    k_fp8 = cast_to_fp8(k,
                        fp8_meta["scaling_fwd"],
                        META_QKV, fp8_dtype_forward).view(k.shape)
                    v_fp8 = cast_to_fp8(v,
                        fp8_meta["scaling_fwd"],
                        META_QKV, fp8_dtype_forward).view(v.shape)
            out_fp8, aux_ctx_tensors = fused_attn_fwd(
                is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
                q_fp8, k_fp8, v_fp8, fp8_dtype_forward, fused_attention_backend, attn_bias,
                fp8_meta["scaling_fwd"].scale_inv[META_QKV],
                fp8_meta["scaling_fwd"].scale_inv[META_S],
                fp8_meta["scaling_fwd"].scale[META_S],
                fp8_meta["scaling_fwd"].scale[META_O],
                fp8_meta["scaling_fwd"].amax_history[0][META_S],
                fp8_meta["scaling_fwd"].amax_history[0][META_O],
                attn_scale, dropout_p, fast_zero_fill, qkv_layout,
                attn_bias_type, attn_mask_type, rng_gen)
            if fp8_meta["recipe"].fp8_mha:
                out_ret = Float8Tensor(data=out_fp8,
                    fp8_meta=fp8_meta,
                    fp8_meta_forward=True,
                    fp8_meta_index=META_O,
                    fp8_dtype=fp8_dtype_forward,
                    dtype=q.dtype,
                )
            else:
                out_ret = cast_from_fp8(
                    out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]),
                    fp8_meta["scaling_fwd"], META_O,
                    fp8_dtype_forward, qkv_dtype).view(out_fp8.shape)
            out_save = out_ret

            if fp8_meta["recipe"].fp8_mha and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
                # 1: qkv packed, 2: kv packed, 3: qkv separate
                qkv_group = len(qkv_layout.split('_'))
                if qkv_group == 1:
                    dim = qkv_layout.find('3')
                    qkv = _combine_tensors([q,k,v], dim)
                    qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1])
                    qkv_no_fp8 = cast_from_fp8(qkv_c._data,
                        fp8_meta["scaling_fwd"],
                        META_QKV, fp8_dtype_forward, TE_DType[qkv.dtype]).view(qkv.shape)
                    q, k, v = _SplitAlongDim.apply(qkv_no_fp8, dim, [1,1,1])
                    q, k, v = [x.squeeze(dim) for x in [q, k, v]]
                if qkv_group == 2:
                    q = cast_from_fp8(q._data,
                        fp8_meta["scaling_fwd"],
                        META_QKV, fp8_dtype_forward, TE_DType[q.dtype]).view(q.shape)
                    dim = qkv_layout.split('_')[1].find('2')
                    kv = _combine_tensors([k,v], dim)
                    kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1])
                    kv_no_fp8 = cast_from_fp8(kv_c._data,
                        fp8_meta["scaling_fwd"],
                        META_QKV, fp8_dtype_forward, TE_DType[kv.dtype]).view(kv.shape)
                    k, v = _SplitAlongDim.apply(kv_no_fp8, dim, [1,1])
                    k, v = [x.squeeze(dim) for x in [k, v]]
                if qkv_group == 3:
                    q = cast_from_fp8(q._data,
                        fp8_meta["scaling_fwd"],
                        META_QKV, fp8_dtype_forward, TE_DType[q.dtype]).view(q.shape)
                    k = cast_from_fp8(k._data,
                        fp8_meta["scaling_fwd"],
                        META_QKV, fp8_dtype_forward, TE_DType[k.dtype]).view(k.shape)
                    v = cast_from_fp8(v._data,
                        fp8_meta["scaling_fwd"],
                        META_QKV, fp8_dtype_forward, TE_DType[v.dtype]).view(v.shape)
                out_save = cast_from_fp8(
                    out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]),
                    fp8_meta["scaling_fwd"], META_O,
                    fp8_dtype_forward, qkv_dtype).view(out_fp8.shape)

            fp8_tensors = (q_fp8, k_fp8, v_fp8, out_fp8,
                fp8_meta["scaling_fwd"].scale.clone(),
                fp8_meta["scaling_fwd"].scale_inv.clone())
        else:
            if _NVTE_DEBUG:
                print('[DotProductAttention]: using non-FP8 forward')
            out_ret, aux_ctx_tensors = fused_attn_fwd(
                is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
                q, k, v, qkv_dtype, fused_attention_backend, attn_bias,
                None, None, None, None, None, None,
                attn_scale, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type,
                rng_gen)
            out_save = out_ret
            fp8_tensors = (None, None, None, None, None, None)
2765

2766
2767
        from .cpu_offload import CPUOffloadEnabled
        if CPUOffloadEnabled:
2768
            tensor_list = [q, k, v, out_save, cu_seqlens_q, cu_seqlens_kv]
2769
2770
2771
2772
2773
            qkv_layout = 'sbhd_sbhd_sbhd'
            for tensor in tensor_list:
                if tensor is not None:
                    tensor.activation_offloading = True

2774
2775
2776
2777
        ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
        qkvo_tensors = (q, k, v, out_save) if not ctx.fp8 else (None, None, None, None)
        ctx.save_for_backward(*qkvo_tensors, cu_seqlens_q, cu_seqlens_kv, *fp8_tensors)
        ctx.fp8_meta = fp8_meta
2778
2779
2780
2781
2782
2783
2784
2785
2786
2787
        ctx.aux_ctx_tensors = aux_ctx_tensors
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_kv = max_seqlen_kv
        ctx.qkv_dtype = qkv_dtype
        ctx.attn_scale = attn_scale
        ctx.dropout_p = dropout_p
        ctx.fast_zero_fill = fast_zero_fill
        ctx.qkv_layout = qkv_layout
        ctx.attn_bias_type = attn_bias_type
        ctx.attn_mask_type = attn_mask_type
2788
2789
        ctx.fused_attention_backend = \
            fused_attention_backend if ctx.fp8 else FusedAttnBackend["F16_arbitrary_seqlen"]
2790
2791
        ctx.use_FAv2_bwd = use_FAv2_bwd

2792
        return out_ret
2793
2794
2795

    @staticmethod
    def backward(ctx, d_out):
2796
2797
2798
2799
2800
2801
        if ctx.fp8_meta["recipe"].fp8_mha:
            assert (isinstance(d_out, Float8Tensor)
                ), "Gradient of the DPA output must be in Float8Tensor type for FP8 MHA."
            d_out_f8tensor = d_out
            d_out = d_out._data

2802
        d_out = d_out.contiguous()
2803
2804
        (q, k, v, out, cu_seqlens_q, cu_seqlens_kv,
            q_fp8, k_fp8, v_fp8, out_fp8, fwd_scales, fwd_scale_invs) = ctx.saved_tensors
2805
2806
        if not ctx.aux_ctx_tensors[0].is_contiguous():
            ctx.aux_ctx_tensors[0] = ctx.aux_ctx_tensors[0].contiguous()
2807
2808
2809
2810
2811
2812
2813
2814
2815
2816
2817
2818
        if ctx.use_FAv2_bwd:
            softmax_lse, rng_state = ctx.aux_ctx_tensors
            dq = torch.empty_like(q)
            dk = torch.empty_like(k)
            dv = torch.empty_like(v)
            maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
            d_out, q, k, v, out = [maybe_contiguous(x)
                for x in (d_out, q, k, v, out)]
            flash_attn_cuda_bwd(
                d_out, q, k, v, out, softmax_lse, dq, dk, dv,
                cu_seqlens_q, cu_seqlens_kv, ctx.max_seqlen_q, ctx.max_seqlen_kv,
                ctx.dropout_p, ctx.attn_scale, False,
2819
                "causal" in ctx.attn_mask_type, None, rng_state
2820
2821
2822
2823
2824
            )
            dq = dq[..., :d_out.shape[-1]]
            dk = dk[..., :d_out.shape[-1]]
            dv = dv[..., :d_out.shape[-1]]
        else:
2825
2826
2827
2828
2829
2830
2831
2832
2833
2834
2835
2836
2837
2838
2839
2840
2841
2842
2843
2844
2845
2846
2847
2848
2849
2850
2851
2852
2853
2854
2855
2856
2857
2858
2859
2860
2861
2862
2863
2864
2865
2866
2867
2868
2869
2870
2871
2872
2873
2874
2875
2876
2877
2878
2879
2880
2881
2882
2883
2884
2885
2886
2887
2888
2889
2890
2891
2892
2893
2894
2895
2896
2897
2898
2899
2900
2901
2902
2903
2904
2905
2906
2907
2908
2909
2910
2911
2912
2913
2914
2915
2916
2917
2918
2919
2920
2921
2922
2923
2924
2925
2926
2927
2928
2929
2930
            with torch.cuda.nvtx.range("_FusedAttn"):
                if ctx.fp8:
                    if _NVTE_DEBUG:
                        print('[DotProductAttention]: using FP8 backward')
                    fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True)
                    fp8_dtype_backward = get_fp8_te_dtype(
                        ctx.fp8_meta["recipe"], fprop_tensor=False)
                    if ctx.fp8_meta["recipe"].fp8_mha:
                        d_out_fp8 = d_out
                        ctx.fp8_meta['scaling_bwd'].scale_inv[META_DO] = d_out_f8tensor._scale_inv
                    else:
                        d_out_fp8 = cast_to_fp8(
                            d_out.view(-1, d_out.shape[-2] * d_out.shape[-1]),
                            ctx.fp8_meta["scaling_bwd"], META_DO, fp8_dtype_backward
                            ).view(d_out.shape)
                    dq_fp8, dk_fp8, dv_fp8, *rest = fused_attn_bwd(
                        ctx.max_seqlen_q, ctx.max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
                        q_fp8, k_fp8, v_fp8, out_fp8, d_out_fp8,
                        fp8_dtype_forward, fp8_dtype_backward, ctx.aux_ctx_tensors,
                        ctx.fused_attention_backend,
                        fwd_scale_invs[META_QKV], # d_scale_qkv,
                        fwd_scale_invs[META_S], # d_scale_s,
                        fwd_scale_invs[META_O], # d_scale_o,
                        ctx.fp8_meta['scaling_bwd'].scale_inv[META_DO], # d_scale_do
                        ctx.fp8_meta['scaling_bwd'].scale_inv[META_DP], # d_scale_dp
                        fwd_scales[META_S], # q_scale_s
                        ctx.fp8_meta['scaling_bwd'].scale[META_DP], # q_scale_dp
                        ctx.fp8_meta['scaling_bwd'].scale[META_DQKV], # q_scale_dqkv
                        ctx.fp8_meta['scaling_bwd'].amax_history[0][META_DP], # amax_dp
                        ctx.fp8_meta['scaling_bwd'].amax_history[0][META_DQKV], # amax_dqkv
                        ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill,
                        ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type)
                    if ctx.fp8_meta["recipe"].fp8_mha:
                        dq = Float8Tensor(data=dq_fp8,
                            fp8_meta=ctx.fp8_meta,
                            fp8_meta_forward=False,
                            fp8_meta_index=META_DQKV,
                            fp8_dtype=fp8_dtype_backward,
                            dtype=d_out_f8tensor.dtype,
                            )
                        dk = Float8Tensor(data=dk_fp8,
                            fp8_meta=ctx.fp8_meta,
                            fp8_meta_forward=False,
                            fp8_meta_index=META_DQKV,
                            fp8_dtype=fp8_dtype_backward,
                            dtype=d_out_f8tensor.dtype,
                            )
                        dv = Float8Tensor(data=dv_fp8,
                            fp8_meta=ctx.fp8_meta,
                            fp8_meta_forward=False,
                            fp8_meta_index=META_DQKV,
                            fp8_dtype=fp8_dtype_backward,
                            dtype=d_out_f8tensor.dtype,
                            )
                    else:
                        qkv_group = len(ctx.qkv_layout.split('_'))
                        if qkv_group == 1:
                            dim = ctx.qkv_layout.find('3')
                            dqkv_fp8 = _combine_tensors([dq_fp8,dk_fp8,dv_fp8], dim)
                            dqkv_c_fp8 = dqkv_fp8.view(-1,
                                dqkv_fp8.shape[-3] * dqkv_fp8.shape[-2] * dqkv_fp8.shape[-1])
                            dqkv = cast_from_fp8(dqkv_c_fp8,
                                ctx.fp8_meta["scaling_bwd"], META_DQKV,
                                fp8_dtype_backward, ctx.qkv_dtype).view(dqkv_fp8.shape)
                            dq, dk, dv = _SplitAlongDim.apply(dqkv, dim, [1,1,1])
                            dq, dk, dv = [x.squeeze(dim) for x in [dq, dk, dv]]
                        if qkv_group == 2:
                            dq = cast_from_fp8(
                                dq_fp8.view(-1, dq_fp8.shape[-2] * dq_fp8.shape[-1]),
                                ctx.fp8_meta["scaling_bwd"], META_DQKV,
                                fp8_dtype_backward, ctx.qkv_dtype).view(dq_fp8.shape)
                            dim = ctx.qkv_layout.split('_')[1].find('2')
                            dkv_fp8 = _combine_tensors([dk_fp8,dv_fp8], dim)
                            dkv_c_fp8 = dkv_fp8.view(-1,
                                dkv_fp8.shape[-3] * dkv_fp8.shape[-2] * dkv_fp8.shape[-1])
                            dkv = cast_from_fp8(dkv_c_fp8,
                                ctx.fp8_meta["scaling_bwd"], META_DQKV,
                                fp8_dtype_backward, ctx.qkv_dtype).view(dkv_fp8.shape)
                            dk, dv = _SplitAlongDim.apply(dkv, dim, [1,1])
                            dk, dv = [x.squeeze(dim) for x in [dk, dv]]
                        if qkv_group == 3:
                            dq = cast_from_fp8(
                                dq_fp8.view(-1, dq_fp8.shape[-2] * dq_fp8.shape[-1]),
                                ctx.fp8_meta["scaling_bwd"], META_DQKV,
                                fp8_dtype_backward, ctx.qkv_dtype).view(dq_fp8.shape)
                            dk = cast_from_fp8(
                                dk_fp8.view(-1, dk_fp8.shape[-2] * dk_fp8.shape[-1]),
                                ctx.fp8_meta["scaling_bwd"], META_DQKV,
                                fp8_dtype_backward, ctx.qkv_dtype).view(dk_fp8.shape)
                            dv = cast_from_fp8(
                                dv_fp8.view(-1, dv_fp8.shape[-2] * dv_fp8.shape[-1]),
                                ctx.fp8_meta["scaling_bwd"], META_DQKV,
                                fp8_dtype_backward, ctx.qkv_dtype).view(dv_fp8.shape)
                else:
                    if _NVTE_DEBUG:
                        print('[DotProductAttention]: using non-FP8 backward')
                    if d_out.dtype == torch.uint8:
                        d_out = d_out_f8tensor.from_float8(q.dtype)
                    dq, dk, dv, *rest = fused_attn_bwd(
                        ctx.max_seqlen_q, ctx.max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
                        q, k, v, out, d_out,
                        ctx.qkv_dtype, ctx.qkv_dtype, ctx.aux_ctx_tensors,
                        ctx.fused_attention_backend,
                        None, None, None, None, None, None, None, None, None, None,
                        ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill,
                        ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type)
2931

2932
2933
        # if no_bias or alibi, return dqkv
        if ctx.attn_bias_type in ["no_bias", "alibi"]:
2934
2935
2936
2937
2938
2939
2940
2941
            return (None, None, None, None, None, dq, dk, dv, None, None, None,
                    None, None, None, None, None, None,
                    None, None, None, None, None, None)
        # else, return (dqkv, dbias)
        return (None, None, None, None, None, dq, dk, dv, None, rest[0], None,
                None, None, None, None, None, None,
                None, None, None, None, None, None)

2942

2943
class FusedAttention(TransformerEngineBaseModule):
2944
2945
2946
2947
2948
2949
2950
2951
2952
    """Dot product attention, with multiple backends:

    1. FusedAttnBackend["F16_max512_seqlen"]
       cuDNN based fused attention for FP16/BF16 and <=512 sequence length.
    2. FusedAttnBackend["F16_arbitrary_seqlen"]
       cuDNN based fused attention for FP16/BF16 and any sequence length.

    Support matrix:

2953
2954
2955
2956
    | backend       | 1                       | 2                              |
    | flash based   | no                      | yes                            |
    | cuDNN based   | yes                     | yes                            |
    | qkv dtype     | fp16/bf16               | fp16/bf16                      |
2957
    | attn_type     | self/cross              | self/cross                     |
2958
    | qkv_layout    |                         |                                |
2959
    |  - (q,k,v)    | sb3hd, bs3hd            | sb3hd, bs3hd, sbh3d, bsh3d     |
2960
    |               | sbhd_sb2hd, bshd_bs2hd  | sbhd_sb2hd, bshd_bs2hd         |
2961
2962
    |               | bshd_bshd_bshd          | sbhd_sbh2d, bshd_bsh2d         |
    |               |                         | sbhd_sbhd_sbhd, bshd_bshd_bshd |
2963
2964
    | mask_type     | causal/padding/no_mask  | causal/padding/no_mask         |
    | bias_type     | post_scale_bias/no_bias | post_scale_bias/alibi/no_bias  |
2965
    | dropout       | yes                     | yes                            |
2966
2967
    | max_seqlen    | <=512, multiple of 64   | any, multiple of 64            |
    | head_dim      | 64                      | <=128, multiple of 8           |
2968
    | output dtype  | fp16/bf16               | fp16/bf16                      |
2969
2970
2971
2972
2973
2974
2975
2976
    """

    def __init__(
        self,
        norm_factor: float,
        attention_dropout: float = 0.0,
        attention_dropout_ctx: Optional[Callable] = nullcontext,
        attention_type: str = "self",
2977
2978
        layer_number: Optional[int] = None,
        deterministic: bool = False,
2979
2980
2981
2982
2983
2984
2985
    ) -> None:
        super().__init__()

        self.norm_factor = norm_factor
        self.attention_dropout = attention_dropout
        self.attention_dropout_ctx = attention_dropout_ctx
        self.attention_type = attention_type
2986
        self.use_FAv2_bwd = (os.getenv("NVTE_FUSED_ATTN_USE_FAv2_BWD", "0") == "1"
Tim Moon's avatar
Tim Moon committed
2987
                        and get_device_compute_capability() == (9, 0))
2988
2989
2990
2991
2992
2993
2994
2995
2996
2997
2998
2999
3000
3001
3002
3003
        self.layer_number = 1 if layer_number is None else layer_number
        if deterministic:
            # workspace optimization path is deterministic
            os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "-1"

        # CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT
        # - unset:       enables workspace optimization when required workspace is <= 256MB
        #                or when bias gradient needs to be computed
        # - n:           enables workspace optimization when required workspace is <= n bytes
        # - -1:          enables workspace optimization always
        # - 0:           disables workspace optimization always
        if "NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT" in os.environ:
            if os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] == "0":
                os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "0"
            if os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] == "1":
                os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "-1"
3004

3005
3006
3007
3008
3009
3010
3011
3012
3013
3014
3015
        def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument
            """
            Temporarily remove fused_attention._extra_state as a missing key
            when loading older TransformerEngine checkpoints. Will phase out
            this hook in TransformerEngine 2.0.
            """
            for key in incompatible_keys.missing_keys:
                if 'fused_attention._extra_state' in key:
                    incompatible_keys.missing_keys.remove(key)
        self.register_load_state_dict_post_hook(remove_extra_states_check)

3016
3017
3018
3019
3020
3021
    def get_fp8_weights_scratchpad(
        self,
        is_first_microbatch: Union[bool, None],
    ) -> List[Float8Tensor]:
        """Needs override."""

3022
    @no_torch_dynamo()
3023
3024
3025
3026
3027
    def forward(
        self,
        query_layer: torch.Tensor,
        key_layer: torch.Tensor,
        value_layer: torch.Tensor,
3028
3029
3030
        qkv_layout: str = "sbh3d",
        cu_seqlens_q: Optional[torch.Tensor] = None,
        cu_seqlens_kv: Optional[torch.Tensor] = None,
3031
3032
        max_seqlen_q: Optional[int] = None,
        max_seqlen_kv: Optional[int] = None,
3033
        attn_mask_type: str = "causal",
3034
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
3035
3036
        fused_attention_backend:
            tex.NVTE_Fused_Attn_Backend = tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend,
3037
3038
3039
        core_attention_bias_type: str = "no_bias",
        core_attention_bias: Optional[torch.Tensor] = None,
        fast_zero_fill: bool = True,
3040
3041
3042
        cp_group: Optional[dist_group_type] = None,
        cp_global_ranks: List[int] = None,
        cp_stream: torch.cuda.Stream = None,
3043
        is_first_microbatch: Optional[bool] = None,
3044
3045
3046
    ) -> torch.Tensor:
        """fused attention fprop"""

3047
        assert (fused_attention_backend
3048
3049
            != tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend
            ), 'No fused attention backend supports this input combination!'
3050
        assert (
3051
3052
3053
            (query_layer.dtype in [torch.float16, torch.bfloat16, torch.uint8])
            and (key_layer.dtype in [torch.float16, torch.bfloat16, torch.uint8])
            and (value_layer.dtype in [torch.float16, torch.bfloat16, torch.uint8])
3054
3055
3056
3057
            ), 'FusedAttention only supports FP16 and BF16 data types.'
        assert (
            query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda
            ), 'FusedAttention only supports CUDA tensors.'
3058
3059
3060
3061
        assert (
            qkv_layout in QKVLayouts
            ), f"FusedAttention does not support qkv_layout = {qkv_layout}!"

3062
3063
        context_parallel = (cp_group is not None) and (get_distributed_world_size(cp_group) != 1)

3064
        qkv_format = ''.join([i for i in qkv_layout.split('_')[0] if i.isalpha()])
3065
3066
3067
3068
        assert (
            qkv_format != 'thd'
            ), 'FusedAttention does not support qkv_format = thd!'

3069
3070
3071
3072
3073
3074
3075
        if qkv_format in ['sbhd', 'bshd']:
            if qkv_format == 'sbhd':
                batch_size, max_seqlen_q, max_seqlen_kv = (
                    query_layer.shape[1], query_layer.shape[0], key_layer.shape[0])
            if qkv_format == 'bshd':
                batch_size, max_seqlen_q, max_seqlen_kv = (
                    query_layer.shape[0], query_layer.shape[1], key_layer.shape[1])
3076
            if 'padding' in attn_mask_type:
3077
3078
                assert not context_parallel, "Padding mask not supported with context parallelism!"

3079
3080
3081
3082
3083
                if cu_seqlens_q is None or cu_seqlens_kv is None:
                    if attention_mask is None:
                        raise RuntimeError(
                            "Please provide attention_mask or cu_seqlens for padding!"
                        )
3084
                    if self.attention_type == "self":
3085
3086
                        cu_seqlens_q = get_cu_seqlens(attention_mask)
                        cu_seqlens_kv = cu_seqlens_q
3087
                    else:
3088
3089
                        cu_seqlens_q = get_cu_seqlens(attention_mask[0])
                        cu_seqlens_kv = get_cu_seqlens(attention_mask[1])
3090
            else:
3091
3092
3093
3094
3095
3096
3097
3098
3099
3100
3101
3102
                if cu_seqlens_q is None:
                    cu_seqlens_q = _get_full_cu_seqlens(
                        batch_size,
                        max_seqlen_q,
                        query_layer.device,
                    )
                if cu_seqlens_kv is None:
                    cu_seqlens_kv = _get_full_cu_seqlens(
                        batch_size,
                        max_seqlen_kv,
                        key_layer.device,
                    )
3103
3104
3105

        qkv_dtype = TE_DType[query_layer.dtype]

3106
        use_FAv2_bwd = (self.use_FAv2_bwd
3107
                and (core_attention_bias_type == "no_bias")
3108
3109
                and (fused_attention_backend
                    == tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen))
3110
3111
3112
3113
3114

        if context_parallel:
            assert (fused_attention_backend
                == tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen
                ), f"{fused_attention_backend} does not work with context parallelism!"
3115
3116
3117
3118
3119
            assert (
                core_attention_bias_type not in ["alibi"]
            ), f"{core_attention_bias_type} is not supported with context parallelism!"
            query_layer, key_layer, value_layer = [x.contiguous()
                for x in (query_layer, key_layer, value_layer)]
3120
3121
3122
3123
3124
3125
3126
3127
3128
            with self.attention_dropout_ctx():
                output = attn_forward_func_with_cp(
                    self.training,
                    query_layer, key_layer, value_layer,
                    cu_seqlens_q, cu_seqlens_kv,
                    max_seqlen_q, max_seqlen_kv,
                    self.attention_dropout if self.training else 0.0,
                    cp_group, cp_global_ranks, cp_stream,
                    softmax_scale=1.0/self.norm_factor,
3129
                    qkv_format=qkv_format,
3130
                    attn_mask_type=attn_mask_type,
3131
3132
                    attn_bias_type=core_attention_bias_type,
                    attn_bias=core_attention_bias,
3133
3134
3135
                    use_fused_attention=True,
                )
        else:
3136
3137
3138
3139
3140
3141
3142
3143
3144
3145
3146
3147
3148
3149
3150
3151
3152
3153
3154
3155
3156
3157
3158
3159
3160
3161
3162
3163
3164
3165
3166
3167
3168
3169
3170
            with self.prepare_forward(query_layer,
                is_first_microbatch,
                num_gemms=3,
                allow_non_contiguous=True) as query_layer:
                with self.attention_dropout_ctx():
                    forced_fp8_dpa = ""
                    if self.fp8_meta["recipe"].fp8_mha:
                        if not self.fp8_meta["recipe"].fp8_dpa:
                            self.fp8_meta["recipe"].fp8_dpa = True
                            forced_fp8_dpa = " (forced)"
                    if _NVTE_DEBUG:
                        print("[DotProductAttention]: "
                            f"""using fp8_recipe.fp8_mha={self.fp8_meta["recipe"].fp8_mha}, """
                            f"""fp8_recipe.fp8_dpa={self.fp8_meta["recipe"].fp8_dpa}"""
                            f"""{forced_fp8_dpa} and """
                            f"""NVTE_FP8_DPA_BWD={int(os.getenv("NVTE_FP8_DPA_BWD", "1"))}""")
                    output = FusedAttnFunc.apply(
                        self.training,
                        max_seqlen_q, max_seqlen_kv,
                        cu_seqlens_q, cu_seqlens_kv,
                        query_layer, key_layer, value_layer,
                        qkv_dtype,
                        core_attention_bias,
                        1.0/self.norm_factor,
                        self.attention_dropout if self.training else 0.0,
                        fast_zero_fill,
                        qkv_layout,
                        core_attention_bias_type,
                        attn_mask_type,
                        None, # rng_gen
                        fused_attention_backend,
                        use_FAv2_bwd,
                        self.fp8 and self.fp8_meta["recipe"].fp8_dpa,
                        self.fp8_meta,
                    )
3171

3172
3173
        # ...hd -> ...(hd)
        return output.view(*output.shape[:-2], -1)
3174
3175


3176
3177
3178
3179
3180
3181
3182
class DotProductAttention(torch.nn.Module):
    """Allows the model to jointly attend to information from different
    representation subspaces as described in the paper:
    `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.

    .. note::

3183
        Argument :attr:`attention_mask` in the `forward` call is only used when
3184
        :attr:`attn_mask_type` includes '"padding"' or `"arbitrary"`.
3185
3186
3187

    .. warning::

3188
        FlashAttention uses a non-deterministic algorithm for optimal performance. To observe
3189
        deterministic behavior at the cost of performance, use FlashAttention version >= `2.4.1`
3190
3191
        and set the environment variable :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO=0`. In order
        to disable`flash-attn` entirely, set :attr:`NVTE_FLASH_ATTN=0`.
3192
3193
3194
3195
3196
3197
3198

    Parameters
    ----------
    num_attention_heads : int
                         number of attention heads in the transformer layer.
    kv_channels : int
                number of key-value channels.
3199
3200
3201
3202
3203
3204
3205
3206
    num_gqa_groups : Optional[int] = None
                    number of GQA groups in the transformer layer.
                    Grouped Query Attention is described in
                    `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
                    This only affects the keys and values, not the queries.
                    GQA-1 is equivalent to Multi-Query Attention
                    (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
                    is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
3207
3208
    attention_dropout: float, default = 0.0
                      dropout probability for the dropout op during multi-head attention.
3209
    attn_mask_type: str, default = `causal`
3210
3211
3212
3213
3214
3215
3216
3217
3218
3219
3220
3221
3222
3223
                   type of attention mask passed into softmax operation, options are "`no_mask`",
                   "`padding`", "`causal`", "`padding,causal`", "`causal,padding`", and
                   "`arbitrary`", where "`padding,causal`" and "`causal,padding`" are equivalent.
                   This arg can be overridden by :attr:`attn_mask_type` in the `forward` method.
                   It is useful for cases involving compilation/tracing, e.g. ONNX export, and the
                   forward arg is useful for dynamically changing mask types, e.g. a different mask
                   for training and inference. For "`no_mask`", no attention mask is applied. For
                   "`causal`" or the causal mask in "`padding,causal`", TransformerEngine calculates
                   and applies an upper triangular mask to the softmax input. No user input is
                   needed. For "`padding`" or the padding mask in "`padding,causal`", users need to
                   provide the locations of padded tokens either via :attr:`cu_seqlens_q` and
                   :attr:`cu_seqlens_kv` in the shape of [batch_size + 1] or :attr:`attention_mask`
                   in the shape [batch_size, 1, 1, max_seq_len]. For the "`arbitrary`" mask, users
                   need to provide a mask that is broadcastable to the shape of softmax input.
3224
3225
3226
3227
3228
3229
    window_size: Optional[Tuple[int, int]], default = `None`
                sliding window size for local attention, where query at position i attends to keys
                in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
                + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
                window and causal mask specifically. Similar to :attr:`attn_mask_type`, it can
                be overridden by :attr:`window_size` in `forward` as well.
3230
3231
    attention_type: str, default = `self`
                   type of attention, either "`self`" and "`cross`".
3232
3233
3234
    layer_number: int, default = `None`
                 layer number of the current `DotProductAttention` when multiple such modules
                 are concatenated, for instance in consecutive transformer blocks.
3235
3236
3237
3238
3239
3240
3241
3242
3243
3244
    qkv_format: str, default = `sbhd`
               dimension format for `query_layer`, `key_layer` and `value_layer`,
               {`sbhd`, `bshd`, `thd`}. `s` stands for the sequence length, `b` batch size,
               `h` the number of heads, `d` head size, and `t` the total number of sequences
               in a batch, with `t = sum(s_i), for i = 0...b-1`. `sbhd` and `bshd` formats
               are used for when sequences in a batch are of equal length or padded to
               equal length, and the `thd` format is used for when sequences in a batch
               have different lengths. Please note that these formats do not reflect how
               tensors `query_layer`, `key_layer`, `value_layer` are laid out in memory.
               For that, please use `_get_qkv_layout` to gain the layout information.
3245
3246
3247
3248
3249
3250
3251
3252
3253

    Parallelism parameters
    ----------------------
    sequence_parallel : bool, default = `False`
                       if set to `True`, uses sequence parallelism.
    tp_size : int, default = 1
             tensor parallel world size.
    tp_group : ProcessGroup, default = `None`
              tensor parallel process group.
3254
3255
3256
3257
3258
3259
3260
3261
3262
    cp_group : ProcessGroup, default = `None`
              context parallel process group.
    cp_global_ranks : list of global rank IDs, default = `None`
                     global rank IDs of GPUs that are in cp_group.
    cp_stream : CUDA stream, default = `None`
               context parallelism splits flash attention into multiple steps for
               compute and communication overlapping. To address the wave quantization
               issue of each split step, we add an additional CUDA stream so that we
               can overlap two flash attention kernels.
3263
3264
3265
3266
3267
3268
    """

    def __init__(
        self,
        num_attention_heads: int,
        kv_channels: int,
3269
        num_gqa_groups: Optional[int] = None,
3270
        attention_dropout: float = 0.0,
3271
        qkv_format: str = "sbhd",
3272
        attn_mask_type: str = "causal",
3273
        window_size: Optional[Tuple[int, int]] = None,
3274
3275
3276
3277
3278
        sequence_parallel: bool = False,
        tp_size: int = 1,
        get_rng_state_tracker: Optional[Callable] = None,
        tp_group: Optional[dist_group_type] = None,
        layer_number: Optional[int] = None,
3279
        attention_type: str = "self",
3280
        cp_group: Optional[dist_group_type] = None,
3281
        cp_global_ranks: List[int] = None,
3282
        cp_stream: torch.cuda.Stream = None,
3283
3284
3285
    ) -> None:
        super().__init__()

3286
        self.qkv_format = qkv_format
3287
3288
3289
        attn_mask_type = attn_mask_type.replace(",","_")
        if attn_mask_type == "causal_padding":
            attn_mask_type = "padding_causal"
3290
        self.attn_mask_type = attn_mask_type
3291
3292
        self.window_size = window_size
        self.window_size = check_set_window_size(attn_mask_type, self.window_size)
3293
        self.tp_size = tp_size if tp_group is None else get_distributed_world_size(tp_group)
3294
3295
        self.tp_group = tp_group
        self.get_rng_state_tracker = get_rng_state_tracker
3296
        self.num_attention_heads = num_attention_heads
3297
        self.layer_number = 1 if layer_number is None else layer_number
3298
3299
3300
        self.cp_group = cp_group
        self.cp_global_ranks = cp_global_ranks
        self.cp_stream = cp_stream
3301

3302
3303
3304
        self.hidden_size_per_attention_head = kv_channels
        self.num_gqa_groups = (
            num_attention_heads if num_gqa_groups is None else num_gqa_groups
3305
        )
3306
3307
3308
3309
        self.num_gqa_groups_per_partition = int(self.num_gqa_groups // tp_size)

        assert (num_attention_heads % self.num_gqa_groups == 0
                ), "The number of attention heads must be divisible by the number of GQA groups!"
3310

3311
        self.rng_states_tracker = None
3312
3313
3314
        if sequence_parallel or get_rng_state_tracker is None:
            attention_dropout_ctx = nullcontext
        else:
3315
3316
3317
            self.rng_states_tracker = get_rng_state_tracker()
            set_all_rng_states(self.rng_states_tracker.get_states())
            attention_dropout_ctx = self.rng_states_tracker.fork
3318
3319
3320
3321

        norm_factor = math.sqrt(self.hidden_size_per_attention_head)

        self.device_compute_capability = get_device_compute_capability()
3322
3323
        self.deterministic = not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1"))) \
                             or torch.are_deterministic_algorithms_enabled()
3324

3325
3326
        self.use_flash_attention = (
            int(os.getenv("NVTE_FLASH_ATTN", "1"))
Tim Moon's avatar
Tim Moon committed
3327
            and self.device_compute_capability >= (8, 0)
3328
        )
3329
        if not _flash_attn_2_4_1_plus and self.deterministic:
3330
3331
            self.use_flash_attention = False
            warnings.warn(
3332
3333
3334
                "Disabling usage of FlashAttention since version <2.4.1 does not support "
                "deterministic execution. In order to use FA with deterministic behavior,"
                " please install FlashAttention version >=2.4.1."
3335
3336
            )

3337
3338
        self.use_fused_attention = (
            int(os.getenv("NVTE_FUSED_ATTN", "1"))
Tim Moon's avatar
Tim Moon committed
3339
            and self.device_compute_capability >= (8, 0)
3340
        )
3341

3342
3343
3344
3345
3346
3347
3348
        assert (
            attention_type in AttnTypes
        ), f"attention_type {attention_type} not supported"

        self.attention_type = attention_type
        self.attention_dropout = attention_dropout

3349
3350
3351
3352
3353
3354
        attn_kwargs = {
            "attention_dropout": attention_dropout,
            "attention_dropout_ctx": attention_dropout_ctx,
        }

        if self.use_flash_attention:
3355
3356
3357
3358
3359
3360
            self.flash_attention = FlashAttention(norm_factor,
                                                  attention_type=attention_type,
                                                  layer_number=layer_number,
                                                  deterministic=self.deterministic,
                                                  **attn_kwargs)

3361
        # Instantiating three types since use of flash-attn and FusedAttention
3362
        # might be ruled out due to forward inputs.
3363
        if self.use_fused_attention:
3364
3365
3366
3367
            self.fused_attention = FusedAttention(norm_factor,
                                                  attention_type=attention_type,
                                                  layer_number=layer_number,
                                                  deterministic=self.deterministic,
3368
                                                  **attn_kwargs)
3369

3370
3371
3372
3373
3374
3375
3376
        self.unfused_attention = UnfusedDotProductAttention(
            norm_factor, **attn_kwargs, layer_number=layer_number)

    def _checkpointed_attention_forward(
        self,
        attention_func: Callable,
        *forward_args: Tuple[torch.Tensor, ...],
3377
        **forward_kwargs: Dict[str, Any],
3378
3379
3380
    ) -> torch.Tensor:
        """Forward method with activation checkpointing."""

3381
3382
        def custom_forward(*input_args, **input_kwargs):
            return attention_func(*input_args, **input_kwargs)
3383
3384
3385

        hidden_states = checkpoint(
            custom_forward,
3386
3387
3388
            distribute_saved_activations=False,
            get_rng_state_tracker=self.get_rng_state_tracker,
            tp_group=self.tp_group,
3389
            *forward_args,
3390
            **forward_kwargs,
3391
3392
3393
3394
        )

        return hidden_states

3395
3396
3397
3398
3399
3400
    def set_context_parallel_group(
        self,
        cp_group: Union[dist_group_type, None],
        cp_global_ranks: List[int],
        cp_stream: torch.cuda.Stream,
    ) -> None:
3401
3402
3403
3404
3405
3406
3407
3408
3409
3410
3411
3412
3413
        """
        Set the context parallel attributes for the given
        module before executing the forward pass.

        Parameters
        ----------
        cp_group : ProcessGroup
                  context parallel process group.
        cp_global_ranks : List[int]
                         list of global ranks in the context group.
        cp_stream : torch.cuda.Stream
                   cuda stream for context parallel execution.
        """
3414
3415
3416
3417
        self.cp_group = cp_group
        self.cp_global_ranks = cp_global_ranks
        self.cp_stream = cp_stream

3418
    @no_torch_dynamo(recursive=False)
3419
3420
3421
3422
3423
    def forward(
        self,
        query_layer: torch.Tensor,
        key_layer: torch.Tensor,
        value_layer: torch.Tensor,
3424
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
3425
3426
3427
        qkv_format: Optional[str] = None,
        cu_seqlens_q: Optional[torch.Tensor] = None,
        cu_seqlens_kv: Optional[torch.Tensor] = None,
3428
3429
        max_seqlen_q: Optional[int] = None,
        max_seqlen_kv: Optional[int] = None,
3430
        attn_mask_type: Optional[str] = None,
3431
        window_size: Optional[Tuple[int, int]] = None,
3432
        checkpoint_core_attention: bool = False,
3433
3434
        core_attention_bias_type: str = "no_bias",
        core_attention_bias: Optional[torch.Tensor] = None,
3435
        alibi_slopes: Optional[torch.Tensor] = None,
3436
        fast_zero_fill: bool = True,
3437
        inference_params: Optional[InferenceParams] = None,
3438
        is_first_microbatch: Optional[bool] = None,
3439
3440
3441
3442
3443
3444
    ) -> torch.Tensor:
        """
        Dot Product Attention Layer.

        .. note::

3445
3446
            Argument :attr:`attention_mask` is only used when :attr:`attn_mask_type`
            includes '"padding"' or `"arbitrary"`.
3447
3448
3449
3450
3451
3452
3453
3454
3455

        .. note::

            Input tensors :attr:`query_layer`, :attr:`key_layer`, and :attr:`value_layer`
            must each be of shape (:attr:`sequence_length`, :attr:`batch_size`,
            :attr:`num_attention_heads`, :attr:`kv_channels`). Output of shape
            (:attr:`sequence_length`, :attr:`batch_size`, :attr:`num_attention_heads`
            * :attr:`kv_channels`) is returned.

3456
3457
        .. note::

3458
3459
3460
3461
3462
3463
3464
3465
3466
3467
3468
3469
3470
3471
3472
3473
3474
3475
            DotProductAttention supports three backends: 1) FlashAttention which calls
            HazyResearch/Dao-AILab's `flash-attn <https://arxiv.org/pdf/2305.13245.pdf>`_
            PyTorch API, 2) FusedAttention which has multiple fused attention implementations
            based on `cuDNN Graph API
            <https://docs.nvidia.com/deeplearning/cudnn/developer-guide/index.html#op-fusion>`_
            (see :attr:`FusedAttention` for more details on FusedAttention backends), and 3)
            UnfusedDotProductAttention which is the native PyTorch implementation
            with fused scaled masked softmax.

        .. note::

            Users can use environment variables :attr:`NVTE_FLASH_ATTN`, :attr:`NVTE_FUSED_ATTN`,
            and :attr:`NVTE_FUSED_ATTN_BACKEND` to control which DotProductAttention backend,
            and FusedAttention backend if applicable, to use. TransformerEngine prioritizes
            FlashAttention over FusedAttention and over UnfusedDotProductAttention.
            If FusedAttention is being used, users can also choose to switch to flash-attn's
            implementation for backward by setting :attr:`NVTE_FUSED_ATTN_USE_FAv2_BWD=1`
            (default: 0), because of the performance differences between various versions of
3476
3477
3478
3479
3480
            flash-attn and FusedAttention. Further, :attr:`NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT`
            can be used to enable (:attr:`1`) or disable (:attr:`0`) the workspace related
            optimizations in FusedAttention. When unset, TransformerEngine determines the code path
            based on its internal logic. These optimizations trade memory for performance
            and should be used with care.
3481

3482
3483
3484
3485
3486
3487
3488
3489
        Parameters
        ----------
        query_layer : torch.Tensor
                     Query tensor.
        key_layer : torch.Tensor
                   Key tensor.
        value_layer : torch.Tensor
                     Value tensor.
3490
3491
3492
3493
3494
3495
3496
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]],
             default = `None`. Boolean tensor(s) used to mask out attention softmax input.
             It should be 'None' for 'causal' and 'no_mask' types. For 'padding' masks, it should be
             a single tensor of [batch_size, 1, 1, seqlen_q] for self-attention, and a tuple of
             two tensors in shapes [batch_size, 1, 1, seqlen_q] and [batch_size, 1, 1, seqlen_kv]
             for cross-attention. For the 'arbitrary' mask type, it should be in a shape that is
             broadcastable to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv].
3497
3498
3499
3500
3501
3502
3503
3504
        qkv_format: str, default = `None`
                   If provided, overrides :attr:`qkv_format` from initialization.
        cu_seqlens_q: Optional[torch.Tensor], default = `None`
                   Cumulative sum of sequence lengths in a batch for `query_layer`,
                   with shape [batch_size + 1] and dtype torch.int32.
        cu_seqlens_kv: Optional[torch.Tensor], default = `None`
                   Cumulative sum of sequence lengths in a batch for `key_layer` and `value_layer`,
                   with shape [batch_size + 1] and dtype torch.int32.
3505
3506
3507
3508
3509
3510
        max_seqlen_q: Optional[int], default = `None`
                      Maximum sequence length in `query_layer`.
                      Calculated from `cu_seqlens_q` if not provided.
        max_seqlen_kv: Optional[int], default = `None`
                       Maximum sequence length in `key_layer` and `value_layer`.
                       Calculated from `cu_seqlens_kv` if not provided.
3511
3512
3513
        attn_mask_type: {`no_mask`, `padding`, `causal`, `padding,causal`, `causal,padding`,
                       `arbitrary`}, default = `None`. Type of attention mask passed into
                       softmax operation. 'padding,causal' and 'causal,padding' are equivalent.
3514
        window_size: Optional[Tuple[int, int]], default = `None`
3515
                    Sliding window size for local attention.
3516
3517
3518
3519
3520
        checkpoint_core_attention : bool, default = `False`
                                   If true, forward activations for attention are recomputed
                                   during the backward pass in order to save memory that would
                                   otherwise be occupied to store the forward activations until
                                   backprop.
3521
        core_attention_bias_type: str, default = `no_bias`
3522
                    Bias type, {`no_bias`, `pre_scale_bias`, `post_scale_bias`, `alibi`}
3523
        core_attention_bias: Optional[torch.Tensor], default = `None`
3524
3525
                    Bias tensor for Q * K.T, shape [1, num_head, max_seqlen_q, max_seqlen_kv].
                    It should be 'None' for 'no_bias' and 'alibi' bias types.
3526
3527
3528
3529
        alibi_slopes: Optional[torch.Tensor], default = `None`
                     ALiBi slopes in FP32 and shape [nheads] or [batch_size, nheads].
                     It adds a bias of (-alibi_slope * (i + seqlen_k - seqlen_q - j))
                     to the attention score of query i and key j.
3530
        fast_zero_fill: bool, default = `True`
3531
                    Whether to use the fast path to set output tensors to 0 or not.
3532
3533
3534
3535
3536
3537
3538
3539
3540
3541
        inference_params: Optional[InferenceParams], default = `None`
            Optimizes execution performance during inference by caching Keys and Values of the
            current decoding iteration. These cached values are appended to the K and V values
            computed in previous iterations, eliminating the need to recalculate them for the
            entire sequence.
            Initialization of `inference_params` is required prior to use to ensure sufficient
            memory allocation.
            Adjustments of the sequence_len_offset should be done after a complete forward pass.
            If rotary positional embeddings (RoPE) are utilized, they must be prepared beforehand.
            Supports "sbhd" and "bshd" layouts, with the "sbhd" layout being more efficient.
3542
3543
3544
3545
3546
3547
3548
3549
3550
3551
3552
3553
3554
        is_first_microbatch : {True, False, None}, default = None
                             During training using either gradient accumulation or
                             pipeline parallelism a minibatch of data is further split
                             into microbatches. Between the microbatches of the same minibatch
                             the model weights are not updated. Setting this parameter indicates
                             whether the current microbatch is the first in a minibatch or not.
                             When set, this parameter enables additional optimizations:

                             * during FP8 training, it allows caching of the FP8 versions of
                               the weights
                             * it also allows skipping gradient accumulation during the
                               first microbatch (since it is the first gradient being
                               produced)
3555
3556
        """

3557
3558
3559
3560
        assert (
            query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda
            ), 'DotProductAttention only supports CUDA tensors.'

3561
3562
3563
        assert (key_layer.shape == value_layer.shape
            ), "Keys and values must have the same shape!"

3564
3565
        if attn_mask_type is not None:
            window_size = check_set_window_size(attn_mask_type, window_size)
3566
        if attn_mask_type is None:
3567
            attn_mask_type = self.attn_mask_type
3568
3569
3570
3571
3572
3573
3574
3575
        else:
            attn_mask_type = attn_mask_type.replace(",","_")
            if attn_mask_type == "causal_padding":
                attn_mask_type = "padding_causal"

        assert (attn_mask_type in AttnMaskTypes
            ), f"Attention mask type {attn_mask_type} is not supported!"

3576
3577
3578
3579
3580
3581
3582
3583
        if self.rng_states_tracker is not None and is_graph_capturing():
            assert (
                isinstance(self.rng_states_tracker, CudaRNGStatesTracker)
            ), "Unsupported RNG states tracker."
            assert (
                graph_safe_rng_available()
            ), "Upgrade PyTorch version to get RNG manipulation support for cuda graph capture."

3584
3585
3586
        if window_size is None:
            window_size = self.window_size

3587
3588
        if qkv_format is None:
            qkv_format = self.qkv_format
3589

3590
3591
3592
3593
3594
3595
3596
3597
3598
3599
3600
3601
3602
3603
3604
3605
3606
3607
3608
3609
3610
3611
3612
3613
3614
3615
3616
3617
3618
3619
3620
3621
3622
        if inference_params is not None:
            assert self.layer_number is not None, "Layer number must be set!"

            if qkv_format == "bshd":
                key_layer = key_layer.transpose(0, 1)
                value_layer = value_layer.transpose(0, 1)

            (inference_key_memory, inference_value_memory,
            ) = inference_params.key_value_memory_dict[self.layer_number]

            batch_start = inference_params.batch_size_offset
            batch_end = batch_start + key_layer.size(1)
            assert batch_end <= inference_key_memory.size(1)

            sequence_start = inference_params.sequence_len_offset
            sequence_end = sequence_start + key_layer.size(0)
            assert sequence_end <= inference_key_memory.size(0)

            # Copy keys and values into KV-cache
            inference_key_memory[
                sequence_start:sequence_end, batch_start:batch_end, ...] = key_layer
            inference_value_memory[
                sequence_start:sequence_end, batch_start:batch_end, ...] = value_layer
            key_layer = inference_key_memory[:sequence_end, batch_start:batch_end, ...]
            value_layer = inference_value_memory[:sequence_end, batch_start:batch_end, ...]

            if qkv_format == "bshd":
                key_layer = key_layer.transpose(0, 1)
                value_layer = value_layer.transpose(0, 1)

            key_layer = key_layer.contiguous()
            value_layer = value_layer.contiguous()

3623
        assert (key_layer.shape[-2] == self.num_gqa_groups_per_partition
3624
3625
3626
3627
3628
3629
3630
3631
3632
3633
3634
3635
3636
3637
3638
3639
3640
            and value_layer.shape[-2] == self.num_gqa_groups_per_partition
            ), f"Keys and values must have num_gqa_group = {self.num_gqa_groups} heads!"
        assert (qkv_format in ['sbhd', 'bshd', 'thd']
            ), "DotProductAttention only supports qkv_format = {'sbhd', 'bshd', 'thd'}!"

        if qkv_format == 'thd':
            assert (all(len(x.shape) == 3 for x in (query_layer, key_layer, value_layer))
                ), "Queries, keys and values must be 3D tensors when qkv_format = thd!"
            assert (cu_seqlens_q is not None and cu_seqlens_kv is not None
                ), "cu_seqlens_q and cu_seqlens_kv can not be None when qkv_format = thd!"
            assert (cu_seqlens_q.shape == cu_seqlens_kv.shape
                and len(cu_seqlens_q.shape) == 1
                and len(cu_seqlens_kv.shape) == 1
                ), "cu_seqlens_q and cu_seqlens_q must both have shape [batch_size + 1]!"
            assert (cu_seqlens_q.dtype == torch.int32
                and cu_seqlens_kv.dtype == torch.int32
                ), "cu_seqlens_q and cu_seqlens_q must both be in dtype torch.int32!"
3641
3642
3643
3644
3645
3646
            if max_seqlen_q is None:
                seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
                max_seqlen_q = seqlens_q.max().item()
            if max_seqlen_kv is None:
                seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
                max_seqlen_kv = seqlens_kv.max().item()
3647
3648
3649
3650
3651
3652
3653
3654
3655
3656
3657
3658
3659
3660
3661
3662
3663
3664
3665

        if qkv_format in ['sbhd', 'bshd']:
            assert (all(len(x.shape) == 4 for x in (query_layer, key_layer, value_layer))
                ), f"Queries, keys and values must be 4D tensors when qkv_format = {qkv_format}!"
            if qkv_format == 'sbhd':
                max_seqlen_q, max_seqlen_kv = (query_layer.shape[0], key_layer.shape[0])
            if qkv_format == 'bshd':
                max_seqlen_q, max_seqlen_kv = (query_layer.shape[1], key_layer.shape[1])
            if cu_seqlens_q is not None:
                seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
                assert (all(seqlens_q <= max_seqlen_q)
                    ), """Sequence lengths indicated by cu_seqlens_q must be no greater than
                    the sequence dimention in 'query_layer'!"""
            if cu_seqlens_kv is not None:
                seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
                assert (all(seqlens_kv <= max_seqlen_kv)
                    ), """Sequence lengths indicated by cu_seqlens_kv must be no greater than
                    the sequence dimention in 'key_layer' and 'value_layer'!"""

3666
3667
3668
3669
3670
3671
3672
3673
        if (isinstance(query_layer, Float8Tensor)
            and isinstance(key_layer, Float8Tensor)
            and isinstance(value_layer, Float8Tensor)):
            qkv_layout, query_layer._data, key_layer._data, value_layer._data = _get_qkv_layout(
                query_layer._data, key_layer._data, value_layer._data, qkv_format = qkv_format)
        else:
            qkv_layout, query_layer, key_layer, value_layer = _get_qkv_layout(
                query_layer, key_layer, value_layer, qkv_format = qkv_format)
3674

3675
3676
        # The priority for attention backends (subject to availability and clearing the filters)
        # is: FlashAttention > FusedAttention (cuDNN) > UnfusedDotProductAttention.
3677
        use_flash_attention = self.use_flash_attention
3678
        use_fused_attention = self.use_fused_attention
3679
        use_unfused_attention = True
3680

3681
3682
3683
        # The following section filters out some backends based on
        # certain asserts before executing the forward pass.

3684
3685
3686
3687
3688
        # Filter: ONNX export.
        if is_in_onnx_export_mode():
            use_flash_attention = False
            use_fused_attention = False

3689
        # Filter: Input type.
3690
3691
3692
        if (query_layer.dtype not in [torch.bfloat16, torch.float16]
            or key_layer.dtype not in [torch.bfloat16, torch.float16]
            or value_layer.dtype not in [torch.bfloat16, torch.float16]
3693
            or any(isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer])
3694
3695
        ):
            use_flash_attention = False
3696
3697
3698
3699
        if (query_layer.dtype not in [torch.bfloat16, torch.float16]
            or key_layer.dtype not in [torch.bfloat16, torch.float16]
            or value_layer.dtype not in [torch.bfloat16, torch.float16]
        ):
3700
            use_fused_attention = False
3701

3702
        # Filter: Device and dimensions.
3703
        # FAv2 supports head_dim <= 256, and for >192 requires sm80/sm90
3704
3705
3706
3707
3708
        # FAv2 requires head_dim % 8 == 0
        if (key_layer.shape[-1] > 256
            or key_layer.shape[-1] % 8 != 0
            or (key_layer.shape[-1] > 192
                and self.device_compute_capability not in ((8, 0), (9, 0)))):
3709
3710
            use_flash_attention = False

3711
        # Filter: cross attention + causal mask.
3712
3713
3714
        # (in training mode)
        if (inference_params is None
            and _flash_attn_2_1_plus
3715
            and "causal" in attn_mask_type
3716
3717
            and max_seqlen_q != max_seqlen_kv
        ):
3718
            warnings.warn(
3719
3720
                "In training mode, disable the use of FlashAttention since version 2.1+ has "
                "changed its behavior for causal mask in cross attention. See "
3721
3722
3723
3724
                "https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag"
            )
            use_flash_attention = False

3725
3726
3727
        context_parallel = (self.cp_group is not None and \
            get_distributed_world_size(self.cp_group) != 1)

3728
3729
3730
3731
3732
3733
3734
        # Filter: sliding window attention.
        # UnfusedDotProductAttention can support SWA via arbitrary attention mask.
        if window_size not in ((-1, -1), (-1, 0)):
            use_fused_attention = False
            if (not _flash_attn_2_3_plus) or context_parallel:
                use_flash_attention = False

3735
        # Filter: Attention mask type.
3736
        #   attn_mask_type(s)    |     supported backends
3737
        # ------------------------------------------------
3738
3739
        #   no_mask              |     All
        #   padding              |     UnfusedDotProductAttention, FlashAttention, FusedAttention
3740
        #   causal               |     All
3741
        #   padding + causal     |     FlashAttention, FusedAttention
3742
3743
3744
3745
3746
        #   arbitrary            |     UnfusedDotProductAttention
        #
        if attn_mask_type == "arbitrary":
            use_flash_attention = False
            use_fused_attention = False
3747
3748
3749
3750
3751

        if (inference_params is None
            and "causal" in attn_mask_type
            and max_seqlen_q != max_seqlen_kv
        ):
3752
            use_unfused_attention = False
3753

3754
3755
3756
3757
3758
3759
3760
3761
3762
3763
3764
3765
3766
3767
3768
3769
3770
3771
3772
3773
3774
3775
3776
3777
3778
3779
3780
3781
        # Filter: bias.
        global _alibi_cache
        if alibi_slopes is not None:
            assert (core_attention_bias_type == "alibi"
                ), "core_attention_bias_type must be alibi in order to use alibi_slopes!"
            if self.layer_number == 1:
                _alibi_cache["_alibi_slopes_require_update"] = True
                _alibi_cache["_alibi_bias_require_update"] = True
        if core_attention_bias_type == "alibi":
            assert (core_attention_bias is None
                ), "core_attention_bias must be None when core_attention_bias_type is alibi!"
            if (_alibi_cache["_num_heads"] != query_layer.shape[-2]
                or _alibi_cache["_max_seqlen_q"] != max_seqlen_q
                or _alibi_cache["_max_seqlen_kv"] != max_seqlen_kv
                or _alibi_cache["_alibi_slopes"] is None):
                _alibi_cache["_alibi_slopes_require_update"] = True
                _alibi_cache["_alibi_bias_require_update"] = True

        if core_attention_bias_type not in ["no_bias", "alibi"] or core_attention_bias is not None:
            use_flash_attention = False

        fu_core_attention_bias_type = core_attention_bias_type
        fu_core_attention_bias = core_attention_bias
        if core_attention_bias_type == "alibi" and use_fused_attention and alibi_slopes is not None:
            fu_core_attention_bias_type = "post_scale_bias"
            _, fu_core_attention_bias = get_alibi(
                query_layer.shape[-2], max_seqlen_q, max_seqlen_kv, alibi_slopes=alibi_slopes,
                bias_dtype=query_layer.dtype)
3782
3783
3784
3785
3786
3787
3788
        if (use_fused_attention
            and fu_core_attention_bias_type == "post_scale_bias"
            and (fu_core_attention_bias.shape[0] != 1
            or fu_core_attention_bias.shape[1] != query_layer.shape[-2])):
            if fu_core_attention_bias.requires_grad:
                # remove this line when cuDNN adds bwd support for
                # [1, 1, s, s], [b, 1, s, s] and [b, h, s, s]
3789
                use_fused_attention = False
3790
            else:
3791
3792
3793
                # max512 backend will only support [1, h, s, s]
                os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1"

3794
3795
        if use_fused_attention:
            fused_attention_backend = tex.get_fused_attn_backend(
3796
3797
3798
3799
                TE_DType[query_layer.dtype]
                if not isinstance(query_layer, Float8Tensor) else query_layer._fp8_dtype,
                TE_DType[key_layer.dtype]
                if not isinstance(key_layer, Float8Tensor) else key_layer._fp8_dtype,
3800
                QKVLayout[qkv_layout],
3801
                AttnBiasType[fu_core_attention_bias_type],
3802
                AttnMaskType[attn_mask_type],
3803
                self.attention_dropout,
3804
3805
3806
3807
3808
3809
                query_layer.shape[-2], # num_attn_heads
                key_layer.shape[-2], # num_gqa_groups
                max_seqlen_q,
                max_seqlen_kv,
                query_layer.shape[-1], # head_dim
            )
3810
3811
            # DPA does not support FP8; for FP8, use cpp_extensions modules directly
            is_backend_avail = (fused_attention_backend in
3812
3813
3814
                [FusedAttnBackend["F16_max512_seqlen"],
                FusedAttnBackend["F16_arbitrary_seqlen"],
                FusedAttnBackend["FP8"]])
3815
3816
3817
3818
            use_fused_attention = ( \
                use_fused_attention and is_backend_avail and \
                (not context_parallel or \
                 fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]))
3819
3820
3821
3822
3823
            if (fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]
                and fu_core_attention_bias_type == "post_scale_bias"
                and (fu_core_attention_bias.shape[0] != 1
                or fu_core_attention_bias.shape[1] != query_layer.shape[-2])):
                use_fused_attention = False
3824

3825
3826
3827
3828
3829
3830
3831
3832
3833
3834
3835
3836
3837
3838
3839
3840
3841
3842
        # Filter: determinism.
        # backend                                  | deterministic
        # ---------------------------------------------------------
        # flash-attn v1                            | yes
        # flash-attn v2                            | no
        # FusedAttnBackend["F16_max512_seqlen"]    | yes
        # FusedAttnBackend["F16_arbitrary_seqlen"] | workspace optimization path: yes; otherwise: no
        # UnfusedDotProductAttention               | yes
        #
        # Note that FusedAttnBackend["F16_arbitrary_seqlen"] only has workspace optimization path
        # on sm90 architectures.
        #
        if (use_fused_attention
            and fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]
            and self.deterministic
            and self.device_compute_capability != (9, 0)):
            use_fused_attention = False

3843
3844
3845
3846
3847
3848
        # Select FusedAttention on sm90 and FlashAttention on others for performance
        if (use_flash_attention
            and use_fused_attention
            and fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]):
            if self.device_compute_capability == (9, 0):
                use_flash_attention = False
3849
3850

        if use_flash_attention:
3851
3852
            if _NVTE_DEBUG:
                print("[DotProductAttention]: using flash-attn",_flash_attn_version)
3853
3854
3855
            if core_attention_bias_type == "alibi":
                alibi_slopes, _ = get_alibi(
                    query_layer.shape[-2], max_seqlen_q, max_seqlen_kv, alibi_slopes=alibi_slopes)
3856
3857
3858
3859
3860
3861
3862
3863
            return self.flash_attention(query_layer,
                                        key_layer,
                                        value_layer,
                                        attention_mask=attention_mask,
                                        qkv_layout=qkv_layout,
                                        cu_seqlens_q=cu_seqlens_q,
                                        cu_seqlens_kv=cu_seqlens_kv,
                                        attn_mask_type=attn_mask_type,
3864
                                        window_size=window_size,
3865
                                        alibi_slopes=alibi_slopes,
3866
3867
                                        cp_group=self.cp_group,
                                        cp_global_ranks=self.cp_global_ranks,
3868
3869
3870
                                        cp_stream=self.cp_stream,
                                        max_seqlen_q=max_seqlen_q,
                                        max_seqlen_kv=max_seqlen_kv)
3871

3872
        if use_fused_attention:
3873
3874
3875
            if _NVTE_DEBUG:
                print("[DotProductAttention]: using cuDNN fused attention (backend "
                    + str(int(fused_attention_backend)) + ")")
3876
            if checkpoint_core_attention:
3877
3878
3879
3880
3881
3882
3883
3884
                return self._checkpointed_attention_forward(
                    self.fused_attention,
                    query_layer,
                    key_layer,
                    value_layer,
                    qkv_layout=qkv_layout,
                    cu_seqlens_q=cu_seqlens_q,
                    cu_seqlens_kv=cu_seqlens_kv,
3885
3886
                    max_seqlen_q=max_seqlen_q,
                    max_seqlen_kv=max_seqlen_kv,
3887
3888
3889
                    attn_mask_type=attn_mask_type,
                    attention_mask=attention_mask,
                    fused_attention_backend=fused_attention_backend,
3890
3891
                    core_attention_bias_type=fu_core_attention_bias_type,
                    core_attention_bias=fu_core_attention_bias,
3892
3893
3894
3895
                    fast_zero_fill=fast_zero_fill,
                    cp_group=self.cp_group,
                    cp_global_ranks=self.cp_global_ranks,
                    cp_stream=self.cp_stream,
3896
                    is_first_microbatch=is_first_microbatch)
3897
3898
3899
3900
3901
3902
3903
            return self.fused_attention(
                query_layer,
                key_layer,
                value_layer,
                qkv_layout=qkv_layout,
                cu_seqlens_q=cu_seqlens_q,
                cu_seqlens_kv=cu_seqlens_kv,
3904
3905
                max_seqlen_q=max_seqlen_q,
                max_seqlen_kv=max_seqlen_kv,
3906
3907
3908
                attn_mask_type=attn_mask_type,
                attention_mask=attention_mask,
                fused_attention_backend=fused_attention_backend,
3909
3910
                core_attention_bias_type=fu_core_attention_bias_type,
                core_attention_bias=fu_core_attention_bias,
3911
3912
3913
3914
                fast_zero_fill=fast_zero_fill,
                cp_group=self.cp_group,
                cp_global_ranks=self.cp_global_ranks,
                cp_stream=self.cp_stream,
3915
                is_first_microbatch=is_first_microbatch)
3916
3917
3918

        assert (not context_parallel), \
            "Context parallelism is only implemented with Flash Attention and Fused Attention!"
3919

3920
3921
3922
3923
3924
3925
3926
        from .cpu_offload import CPUOffloadEnabled
        if CPUOffloadEnabled:
            warnings.warn(
                           "Attention activation Offloading is only implemented"
                           "with Flash Attention and Fused Attention!"
                         )

3927
3928
        if _NVTE_DEBUG:
            print("[DotProductAttention]: using unfused DPA")
3929
3930
3931
3932
3933
3934
3935
3936
3937
3938
3939
3940
3941
        if use_unfused_attention:
            if checkpoint_core_attention:
                return self._checkpointed_attention_forward(
                    self.unfused_attention,
                    query_layer,
                    key_layer,
                    value_layer,
                    qkv_layout = qkv_layout,
                    cu_seqlens_q = cu_seqlens_q,
                    cu_seqlens_kv = cu_seqlens_kv,
                    attn_mask_type = attn_mask_type,
                    attention_mask = attention_mask,
                    core_attention_bias_type = core_attention_bias_type,
3942
3943
                    core_attention_bias = core_attention_bias,
                    alibi_slopes = alibi_slopes)
3944
3945
3946
3947
3948
3949
3950
3951
3952
            return self.unfused_attention(query_layer,
                    key_layer,
                    value_layer,
                    qkv_layout = qkv_layout,
                    cu_seqlens_q = cu_seqlens_q,
                    cu_seqlens_kv = cu_seqlens_kv,
                    attn_mask_type = attn_mask_type,
                    attention_mask = attention_mask,
                    core_attention_bias_type = core_attention_bias_type,
3953
3954
                    core_attention_bias = core_attention_bias,
                    alibi_slopes = alibi_slopes)
3955
3956

        raise Exception("No dot product attention support for the provided inputs!")
3957
3958


3959
3960
3961
3962
3963
3964
3965
class MultiheadAttention(torch.nn.Module):
    r"""
    Multi-head Attention (MHA), including Query,
    Key, Value and Output projection.

    .. note::

3966
3967
        Argument :attr:`attention_mask` in the `forward` call is only used when
        :attr:`attn_mask_type` includes '"padding"' or `"arbitrary"`.
3968

3969
3970
3971
3972
3973
3974
3975
3976
3977
3978
3979
3980
3981
3982
3983
3984
3985
3986
3987
3988
3989
3990
3991
3992
3993
    Parameters
    ----------
    hidden_size : int
                 size of each input sample.
    num_attention_heads : int
                         number of attention heads in the transformer layer.
    kv_channels: int, default = `None`
                number of key-value channels. defaults to
                :attr:`hidden_size` / :attr:`num_attention_heads` if `None`.
    attention_dropout: float, default = 0.1
                      dropout probability for the dropout op during multi-head attention.
    layernorm_epsilon : float, default = 1e-5
                       a value added to the denominator of layer normalization
                       for numerical stability.
    init_method : Callable, default = `None`
                 used for initializing weights of QKV and FC1 weights in the following way:
                 `init_method(weight)`. When set to `None`, defaults to
                 `torch.nn.init.normal_(mean=0.0, std=0.023)`.
    output_layer_init_method : Callable, default = `None`
                              used for initializing weights of PROJ and FC2 in the following way:
                              `output_layer_init_method(weight)`. When set to `None`, defaults to
                              `torch.nn.init.normal_(mean=0.0, std=0.023)`.
    layer_number: int, default = `None`
                 layer number of the current `TransformerLayer` when multiple such modules are
                 concatenated to form a transformer block.
3994
3995
    attn_mask_type: {'no_mask', 'padding', 'causal', 'padding_causal' 'arbitrary'},
                   default = `causal`
3996
3997
3998
3999
4000
                   type of attention mask passed into softmax operation. Overridden by
                   :attr:`attn_mask_type` in the `forward` method. The forward
                   arg is useful for dynamically changing mask types, e.g. a different
                   mask for training and inference. The init arg is useful for cases
                   involving compilation/tracing, e.g. ONNX export.
4001
4002
4003
4004
4005
4006
    window_size: Optional[Tuple[int, int]], default = `None`
                sliding window size for local attention, where query at position i attends to keys
                in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
                + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
                window and causal mask specifically. Similar to :attr:`attn_mask_type`, it can
                be overridden by :attr:`window_size` in `forward` as well.
4007
4008
4009
4010
4011
4012
4013
4014
4015
4016
4017
4018
4019
    num_gqa_groups : int, default = `None`
                         number of GQA groups in the transformer layer.
                         Grouped Query Attention is described in
                         `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
                         This only affects the keys and values, not the querys.
                         GQA-1 is equivalent to Multi-Query Attention
                         (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
                         is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
    return_layernorm_output : bool, default = `False`
                             if set to `True`, output of layernorm is returned from the forward
                             together with the output of the linear transformation.
                             Example use case: residual connection for transformer module is
                             taken post layernorm.
4020
4021
    input_layernorm: bool, default = `False`
                     if set to `True`, layer normalization to the input is applied.
4022
4023
4024
4025
4026
4027
4028
4029
4030
4031
4032
4033
4034
4035
4036
4037
4038
4039
4040
4041
4042
4043
4044
    attention_type: { 'self', 'cross' }, default = 'self'
                   type of attention applied.
    zero_centered_gamma : bool, default = 'False'
                         if set to 'True', gamma parameter in LayerNorm is initialized to 0 and
                         the LayerNorm formula changes to

                         .. math::
                            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} *
                            (1 + \gamma) + \beta
    normalization : { 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm'
                   type of normalization applied.
    qkv_weight_interleaved : bool, default = `True`
                            if set to `False`, the QKV weight is interpreted as a concatenation of
                            query, key, and value weights along the `0th` dimension. The default
                            interpretation is that the individual `q`, `k`, and `v` weights for each
                            attention head are interleaved. This parameter is set to `False` when
                            using :attr:`fuse_qkv_params=False`.
    bias : bool, default = `True`
          if set to `False`, the transformer layer will not learn any additive biases.
    device : Union[torch.device, str], default = "cuda"
          The device on which the parameters of the model will allocated. It is the user's
          responsibility to ensure all parameters are moved to the GPU before running the
          forward pass.
4045
4046
4047
4048
4049
4050
4051
4052
    qkv_format: str, default = `sbhd`
            dimension format for `query_layer`, `key_layer` and `value_layer`,
            {`sbhd`, `bshd`}. `s` stands for the sequence length, `b` batch size,
            `h` the number of heads and `d` head size. `sbhd` and `bshd` formats
            are used for when sequences in a batch are of equal length or padded to
            equal length. Please note that these formats do not reflect how
            tensors `query_layer`, `key_layer`, `value_layer` are laid out in memory.
            For that, please use `_get_qkv_layout` to gain the layout information.
4053
4054
4055
4056
4057
4058
4059
4060
4061
4062
4063
4064
4065
4066
4067
4068
4069
4070
4071
4072
4073
4074
4075
4076
4077
4078
4079
4080
4081
4082
4083
4084
4085
4086
4087
4088
4089
4090
4091
4092

    Parallelism parameters
    ----------------------
    set_parallel_mode : bool, default = `False`
                      if set to `True`, QKV and FC1 layers are used as Column Parallel
                      whereas PROJ and FC2 is used as Row Parallel as described
                      `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
    sequence_parallel : bool, default = `False`
                       if set to `True`, uses sequence parallelism.
    tp_group : ProcessGroup, default = `None`
              tensor parallel process group.
    tp_size : int, default = 1
             used as TP (tensor parallel) world size when TP groups are not formed during
             initialization. In this case, users must call the
             `set_tensor_parallel_group(tp_group)` method on the initialized module before the
             forward pass to supply the tensor parallel group needed for tensor and sequence
             parallel collectives.

    Optimization parameters
    -----------------------
    fuse_wgrad_accumulation : bool, default = 'False'
                             if set to `True`, enables fusing of creation and accumulation of
                             the weight gradient. When enabled, it is assumed that the weights
                             have an additional `main_grad` attribute (used instead of the
                             regular `grad`) which is a pre-allocated buffer of the correct
                             size to accumulate gradients in.
    params_dtype : torch.dtype, default = `torch.get_default_dtype()`
                  it controls the type used to allocate the initial parameters. Useful when
                  the model is trained with lower precision and the original FP32 parameters
                  would not fit in GPU memory.
    return_bias : bool, default = `False`
                 when set to `True`, this module will not apply the additive bias itself, but
                 instead return the bias value during the forward pass together with the
                 output of the linear transformation :math:`y = xA^T`. This is useful when
                 the bias addition can be fused to subsequent operations.
    fuse_qkv_params: bool, default = 'False'
                    if set to `True`, `TransformerLayer` module exposes a single fused
                    parameter for query-key-value. This enables optimizations such as QKV
                    fusion without concatentations/splits and also enables the argument
                    `fuse_wgrad_accumulation`.
4093
4094
4095
4096
4097
4098
    """

    def __init__(
        self,
        hidden_size: int,
        num_attention_heads: int,
4099
4100
4101
4102
4103
        kv_channels: Optional[int] = None,
        attention_dropout: float = 0.1,
        layernorm_epsilon: float = 1e-5,
        init_method: Optional[Callable] = None,
        output_layer_init_method: Optional[Callable] = None,
4104
        layer_number: Optional[int] = None,
4105
        attn_mask_type: str = "causal",
4106
        window_size: Optional[Tuple[int, int]] = None,
4107
4108
        tp_group: Optional[dist_group_type] = None,
        tp_size: int = 1,
4109
        num_gqa_groups: Optional[int] = None,
4110
4111
4112
        fuse_wgrad_accumulation: bool = False,
        get_rng_state_tracker: Optional[Callable] = None,
        sequence_parallel: bool = False,
4113
        params_dtype: Optional[torch.dtype] = None,
4114
        return_bias: bool = False,
4115
4116
4117
4118
4119
4120
4121
4122
4123
        return_layernorm_output: bool = False,
        input_layernorm: bool = False,
        attention_type: str = "self",
        set_parallel_mode: bool = False,
        fuse_qkv_params: bool = False,
        zero_centered_gamma: bool = False,
        qkv_weight_interleaved: bool = True,
        ub_bulk_wgrad: bool = False,
        ub_bulk_dgrad: bool = False,
Jaemin Choi's avatar
Jaemin Choi committed
4124
        ub_overlap_rs_dgrad: bool = False,
4125
4126
        ub_overlap_rs: bool = False,
        ub_overlap_ag: bool = False,
4127
        bias: bool = True,
4128
        normalization: str = "LayerNorm",
4129
        device: Union[torch.device, str] = "cuda",
4130
        qkv_format: str = "sbhd",
4131
4132
    ) -> None:
        super().__init__()
4133

4134
        self.qkv_format = qkv_format
4135
        self.attn_mask_type = attn_mask_type
4136
4137
        self.window_size = window_size
        self.window_size = check_set_window_size(attn_mask_type, self.window_size)
4138
        self.layer_number = layer_number
4139
4140
4141
4142
4143
        self.input_layernorm = input_layernorm
        self.attention_type = attention_type
        self.get_rng_state_tracker = get_rng_state_tracker
        self.tp_group = tp_group
        self.return_layernorm_output = return_layernorm_output
4144
        self.params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
4145
        self.num_attention_heads = num_attention_heads
4146
4147
4148
4149
4150
4151
4152
4153
        self.return_bias = return_bias

        kv_channels = kv_channels if kv_channels else (hidden_size // num_attention_heads)

        if init_method is None:
            init_method = get_default_init_method()
        if output_layer_init_method is None:
            output_layer_init_method = get_default_init_method()
4154
4155
4156
4157
4158

        if not fuse_qkv_params:
            qkv_weight_interleaved = False
        self.qkv_weight_interleaved = qkv_weight_interleaved

4159
4160
4161
        assert attention_type in AttnTypes, f"attention_type {attention_type} not supported"
        if layer_number is not None:
            assert layer_number > 0, "layer_number must be a positive integer"
4162
4163
4164
4165
4166
4167
4168

        tp_size = tp_size if tp_group is None else get_distributed_world_size(tp_group)
        self.tp_size = tp_size
        self.sequence_parallel = (tp_size > 1) and sequence_parallel

        self.hidden_size_per_attention_head = kv_channels
        self.num_attention_heads_per_partition = divide(num_attention_heads, tp_size)
4169
4170
4171
4172
        self.num_gqa_groups = (
            num_attention_heads if num_gqa_groups is None else num_gqa_groups
        )
        assert (num_attention_heads % self.num_gqa_groups == 0
cyanguwa's avatar
cyanguwa committed
4173
4174
                ), "The number of attention heads must be divisible by the number of GQA groups!"
        assert (self.num_gqa_groups % tp_size == 0
4175
4176
4177
                ), "The number of GQA groups must be divisible by tensor parallel size!"
        self.num_gqa_groups_per_partition = int(self.num_gqa_groups // tp_size)
        self.hidden_size_kv = int(hidden_size * self.num_gqa_groups // num_attention_heads)
4178
4179
4180
4181
4182
4183
4184

        common_gemm_kwargs = {
            "fuse_wgrad_accumulation": fuse_wgrad_accumulation,
            "tp_group": tp_group,
            "tp_size": tp_size,
            "get_rng_state_tracker": get_rng_state_tracker,
            "sequence_parallel": sequence_parallel,
4185
            "params_dtype": self.params_dtype,
4186
            "device": device,
4187
4188
4189
4190
        }

        qkv_parallel_mode = "column" if set_parallel_mode else None

cyanguwa's avatar
cyanguwa committed
4191
        if self.attention_type == "self":
4192
4193
4194
4195
4196
4197
4198
            parameters_split = None
            if not fuse_qkv_params:
                parameters_split = collections.OrderedDict([
                    ("query", hidden_size),
                    ("key", self.hidden_size_kv),
                    ("value", self.hidden_size_kv),
                ])
4199
4200
4201
            if self.input_layernorm:
                self.layernorm_qkv = LayerNormLinear(
                    hidden_size,
cyanguwa's avatar
cyanguwa committed
4202
                    hidden_size + 2 * self.hidden_size_kv,
4203
4204
4205
4206
4207
4208
                    eps=layernorm_epsilon,
                    init_method=init_method,
                    bias=bias,
                    return_bias=False,
                    parallel_mode=qkv_parallel_mode,
                    return_layernorm_output=return_layernorm_output,
cyanguwa's avatar
cyanguwa committed
4209
                    parameters_split=parameters_split,
4210
4211
4212
                    zero_centered_gamma=zero_centered_gamma,
                    ub_bulk_wgrad=ub_bulk_wgrad,
                    ub_bulk_dgrad=ub_bulk_dgrad,
Jaemin Choi's avatar
Jaemin Choi committed
4213
                    ub_overlap_rs_dgrad=ub_overlap_rs_dgrad,
4214
                    ub_overlap_ag=ub_overlap_ag,
4215
                    normalization=normalization,
Przemyslaw Tredak's avatar
Przemyslaw Tredak committed
4216
                    ub_name="qkv",
4217
4218
4219
4220
4221
                    **common_gemm_kwargs,
                )
            else:
                self.qkv = Linear(
                    hidden_size,
cyanguwa's avatar
cyanguwa committed
4222
                    hidden_size + 2 * self.hidden_size_kv,
4223
4224
4225
4226
                    init_method=init_method,
                    bias=bias,
                    return_bias=False,
                    parallel_mode=qkv_parallel_mode,
cyanguwa's avatar
cyanguwa committed
4227
                    parameters_split=parameters_split,
4228
4229
                    **common_gemm_kwargs,
                )
cyanguwa's avatar
cyanguwa committed
4230
        elif self.attention_type == "cross":
4231
4232
4233
4234
4235
4236
4237
4238
4239
            if self.input_layernorm:
                self.layernorm_query = LayerNormLinear(
                    hidden_size,
                    hidden_size,
                    eps=layernorm_epsilon,
                    init_method=init_method,
                    bias=bias,
                    return_bias=False,
                    parallel_mode=qkv_parallel_mode,
4240
                    parameters_split=("query",) if not fuse_qkv_params else None,
4241
4242
4243
4244
                    return_layernorm_output=return_layernorm_output,
                    zero_centered_gamma=zero_centered_gamma,
                    ub_bulk_wgrad=ub_bulk_wgrad,
                    ub_bulk_dgrad=ub_bulk_dgrad,
Jaemin Choi's avatar
Jaemin Choi committed
4245
                    ub_overlap_rs_dgrad=ub_overlap_rs_dgrad,
4246
                    ub_overlap_ag=ub_overlap_ag,
4247
                    normalization=normalization,
Przemyslaw Tredak's avatar
Przemyslaw Tredak committed
4248
                    ub_name="qkv",
4249
4250
4251
4252
4253
4254
4255
4256
4257
4258
4259
4260
4261
4262
                    **common_gemm_kwargs,
                )
            else:
                self.query_layer = Linear(
                    hidden_size,
                    hidden_size,
                    init_method=init_method,
                    bias=bias,
                    return_bias=False,
                    parallel_mode=qkv_parallel_mode,
                    **common_gemm_kwargs,
                )
            self.key_value = Linear(
                hidden_size,
4263
                2 * self.hidden_size_kv,
4264
4265
4266
4267
                init_method=init_method,
                bias=bias,
                return_bias=False,
                parallel_mode=qkv_parallel_mode,
4268
                parameters_split=("key", "value") if not fuse_qkv_params else None,
4269
4270
4271
4272
4273
4274
4275
                **common_gemm_kwargs,
            )

        # Attention.
        self.core_attention = DotProductAttention(
            num_attention_heads,
            kv_channels,
4276
4277
            num_gqa_groups=self.num_gqa_groups,
            attention_dropout=attention_dropout,
4278
            qkv_format=self.qkv_format,
4279
4280
4281
4282
            tp_size=tp_size,
            get_rng_state_tracker=get_rng_state_tracker,
            sequence_parallel=sequence_parallel,
            tp_group=tp_group,
4283
            layer_number=self.layer_number,
4284
            attention_type=self.attention_type,
4285
4286
4287
4288
4289
4290
4291
4292
        )

        # Linear
        self.proj = Linear(
            hidden_size,
            hidden_size,
            init_method=output_layer_init_method,
            bias=bias,
4293
            return_bias=return_bias,
4294
            parallel_mode="row" if set_parallel_mode else None,
4295
4296
            ub_overlap_rs=ub_overlap_rs,
            ub_overlap_ag=ub_overlap_ag,
Przemyslaw Tredak's avatar
Przemyslaw Tredak committed
4297
            ub_name="proj",
4298
4299
4300
4301
4302
            **common_gemm_kwargs,
        )


    def _allocate_memory(
4303
        self, inference_max_sequence_len: int, batch_size: int, dtype: torch.dtype
4304
4305
4306
4307
    ) -> torch.Tensor:
        return torch.empty(
            inference_max_sequence_len,
            batch_size,
4308
            self.num_gqa_groups_per_partition,
4309
            self.hidden_size_per_attention_head,
4310
            dtype=dtype,
4311
4312
4313
4314
            device=torch.cuda.current_device(),
        )

    def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None:
4315
4316
4317
4318
4319
4320
4321
4322
4323
        """
        Set the tensor parallel group for the given
        module before executing the forward pass.

        Parameters
        ----------
        tp_group : ProcessGroup, default = `None`
                  tensor parallel process group.
        """
4324
4325
        self.tp_group = tp_group

4326
    def set_context_parallel_group(
4327
4328
        self,
        cp_group: Union[dist_group_type, None],
4329
        cp_global_ranks: List[int],
4330
4331
        cp_stream: torch.cuda.Stream,
    ) -> None:
4332
4333
4334
4335
4336
4337
4338
4339
4340
4341
4342
4343
4344
        """
        Set the context parallel attributes for the given
        module before executing the forward pass.

        Parameters
        ----------
        cp_group : ProcessGroup
                  context parallel process group.
        cp_global_ranks : List[int]
                         list of global ranks in the context group.
        cp_stream : torch.cuda.Stream
                   cuda stream for context parallel execution.
        """
4345
4346
4347
4348
4349
4350
        # Deep iterate but skip self to avoid infinite recursion.
        for index, child in enumerate(self.modules()):
            if index == 0:
                continue
            if hasattr(child, "set_context_parallel_group"):
                child.set_context_parallel_group(cp_group, cp_global_ranks, cp_stream)
4351

4352
4353
4354
    def forward(
        self,
        hidden_states: torch.Tensor,
4355
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
4356
        encoder_output: Optional[torch.Tensor] = None,
4357
        attn_mask_type: Optional[str] = None,
4358
        window_size: Optional[Tuple[int, int]] = None,
4359
4360
        is_first_microbatch: Optional[bool] = None,
        checkpoint_core_attention: bool = False,
4361
        inference_params: Optional[InferenceParams] = None,
4362
        rotary_pos_emb: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
4363
4364
        core_attention_bias_type: str = "no_bias",
        core_attention_bias: Optional[torch.Tensor] = None,
4365
        alibi_slopes: Optional[torch.Tensor] = None,
4366
        fast_zero_fill: bool = True,
4367
    ) -> Tuple[Union[torch.Tensor, None], ...]:
4368
4369
4370
4371
4372
        """
        Forward propagation for MultiheadAttention layer.

        .. note::

4373
4374
            Argument :attr:`attention_mask` is only used when :attr:`attn_mask_type`
            includes `"padding"` or `"arbitrary"`.
4375
4376
4377
4378
4379

        Parameters
        ----------
        hidden_states : torch.Tensor
             Input tensor.
4380
4381
4382
4383
4384
4385
4386
4387
4388
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]],
             default = `None`. Boolean tensor(s) used to mask out attention softmax input.
             It should be 'None' for 'causal' and 'no_mask' types. For 'padding' masks, it should be
             a single tensor of [batch_size, 1, 1, seqlen_q] for self-attention, and a tuple of
             two tensors in shapes [batch_size, 1, 1, seqlen_q] and [batch_size, 1, 1, seqlen_kv]
             for cross-attention. For the 'arbitrary' mask type, it should be in a shape that is
             broadcastable to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv].
        attn_mask_type: {'no_mask', 'padding', 'causal', 'padding_causal', 'arbitrary'},
                       default = `None`
4389
                       type of attention mask passed into softmax operation.
4390
4391
        window_size: Optional[Tuple[int, int]], default = `None`
                    sliding window size for local attention.
4392
4393
4394
4395
4396
4397
4398
4399
4400
4401
4402
4403
4404
4405
4406
4407
4408
4409
4410
4411
4412
4413
4414
4415
4416
        encoder_output : Optional[torch.Tensor], default = `None`
             Output of the encoder block to be fed into the decoder block if using
             `layer_type="decoder"`.
        is_first_microbatch : {True, False, None}, default = None
                             During training using either gradient accumulation or
                             pipeline parallelism a minibatch of data is further split
                             into microbatches. Between the microbatches of the same minibatch
                             the model weights are not updated. Setting this parameter indicates
                             whether the current microbatch is the first in a minibatch or not.
                             When set, this parameter enables additional optimizations:

                             * during FP8 training, it allows caching of the FP8 versions of
                               the weights
                             * it also allows skipping gradient accumulation during the
                               first microbatch (since it is the first gradient being
                               produced)
        checkpoint_core_attention: bool, default = `False`
                                  If true, forward activations for core attention are recomputed
                                  during the backward pass in order to save memory that would
                                  otherwise be occupied to store the forward activations until
                                  backprop.
        rotary_pos_emb: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], default = `None`
                       Embeddings for query and key tensors for applying rotary position
                       embedding. By default no input embedding is applied.
        core_attention_bias_type: str, default = `no_bias`
4417
                    Bias type, {`no_bias`, `pre_scale_bias`, 'post_scale_bias`, `alibi`}
4418
        core_attention_bias: Optional[torch.Tensor], default = `None`
4419
4420
                    Bias tensor for Q * K.T, shape [1, num_head, max_seqlen_q, max_seqlen_kv].
                    It should be 'None' for 'no_bias' and 'alibi' bias types.
4421
4422
4423
4424
        alibi_slopes: Optional[torch.Tensor], default = `None`
                     ALiBi slopes in FP32 and shape [nheads] or [batch_size, nheads].
                     It adds a bias of (-alibi_slope * (i + seqlen_k - seqlen_q - j))
                     to the attention score of query i and key j.
4425
4426
4427
        fast_zero_fill: bool, default = `True`
                    Whether to set output tensors to 0 or not before use.
        """
4428
4429
        # hidden_states: [sq, b, h]

4430
4431
        if attn_mask_type is not None:
            window_size = check_set_window_size(attn_mask_type, window_size)
4432
        if attn_mask_type is None:
4433
            attn_mask_type = self.attn_mask_type
4434
4435
        if window_size is None:
            window_size = self.window_size
4436

4437
4438
4439
4440
4441
        if "padding" in attn_mask_type and attention_mask is not None:
            for i,_ in enumerate(attention_mask):
                assert (
                    attention_mask[i].dtype == torch.bool
                ), "Attention mask must be in boolean type!"
4442

4443
4444
        assert (core_attention_bias_type in AttnBiasTypes
                ), f"core_attention_bias_type {core_attention_bias_type} is not supported!"
4445

4446
        # =================================================
4447
        # Pre-allocate memory for key-values for inference
4448
4449
4450
4451
        # =================================================

        if inference_params and self.layer_number is not None:
            if self.layer_number not in inference_params.key_value_memory_dict:
4452
                inf_max_seq_len = inference_params.max_sequence_length
4453
4454
                inf_max_batch_size = inference_params.max_batch_size
                inference_key_memory = self._allocate_memory(
4455
                    inf_max_seq_len, inf_max_batch_size, hidden_states.dtype
4456
4457
                )
                inference_value_memory = self._allocate_memory(
4458
                    inf_max_seq_len, inf_max_batch_size, hidden_states.dtype
4459
4460
4461
4462
4463
4464
4465
4466
4467
4468
4469
                )
                inference_params.key_value_memory_dict[self.layer_number] = (
                    inference_key_memory,
                    inference_value_memory,
                )
            else:
                (
                    inference_key_memory,
                    inference_value_memory,
                ) = inference_params.key_value_memory_dict[self.layer_number]

4470
        # ======================
4471
        # Query, Key, and Value
4472
        # ======================
4473

cyanguwa's avatar
cyanguwa committed
4474
4475
        if self.attention_type == "self":
            # Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn]
4476
4477
4478
4479
4480
4481
4482
4483
4484
4485
4486
4487
4488
            if self.input_layernorm:
                layernorm_qkv_outputs = self.layernorm_qkv(
                    hidden_states,
                    is_first_microbatch=is_first_microbatch,
                )
                if self.return_layernorm_output:
                    mixed_x_layer, layernorm_output = layernorm_qkv_outputs
                else:
                    mixed_x_layer = layernorm_qkv_outputs
            else:
                mixed_x_layer = self.qkv(
                    hidden_states,
                    is_first_microbatch=is_first_microbatch,
4489
                    is_first_module_in_mha=True, # specific to FP8 MHA
4490
4491
                )

cyanguwa's avatar
cyanguwa committed
4492
4493
            num_queries_per_key_value = (self.num_attention_heads_per_partition //
                                         self.num_gqa_groups_per_partition)
4494
            if self.qkv_weight_interleaved:
cyanguwa's avatar
cyanguwa committed
4495
                # [sq, b, ng * (np/ng + 2) * hn] --> [sq, b, ng, (np/ng + 2), hn]
4496
                new_tensor_shape = mixed_x_layer.size()[:-1] + (
cyanguwa's avatar
cyanguwa committed
4497
4498
                    self.num_gqa_groups_per_partition,
                    (num_queries_per_key_value + 2),
4499
4500
4501
4502
                    self.hidden_size_per_attention_head,
                )
                # split along second last dimension
                split_dim = -2
cyanguwa's avatar
cyanguwa committed
4503
4504
4505
4506
4507
4508
4509
4510
4511
            else:
                # [sq, b, ng * (np/ng + 2) * hn] --> [sq, b, (np/ng + 2), ng, hn]
                new_tensor_shape = mixed_x_layer.size()[:-1] + (
                    (num_queries_per_key_value + 2),
                    self.num_gqa_groups_per_partition,
                    self.hidden_size_per_attention_head
                )
                # split along third last dimension
                split_dim = -3
4512
4513
4514

            mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)

cyanguwa's avatar
cyanguwa committed
4515
4516
4517
4518
4519
4520
4521
4522
4523
            # qkv_weight_interleaved:
            #  [sq, b, ng, (np/ng + 2), hn]
            #  --> [sq, b, ng, np/ng, hn], [sq, b, ng, 1, hn], [sq, b, ng, 1, hn]
            # not qkv_weight_interleaved:
            #  [sq, b, (np/ng + 2), ng, hn]
            #  --> [sq, b, np/ng, np, hn], [sq, b, 1, ng, hn], [sq, b, 1, ng, hn]
            if not is_in_onnx_export_mode():
                query_layer, key_layer, value_layer = _SplitAlongDim.apply(
                    mixed_x_layer, split_dim, (num_queries_per_key_value, 1, 1)
4524
                )
4525
            else:
cyanguwa's avatar
cyanguwa committed
4526
4527
4528
4529
4530
4531
4532
4533
4534
4535
4536
4537
                query_layer, key_layer, value_layer = torch.split(
                    mixed_x_layer, (num_queries_per_key_value, 1, 1), dim = split_dim,
                 )

            # query: -> [sq, b, np, hn]
            # key, value: -> [sq, b, ng, hn]
            query_layer, key_layer, value_layer = (x.reshape(x.size(0), x.size(1), -1,
                                                             self.hidden_size_per_attention_head)
                                                   for x in (query_layer, key_layer, value_layer))

        elif self.attention_type == "cross":
            # Attention heads [sk, b, h] --> [sk, b, (ng * 2 * hn)]
4538
            mixed_kv_layer = self.key_value(
cyanguwa's avatar
cyanguwa committed
4539
                encoder_output,
4540
                is_first_microbatch=is_first_microbatch,
4541
                is_first_module_in_mha=True, # specific to FP8 MHA
4542
4543
4544
            )

            if self.qkv_weight_interleaved:
cyanguwa's avatar
cyanguwa committed
4545
                # [sq, b, (ng * 2 * hn)] --> [sq, b, ng, 2 * hn]
4546
                new_tensor_shape = mixed_kv_layer.size()[:-1] + (
4547
                    self.num_gqa_groups_per_partition,
4548
4549
4550
4551
4552
                    2 * self.hidden_size_per_attention_head,
                )
                # split along last dimension
                split_dim = -1
            else:
cyanguwa's avatar
cyanguwa committed
4553
                # [sq, b, (ng * 2 * hn)] --> [sq, b, 2 * ng, hn]
4554
                new_tensor_shape = mixed_kv_layer.size()[:-1] + (
4555
                    2 * self.num_gqa_groups_per_partition,
4556
4557
4558
4559
4560
4561
4562
                    self.hidden_size_per_attention_head,
                )
                # split along second last dimension
                split_dim = -2

            mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)

cyanguwa's avatar
cyanguwa committed
4563
4564
4565
4566
4567
            # mixed_kv_layer --> 2 [sk, b, ng, hn]
            if not is_in_onnx_export_mode():
                key_layer, value_layer = _SplitAlongDim.apply(
                    mixed_kv_layer, split_dim, mixed_kv_layer.shape[split_dim] // 2,
                )
4568
            else:
cyanguwa's avatar
cyanguwa committed
4569
4570
4571
                key_layer, value_layer = torch.split(
                    mixed_kv_layer, mixed_kv_layer.shape[split_dim] // 2, dim = split_dim,
                )
4572
4573
4574
            key_layer, value_layer = (x.reshape(
                x.size(0), x.size(1), -1, self.hidden_size_per_attention_head,
                ) for x in (key_layer, value_layer))
4575
4576
4577
4578
4579
4580
4581
4582
4583
4584
4585
4586
4587
4588
4589

            # Attention head [sq, b, h] --> [sq, b, hp]
            if self.input_layernorm:
                layernorm_query_outputs = self.layernorm_query(
                    hidden_states,
                    is_first_microbatch=is_first_microbatch,
                )
                if self.return_layernorm_output:
                    query_layer, layernorm_output = layernorm_query_outputs
                else:
                    query_layer = layernorm_query_outputs
            else:
                query_layer = self.query_layer(
                    hidden_states,
                    is_first_microbatch=is_first_microbatch,
4590
                    is_first_module_in_mha=True, # specific to FP8 MHA
4591
4592
4593
4594
4595
4596
4597
4598
4599
                )

            # [sq, b, hp] --> [sq, b, np, hn]
            new_tensor_shape = query_layer.size()[:-1] + (
                self.num_attention_heads_per_partition,
                self.hidden_size_per_attention_head,
            )
            query_layer = query_layer.view(*new_tensor_shape)

4600
4601
4602
        # ======================================================
        # Apply relative positional encoding (rotary embedding)
        # ======================================================
4603

4604
        if rotary_pos_emb is not None:
4605
4606
4607
            assert (not isinstance(query_layer, Float8Tensor)
                and not isinstance(key_layer, Float8Tensor)
                ), "RoPE is not supported for Float8Tensors!"
4608
            # duplicate the pos_emb for self attention
4609
4610
4611
4612
            if not isinstance(rotary_pos_emb, tuple):
                rotary_pos_emb = ((rotary_pos_emb,) * 2)

            q_pos_emb, k_pos_emb = rotary_pos_emb
4613
4614
4615
4616
4617
4618
4619
4620
4621
4622
4623
4624
4625
4626

            # adjust key and value for inference
            if inference_params is not None:
                if self.qkv_format == "sbhd":
                    sequence_length = key_layer.size(0)
                elif self.qkv_format == "bshd":
                    sequence_length = key_layer.size(1)

                sequence_start = inference_params.sequence_len_offset
                sequence_end = sequence_start + sequence_length

                q_pos_emb = q_pos_emb[sequence_start:sequence_end, ...]
                k_pos_emb = k_pos_emb[sequence_start:sequence_end, ...]

4627
4628
            query_layer = apply_rotary_pos_emb(query_layer, q_pos_emb, self.qkv_format, fused=True)
            key_layer = apply_rotary_pos_emb(key_layer, k_pos_emb, self.qkv_format, fused=True)
4629

4630
4631
4632
4633
        # ===========================
        # Core attention computation
        # ===========================

4634
4635
4636
4637
        context_layer = self.core_attention(
            query_layer,
            key_layer,
            value_layer,
4638
            qkv_format=self.qkv_format,
4639
4640
            cu_seqlens_q=None,
            cu_seqlens_kv=None,
4641
4642
            attention_mask=attention_mask,
            attn_mask_type=attn_mask_type,
4643
            window_size=window_size,
4644
4645
4646
            checkpoint_core_attention=checkpoint_core_attention,
            core_attention_bias_type=core_attention_bias_type,
            core_attention_bias=core_attention_bias,
4647
            alibi_slopes=alibi_slopes,
4648
            fast_zero_fill=fast_zero_fill,
4649
            inference_params=inference_params,
4650
4651
        )

4652
        # ===================
4653
        # Output. [sq, b, h]
4654
        # ===================
4655

4656
        projection_output = self.proj(
4657
4658
            context_layer,
            is_first_microbatch=is_first_microbatch,
4659
4660
        )

4661
4662
4663
4664
4665
4666
4667
4668
        if self.return_bias:
            attention_output, attention_bias = projection_output
        else:
            attention_output, attention_bias = projection_output, None

        outputs = (attention_output,)
        if self.return_bias:
            outputs += (attention_bias,)
4669
        if self.input_layernorm and self.return_layernorm_output:
4670
4671
            outputs += (layernorm_output,)
        return outputs if len(outputs) > 1 else outputs[0]