"cacheflow/vscode:/vscode.git/clone" did not exist on "84eee24e20ff4c0fc1b126289265f560089efa47"
attention.py 169 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.

"""Attention."""
6
import collections
7
8
9
10
from contextlib import nullcontext
import functools
from importlib.metadata import version
import math
11
import os
12
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
13
import warnings
14

cyanguwa's avatar
cyanguwa committed
15
import numpy as np
16
from pkg_resources import packaging
17
18

import torch
19
import torch.nn.functional as F
20
21

import transformer_engine_extensions as tex
22
23
24
25
26
from transformer_engine.pytorch.cpp_extensions.fused_attn import (
    fused_attn_fwd_qkvpacked,
    fused_attn_bwd_qkvpacked,
    fused_attn_fwd_kvpacked,
    fused_attn_bwd_kvpacked,
27
28
    fused_attn_fwd,
    fused_attn_bwd,
29
30
31
32
33
    QKVLayout,
    AttnBiasType,
    AttnMaskType,
    FusedAttnBackend,
)
34
35
36
37
38
39
from transformer_engine.pytorch.module import LayerNormLinear, Linear
from transformer_engine.pytorch.utils import (
    divide,
    attention_mask_func,
    split_tensor_along_dim,
    get_device_compute_capability,
40
    get_default_init_method,
41
42
43
44
)
from transformer_engine.pytorch.constants import (
    AttnMaskTypes,
    AttnTypes,
45
    AttnBiasTypes,
46
    QKVLayouts,
47
    dist_group_type,
48
    TE_DType,
49
50
51
52
)
from transformer_engine.pytorch.softmax import FusedScaleMaskSoftmax
from transformer_engine.pytorch.distributed import (
    get_distributed_world_size,
53
    get_distributed_rank,
54
55
56
    checkpoint,
)
from transformer_engine.pytorch.export import is_in_onnx_export_mode
57
from transformer_engine.pytorch.jit import jit_fuser, no_torch_dynamo
58
59

_flash_attn_version = packaging.version.Version(version("flash-attn"))
60
_flash_attn_version_required = packaging.version.Version("2.0.6")
61
_flash_attn_max_version = packaging.version.Version("2.5.6")
62
_flash_attn_2_1_plus = _flash_attn_version >= packaging.version.Version("2.1")
63
_flash_attn_2_3_plus = _flash_attn_version >= packaging.version.Version("2.3")
64
65
_flash_attn_2_4_plus = _flash_attn_version >= packaging.version.Version("2.4")
_flash_attn_2_4_1_plus = _flash_attn_version >= packaging.version.Version("2.4.1")
66

67
if _flash_attn_version >= _flash_attn_version_required:
68
    from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_forward_func # pylint: disable=no-name-in-module
69
    from flash_attn_2_cuda import varlen_bwd as flash_attn_cuda_bwd # pylint: disable=no-name-in-module
70
71
    from flash_attn.flash_attn_interface import _flash_attn_varlen_forward as _flash_attn_forward # pylint: disable=no-name-in-module,ungrouped-imports
    from flash_attn.flash_attn_interface import _flash_attn_varlen_backward as _flash_attn_backward # pylint: disable=no-name-in-module
72
73


74
_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
75
76
77
78
79
80
81
82
83
_alibi_cache = {
    "_num_heads": None,
    "_alibi_slopes": None,
    "_max_seqlen_q": None,
    "_max_seqlen_kv": None,
    "_alibi_bias": None,
    "_alibi_slopes_require_update": False,
    "_alibi_bias_require_update": False,
    }
84
85


86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
__all__ = ["DotProductAttention", "InferenceParams", "MultiheadAttention"]

class InferenceParams: # pylint: disable=too-few-public-methods
    """
    Inference parameters that are passed to the main model in order
    to efficienly calculate and store the context during inference.

    Parameters
    ----------
    max_batch_size : int
                    maximum batch size during inference.
    max_sequence_length : int
                         maximum sequence length during inference.
    """

    def __init__(self, max_batch_size, max_sequence_length):
        self.max_sequence_length = max_sequence_length
        self.max_batch_size = max_batch_size
        self.sequence_len_offset = 0
        self.batch_size_offset = 0
        self.key_value_memory_dict = {}

    def swap_key_value_dict(self, batch_indices):
        """
        Reorders the KV cache using the specified batch indices.

        Parameters
        ----------
        batch_indices : List[int]
                       Sequence of indices to reorder along the batch dimensions of
                       the KV cache. Must have a length equal to the batch size.
        """
        if len(self.key_value_memory_dict) == 0:
            raise ValueError("should not swap when dict in empty")

        for layer_number, inference_memory in self.key_value_memory_dict.items():
            inference_key_memory, inference_value_memory = inference_memory
            assert (
                len(batch_indices) == inference_key_memory.shape[1]
            )  # make sure batch size is the same
            new_inference_key_memory = inference_key_memory[:, batch_indices]
            new_inference_value_memory = inference_value_memory[:, batch_indices]
            self.key_value_memory_dict[layer_number] = (
                new_inference_key_memory,
                new_inference_value_memory,
            )
132

133
134
135
136
137
@torch.no_grad()
def get_alibi(
    num_heads: int,
    max_seqlen_q: int,
    max_seqlen_kv: int,
138
139
140
    alibi_slopes: Optional[torch.Tensor] = None,
    bias_dtype: Optional[torch.dtype] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
141
    """
142
143
144
145
146
147
148
149
150
151
152
153
    Parameters
    ----------
    num_heads: int
        Number of heads.
    max_seqlen_q: int
        Maximum sequence length for queries.
    max_seqlen_kv: int
        Maximum sequence length for keys and values.
    alibi_slopes: Optional[torch.Tensor], default = `None`
        Custom ALiBi slopes, FP32, CUDA tensor, in shape [num_heads] or [batch_size, num_heads].
    bias_dtype: Optional[torch.dtype], default = `None`
        Dtype of the generated ALiBi bias. If None, use torch.float32.
154

155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
    Returns
    ----------
    alibi_slopes: torch.Tensor
        ALiBi slopes in FP32 and shape [num_heads] or [batch_size, num_heads].
    alibi_bias: torch.Tensor
        ALiBi bias in FP32 or `bias_dtype`. If `alibi_slopes` is in [num_heads] shape,
        then `alibi_bias` is in [1, num_heads, max_seqlen_q, max_seqlen_kv], and if
        `alibi_slopes` is in [batch_size, num_heads], then the bias is in
        [batch_size, num_heads, max_seqlen_q, max_seqlen_kv].
    """
    global _alibi_cache
    if _alibi_cache["_alibi_slopes_require_update"]:
        if alibi_slopes is not None:
            _alibi_cache["_alibi_slopes"] = alibi_slopes
        else:
            n = 2 ** math.floor(math.log2(num_heads))
            m_0 = 2.0 ** (-8.0 / n)
            m = torch.pow(m_0, torch.arange(1, 1 + n))

            if n < num_heads:
                m_hat_0 = 2.0 ** (-4.0 / n)
                m_hat = torch.pow(m_hat_0, torch.arange(1, 1 + 2 * (num_heads - n), 2))
                m = torch.cat([m, m_hat])

            _alibi_cache["_alibi_slopes"] = m.to(dtype=torch.float32, device="cuda")
        _alibi_cache["_num_heads"] = num_heads
        _alibi_cache["_alibi_slopes_require_update"] = False

    if _alibi_cache["_alibi_bias_require_update"]:
        assert _alibi_cache["_alibi_slopes"] is not None, "ALiBi slopes can not be None!"
        if _alibi_cache["_alibi_slopes"].dim() == 1:
            slopes_shape = torch.Size([1, _alibi_cache["_alibi_slopes"].shape[0], 1, 1])
        if _alibi_cache["_alibi_slopes"].dim() == 2:
            slopes_shape = torch.Size([*_alibi_cache["_alibi_slopes"].shape[:], 1, 1])
        bias = torch.arange(
            1 - max_seqlen_kv, 1, dtype=torch.int32, device="cuda").view(1, 1, 1, max_seqlen_kv)
        bias = bias - torch.arange(
            1 - max_seqlen_q, 1, dtype=torch.int32, device="cuda").view(1, 1, max_seqlen_q, 1)
        bias = bias.abs().mul(-1)
        bias = bias * _alibi_cache["_alibi_slopes"].view(slopes_shape)
        _alibi_cache["_max_seqlen_q"], _alibi_cache["_max_seqlen_kv"] = max_seqlen_q, max_seqlen_kv
        bias_dtype = torch.float32 if bias_dtype is None else bias_dtype
        _alibi_cache["_alibi_bias"] = bias.contiguous().to(dtype=bias_dtype, device="cuda")
        _alibi_cache["_alibi_bias_require_update"] = False

    return _alibi_cache["_alibi_slopes"], _alibi_cache["_alibi_bias"]
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216


def get_cu_seqlens(mask: torch.Tensor) -> torch.Tensor:
    """
    Given a padding mask of shape [batch_size, 1, 1, max_seqlen], returns an int32
    tensor of shape [batch_size + 1] containing the cumulative sequence lengths of
    the samples in a batch.
    """
    mask = mask.squeeze(1).squeeze(1)
    reduced_mask = mask.sum(dim=1)
    cu_seqlens = reduced_mask.cumsum(dim=0).to(torch.int32)
    zero = torch.zeros(1, dtype=torch.int32, device="cuda")
    cu_seqlens = torch.cat((zero, cu_seqlens))

    return cu_seqlens

217

218
219
220
def get_cu_seqlens_and_indices(mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Given a padding mask of shape [batch_size, 1, 1, max_seqlen], returns an int32
221
222
223
    tensor of shape [batch_size + 1] containing the cumulative sequence lengths of
    the samples in a batch, and another int32 tensor of shape [batch_size * max_seqlen, 1, 1]
    containing the indices for the valid tokens.
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
    """
    mask = mask.squeeze(1).squeeze(1)
    bs, seqlen = mask.shape

    reduced_mask = mask.sum(dim=1)
    cu_seqlens = reduced_mask.cumsum(dim=0).to(torch.int32)
    zero = torch.zeros(1, dtype=torch.int32, device="cuda")
    cu_seqlens = torch.cat((zero, cu_seqlens))

    mask = mask.reshape(-1)
    indices = mask.nonzero()
    indices = indices.unsqueeze(-1)

    num_nonzeros = indices.shape[0]
    pad_amount = bs * seqlen - num_nonzeros
    indices = F.pad(input=indices, pad=(0, 0, 0, 0, 0, pad_amount),
                    mode="constant", value=float(bs * seqlen))

    return cu_seqlens, indices


245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
def get_indices(max_seqlen: int, cu_seqlens: torch.Tensor) -> torch.Tensor:
    """
    Given max_seqlen and cu_seqlens of shape [batch_size + 1], returns an int32
    tensor of shape [batch_size * max_seqlen, 1, 1] containing the indices for
    the valid tokens in a batch.
    """
    bs = len(cu_seqlens) - 1
    seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
    indices = [i*max_seqlen + ii for i,j in enumerate(seqlens) for ii in range(j)]
    indices = torch.Tensor(indices).unsqueeze(1).unsqueeze(1).to(
                    dtype=torch.int64, device="cuda")

    num_nonzeros = indices.shape[0]
    pad_amount = bs * max_seqlen - num_nonzeros
    indices = F.pad(input=indices, pad=(0, 0, 0, 0, 0, pad_amount),
                    mode="constant", value=float(bs * max_seqlen))

    return indices


265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
@functools.lru_cache
def _get_full_cu_seqlens(
    batch_size: int,
    max_seqlen: int,
    device: torch.device,
) -> torch.Tensor:
    """Cumulative sequence lengths in full data batch

    All sequences in batch have the maximum sequence length.

    """
    return torch.arange(
        0,
        (batch_size + 1) * max_seqlen,
        step=max_seqlen,
        dtype=torch.int32,
        device=device,
    )


285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
@jit_fuser
def pack_tensor(
    indices: torch.Tensor,
    tensor: torch.Tensor,
) -> torch.Tensor:
    """
    Packs the given tensor using the `indices`.
    """
    padding_indice = torch.zeros(
        1, tensor.shape[1], tensor.shape[2], dtype=tensor.dtype, device=tensor.device)
    tensor = torch.cat((tensor, padding_indice), dim=0)

    indices = indices.repeat(1, tensor.shape[1], tensor.shape[2])
    packed = torch.gather(tensor, 0, indices)
    return packed


@jit_fuser
def pack_2_tensors(
    indices: torch.Tensor,
    t1: torch.Tensor,
    t2: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Packs the given 2 tensors using the `indices`.
    """
    t1_packed = pack_tensor(indices, t1)
    t2_packed = pack_tensor(indices, t2)
    return t1_packed, t2_packed


@jit_fuser
def pack_3_tensors(
    indices: torch.Tensor,
    t1: torch.Tensor,
    t2: torch.Tensor,
    t3: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Packs the given 3 tensors using the `indices`.
    """
    t1_packed = pack_tensor(indices, t1)
    t2_packed = pack_tensor(indices, t2)
    t3_packed = pack_tensor(indices, t3)
    return t1_packed, t2_packed, t3_packed


@jit_fuser
def unpack_tensor(
    indices: torch.Tensor,
    dim0: int,
    tensor: torch.Tensor,
) -> torch.Tensor:
    """
    Inverse of `pack_tensor`.
    """
    indices = indices.repeat(1, tensor.shape[1], tensor.shape[2])
    unpacked = torch.zeros(
        dim0 + 1, tensor.shape[1], tensor.shape[2], dtype=tensor.dtype, device=tensor.device)
    unpacked.scatter_(0, indices, tensor)
    unpacked = unpacked[0:-1,:,:]
    return unpacked


@jit_fuser
def unpack_2_tensors(
    indices: torch.Tensor,
    dim0: int,
    t1: torch.Tensor,
    t2: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Inverse of `pack_2_tensors`.
    """
    t1_unpacked = unpack_tensor(indices, dim0, t1)
    t2_unpacked = unpack_tensor(indices, dim0, t2)
    return t1_unpacked, t2_unpacked


@jit_fuser
def unpack_3_tensors(
    indices: torch.Tensor,
    dim0: int,
    t1: torch.Tensor,
    t2: torch.Tensor,
    t3: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Inverse of `pack_3_tensors`.
    """
    t1_unpacked = unpack_tensor(indices, dim0, t1)
    t2_unpacked = unpack_tensor(indices, dim0, t2)
    t3_unpacked = unpack_tensor(indices, dim0, t3)
    return t1_unpacked, t2_unpacked, t3_unpacked


class PackTensors(torch.autograd.Function):
    """
    Autograd function to pack tensors.
    """
    @staticmethod
    def forward(
        ctx,
        indices: torch.Tensor,
        *tensors: Tuple[torch.Tensor, ...]
    ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
        assert 1 <= len(tensors) <= 3, f"Packing {len(tensors)} tensors not supported."
        ctx.indices = indices
        ctx.dim0 = tensors[0].shape[0]
        if len(tensors) == 1:
            return pack_tensor(indices, *tensors)
        if len(tensors) == 2:
            return pack_2_tensors(indices, *tensors)
        return pack_3_tensors(indices, *tensors)

    @staticmethod
    def backward(ctx, *grad_outputs: Tuple[torch.Tensor, ...]):
        if len(grad_outputs) == 1:
            return None, unpack_tensor(ctx.indices, ctx.dim0, *grad_outputs)
        if len(grad_outputs) == 2:
            return None, *unpack_2_tensors(ctx.indices, ctx.dim0, *grad_outputs)
        return None, *unpack_3_tensors(ctx.indices, ctx.dim0, *grad_outputs)


class UnpackTensor(torch.autograd.Function):
    """
    Autograd function to unpack a tensor.
    """
    @staticmethod
    def forward(
        ctx,
        indices: torch.Tensor,
        dim0: int,
        tensor: torch.Tensor,
    ) -> torch.Tensor:
        ctx.indices = indices
        return unpack_tensor(indices, dim0, tensor)

    @staticmethod
    def backward(ctx, grad_output):
        return None, None, pack_tensor(ctx.indices, grad_output)


428
429
430
def flash_attn_p2p_communicate(rank, send_tensor, send_dst,
                               recv_tensor, recv_src,
                               cp_group, batch_p2p_comm):
431
    """Point-to-point communications of KV and dKV in Attention with context parallelism"""
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
    send_recv_ops = []

    if batch_p2p_comm:
        if rank % 2 == 0:
            send_op = torch.distributed.P2POp(torch.distributed.isend,
                                              send_tensor,
                                              send_dst,
                                              cp_group)
            recv_op = torch.distributed.P2POp(torch.distributed.irecv,
                                              recv_tensor,
                                              recv_src,
                                              cp_group)
            send_recv_ops.append(send_op)
            send_recv_ops.append(recv_op)
        else:
            recv_op = torch.distributed.P2POp(torch.distributed.irecv,
                                              recv_tensor,
                                              recv_src,
                                              cp_group)
            send_op = torch.distributed.P2POp(torch.distributed.isend,
                                              send_tensor,
                                              send_dst,
                                              cp_group)
            send_recv_ops.append(recv_op)
            send_recv_ops.append(send_op)
        send_recv_reqs = torch.distributed.batch_isend_irecv(send_recv_ops)
    else:
        if rank % 2 == 0:
            send_op = torch.distributed.isend(send_tensor, send_dst, cp_group)
            recv_op = torch.distributed.irecv(recv_tensor, recv_src, cp_group)
            send_recv_ops.append(send_op)
            send_recv_ops.append(recv_op)
        else:
            recv_op = torch.distributed.irecv(recv_tensor, recv_src, cp_group)
            send_op = torch.distributed.isend(send_tensor, send_dst, cp_group)
            send_recv_ops.append(recv_op)
            send_recv_ops.append(send_op)
        send_recv_reqs = send_recv_ops

    return send_recv_reqs


474
@jit_fuser
475
def flash_attn_fwd_out_correction(out, out_per_step, softmax_lse, softmax_lse_per_step):
476
    """Merge partial outputs of each step in Attention with context parallelism"""
477
478
479
480
481
482
    softmax_lse_corrected_exp = torch.exp(softmax_lse_per_step - softmax_lse).transpose(1, 2)
    softmax_lse_corrected_exp = softmax_lse_corrected_exp.unsqueeze(-1)
    out_corrected = out_per_step*softmax_lse_corrected_exp
    out.add_(out_corrected)


483
@jit_fuser
484
def flash_attn_fwd_softmax_lse_correction(softmax_lse, softmax_lse_per_step):
485
    """Merge softmax stats of each step in Attention with context parallelism"""
486
487
488
489
490
    softmax_lse.exp_()
    softmax_lse.add_(softmax_lse_per_step.to(torch.double).exp())
    softmax_lse.log_()


491
class AttnFuncWithCP(torch.autograd.Function):
492
    """
493
494
    Attention implementation with context parallelism.
    Split attention compute into multiple steps, and overlap current-step
495
496
497
498
    compute with next-step communication.
    """

    @staticmethod
499
500
501
    def forward(ctx, is_training, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
                dropout_p, cp_group, cp_global_ranks, cp_stream, softmax_scale, attn_mask_type,
                deterministic, use_fused_attention):
502
503
504
505
506
507
508
509
510
        if softmax_scale is None:
            softmax_scale = q.shape[-1] ** (-0.5)

        cp_size = get_distributed_world_size(cp_group)
        rank = get_distributed_rank(cp_group)
        send_dst = cp_global_ranks[(rank + 1) % cp_size]
        recv_src = cp_global_ranks[(rank + cp_size - 1) % cp_size]
        batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2)

511
512
513
514
515
        causal = (attn_mask_type == "causal")

        if causal:
            # [b, s, np, hn] -> [b, 2, s//2, np, hn]
            q, k, v = [x.view(x.shape[0], 2, x.shape[1]//2, *x.shape[2:]) for x in [q, k, v]]
516
        assert(q.shape[-1] % 8 == 0), "hidden size per attention head should be multiple of 8"
517
518
519
520
521
        fa_optional_forward_kwargs = {}
        if _flash_attn_2_3_plus:
            fa_optional_forward_kwargs["window_size"] = [-1, 0] if causal else [-1, -1]
        if _flash_attn_2_4_plus:
            fa_optional_forward_kwargs["alibi_slopes"] = None
522

523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
        # Flash Attn inputs
        q_inputs = [None, None]
        kv_inputs = [None, None]
        # Flash Attn outputs
        out_per_step = [None for _ in range(cp_size)]
        softmax_lse_per_step = [None for _ in range(cp_size)]
        rng_states = [None for _ in range(cp_size)]

        # create two streams to resolve wave quantization issue of Flash Attn in each step
        flash_attn_streams = [torch.cuda.current_stream(), cp_stream]
        # synchronize fwd results correction across steps
        fwd_results_correction_done = torch.cuda.Event()

        p2p_comm_buffers = [None for _ in range(cp_size)]
        p2p_comm_buffers[0] = torch.cat((k.unsqueeze(0), v.unsqueeze(0)), dim=0)
        send_recv_reqs = [[], []]

        for i in range(cp_size+1):
            if i < cp_size:
                with torch.cuda.stream(flash_attn_streams[i%2]):
                    # wait until KV is received
                    for req in send_recv_reqs[(i+1)%2]:
                        req.wait()

                    if i < (cp_size-1):
                        p2p_comm_buffers[i+1] = torch.empty_like(p2p_comm_buffers[i])
                        send_recv_reqs[i%2] = flash_attn_p2p_communicate(rank,
                                                                         p2p_comm_buffers[i],
                                                                         send_dst,
                                                                         p2p_comm_buffers[i+1],
                                                                         recv_src,
                                                                         cp_group,
                                                                         batch_p2p_comm)

                    kv_inputs[i%2] = p2p_comm_buffers[i]
                    if causal:
                        if i == 0:
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
                            if use_fused_attention:
                                # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
                                q_inputs[i%2] = q.view(q.shape[0], -1, *q.shape[-2:])
                                # [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn]
                                kv_inputs[i%2] = kv_inputs[i%2].view(
                                    2, k.shape[0], -1, *k.shape[-2:])
                                out_per_step[i], [softmax_lse_per_step[i], rng_states[i]] = \
                                fused_attn_fwd(
                                    is_training, max_seqlen_q, max_seqlen_k, cu_seqlens_q,
                                    cu_seqlens_k, q_inputs[i%2], kv_inputs[i%2][0],
                                    kv_inputs[i%2][1], TE_DType[q.dtype],
                                    tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
                                    attn_scale=softmax_scale, dropout=dropout_p,
                                    qkv_layout="bshd_bshd_bshd", attn_mask_type="causal",
                                )
                            else:
                                # [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
                                q_inputs[i%2] = q.view(-1, *q.shape[-2:])
                                # [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn]
                                kv_inputs[i%2] = kv_inputs[i%2].view(2, -1, *k.shape[-2:])
                                _, _, _, _, out_per_step[i], \
                                softmax_lse_per_step[i], _, rng_states[i] = _flash_attn_forward(
                                    q_inputs[i%2], kv_inputs[i%2][0], kv_inputs[i%2][1],
                                    cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
                                    dropout_p, softmax_scale, causal=True, return_softmax=False,
                                    **fa_optional_forward_kwargs
                                )
587
                        elif i <= rank:
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
                            if use_fused_attention:
                                # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
                                q_inputs[i%2] = q.view(q.shape[0], -1, *q.shape[-2:])
                                # [2, b, 2, sk//2, np, hn] -> [2, b, sk//2, np, hn]
                                kv_inputs[i%2] = kv_inputs[i%2][:, :, 0, ...].contiguous()
                                out_per_step[i], [softmax_lse_per_step[i], rng_states[i]] = \
                                fused_attn_fwd(
                                    is_training, max_seqlen_q, max_seqlen_k//2, cu_seqlens_q,
                                    cu_seqlens_k//2, q_inputs[i%2], kv_inputs[i%2][0],
                                    kv_inputs[i%2][1], TE_DType[q.dtype],
                                    tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
                                    attn_scale=softmax_scale, dropout=dropout_p,
                                    qkv_layout="bshd_bshd_bshd", attn_mask_type="no_mask",
                                )
                            else:
                                # [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
                                q_inputs[i%2] = q.view(-1, *q.shape[-2:])
                                # [2, b, 2, sk//2, np, hn] -> [2, b, sk//2, np, hn]
                                kv_inputs[i%2] = kv_inputs[i%2][:, :, 0, ...].contiguous()
                                # [2, b, sk//2, np, hn] -> [2, b*sk//2, np, hn]
                                kv_inputs[i%2] = kv_inputs[i%2].view(2, -1, *k.shape[-2:])
                                if _flash_attn_2_3_plus:
                                    fa_optional_forward_kwargs["window_size"] = [-1, -1]
                                _, _, _, _, out_per_step[i], \
                                softmax_lse_per_step[i], _, rng_states[i] = _flash_attn_forward(
                                    q_inputs[i%2], kv_inputs[i%2][0], kv_inputs[i%2][1],
                                    cu_seqlens_q, cu_seqlens_k//2, max_seqlen_q, max_seqlen_k//2,
                                    dropout_p, softmax_scale, causal=False, return_softmax=False,
                                    **fa_optional_forward_kwargs
                                )
                        else:
                            if use_fused_attention:
                                # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
                                q_inputs[i%2] = q[:, 1, ...].contiguous()
                                # [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn]
                                kv_inputs[i%2] = kv_inputs[i%2].view(
                                    2, k.shape[0], -1, *k.shape[-2:])
                                out_per_step[i], [softmax_lse_per_step[i], rng_states[i]] = \
                                fused_attn_fwd(
                                    is_training, max_seqlen_q//2, max_seqlen_k, cu_seqlens_q//2,
                                    cu_seqlens_k, q_inputs[i%2], kv_inputs[i%2][0],
                                    kv_inputs[i%2][1], TE_DType[q.dtype],
                                    tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
                                    attn_scale=softmax_scale, dropout=dropout_p,
                                    qkv_layout="bshd_bshd_bshd", attn_mask_type="no_mask",
                                )
                            else:
                                # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] -> [b*sq//2, np, hn]
                                q_inputs[i%2] = q[:, 1, ...].contiguous().view(-1, *q.shape[-2:])
                                # [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn]
                                kv_inputs[i%2] = kv_inputs[i%2].view(2, -1, *k.shape[-2:])
                                if _flash_attn_2_3_plus:
                                    fa_optional_forward_kwargs["window_size"] = [-1, -1]
                                _, _, _, _, out_per_step[i], \
                                softmax_lse_per_step[i], _, rng_states[i] = _flash_attn_forward(
                                    q_inputs[i%2], kv_inputs[i%2][0], kv_inputs[i%2][1],
                                    cu_seqlens_q//2, cu_seqlens_k, max_seqlen_q//2, max_seqlen_k,
                                    dropout_p, softmax_scale, causal=False, return_softmax=False,
                                    **fa_optional_forward_kwargs
                                )
                    else:
                        if use_fused_attention:
                            out_per_step[i], [softmax_lse_per_step[i], rng_states[i]] = \
                            fused_attn_fwd(
                                is_training, max_seqlen_q, max_seqlen_k, cu_seqlens_q,
                                cu_seqlens_k, q, kv_inputs[i%2][0],
                                kv_inputs[i%2][1], TE_DType[q.dtype],
                                tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
                                attn_scale=softmax_scale, dropout=dropout_p,
                                qkv_layout="bshd_bshd_bshd", attn_mask_type="no_mask",
658
                            )
659
                        else:
660
661
662
                            # [b, sq, np, hn] -> [b*sq, np, hn]
                            q_inputs[i%2] = q.view(-1, *q.shape[-2:])
                            # [2, b, sk, np, hn] -> [2, b*sk, np, hn]
663
                            kv_inputs[i%2] = kv_inputs[i%2].view(2, -1, *k.shape[-2:])
664
665
666
                            _, _, _, _, out_per_step[i], \
                            softmax_lse_per_step[i], _, rng_states[i] = _flash_attn_forward(
                                q_inputs[i%2], kv_inputs[i%2][0], kv_inputs[i%2][1],
667
668
669
                                cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
                                dropout_p, softmax_scale, causal=False, return_softmax=False,
                                **fa_optional_forward_kwargs
670
                            )
671
672
673
674
675
676

            if i > 0:
                # wait until fwd restuls correction of last step is done
                if i > 1:
                    flash_attn_streams[(i-1)%2].wait_event(fwd_results_correction_done)

677
678
679
680
                if use_fused_attention:
                    # [b, np, sq, 1] -> [b, np, sq]
                    softmax_lse_per_step[i-1].squeeze_(-1)

681
                with torch.cuda.stream(flash_attn_streams[(i-1)%2]):
682
683
684
685
                    if i == 1:
                        out = torch.empty_like(q).zero_()
                        softmax_lse = torch.clone(softmax_lse_per_step[0]).to(torch.double)
                        if causal:
686
687
688
689
                            # [b, np, sq] -> [b, np, 2, sq//2]
                            softmax_lse_ = softmax_lse.view(
                                *softmax_lse.shape[:-1], 2, softmax_lse.shape[-1]//2
                            )
690
691
692
                    elif (i-1) <= rank or not causal:
                        flash_attn_fwd_softmax_lse_correction(softmax_lse,
                                                              softmax_lse_per_step[i-1])
693
                    else:
694
695
                        flash_attn_fwd_softmax_lse_correction(softmax_lse_[..., 1, :],
                                                              softmax_lse_per_step[i-1])
696
697
698
699
700
701
702
703
704
705

                if i < cp_size:
                    flash_attn_streams[(i-1)%2].record_event(fwd_results_correction_done)

        torch.cuda.current_stream().wait_stream(flash_attn_streams[1])

        softmax_lse = softmax_lse.to(torch.float)
        for i in range(cp_size):
            # [b*sq, np, hn] -> [b, sq, np, hn] or [b*sq//2, np, hn] -> [b, sq//2, np, hn]
            out_ = out_per_step[i].view(out.shape[0], -1, *out.shape[-2:])
706
            if i <= rank or not causal:
707
708
709
710
711
712
713
714
715
716
717
                flash_attn_fwd_out_correction(out.view(*out_.shape),
                                              out_,
                                              softmax_lse,
                                              softmax_lse_per_step[i])
            else:
                flash_attn_fwd_out_correction(out[:, 1, ...],
                                              out_,
                                              softmax_lse_[..., 1, :],
                                              softmax_lse_per_step[i])

        kv = p2p_comm_buffers[-1]
718
719
720
721
        if use_fused_attention:
            out = out.view(out.shape[0], -1, *out.shape[-2:])
        else:
            out = out.view(-1, *out.shape[-2:])
722
723
724
725
726
727
728
729
730
731
        ctx.save_for_backward(q, kv, out, softmax_lse, cu_seqlens_q, cu_seqlens_k)
        ctx.rng_states = rng_states
        ctx.cp_group = cp_group
        ctx.cp_global_ranks = cp_global_ranks
        ctx.dropout_p = dropout_p
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_k = max_seqlen_k
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
        ctx.deterministic = deterministic
732
        ctx.use_fused_attention = use_fused_attention
733
734
735
736
737
738
739
740
741
742
743
744
        return out

    @staticmethod
    def backward(ctx, dout):
        q, kv, out, softmax_lse, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors

        cp_size = get_distributed_world_size(ctx.cp_group)
        rank = get_distributed_rank(ctx.cp_group)
        send_dst = ctx.cp_global_ranks[(rank + cp_size - 1) % cp_size]
        recv_src = ctx.cp_global_ranks[(rank + 1) % cp_size]
        batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2)

745
746
747
748
749
750
751
752
753
754
        if ctx.causal:
            # [b, np, sq] -> [b, np, 2, sq//2]
            softmax_lse_ = softmax_lse.view(*softmax_lse.shape[:-1], 2, softmax_lse.shape[-1]//2)
            softmax_lse_ = softmax_lse_[..., 1, :].contiguous()
            if ctx.use_fused_attention:
                # [b, np, sq//2] -> [b, np, sq//2, 1]
                softmax_lse_.unsqueeze_(-1)
        if ctx.use_fused_attention:
            # [b, np, sq] -> [b, np, sq, 1]
            softmax_lse.unsqueeze_(-1)
755
756
757
758
759
760
761
762
763
764
        out = out.view(*q.shape)
        dout = dout.view(*q.shape)
        # Flash Attn outputs
        dq = torch.empty_like(q)

        p2p_comm_buffers = [torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device), \
                            torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device)]
        p2p_comm_buffers[0][0].copy_(kv)
        send_recv_reqs = []

765
766
767
768
769
770
        fa_optional_backward_kwargs = {}
        if _flash_attn_2_4_plus:
            fa_optional_backward_kwargs["alibi_slopes"] = None
        if _flash_attn_2_4_1_plus:
            fa_optional_backward_kwargs["deterministic"] = ctx.deterministic

771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
        for i in range(cp_size):
            # wait until KV is received
            for req in send_recv_reqs:
                req.wait()

            send_tensor = p2p_comm_buffers[i%2]
            recv_tensor = p2p_comm_buffers[(i+1)%2]
            if i == 0:
                send_tensor = send_tensor[0]
                recv_tensor = recv_tensor[0]
            if i == (cp_size-1):
                send_tensor = send_tensor[1]
                recv_tensor = recv_tensor[1]

            send_recv_reqs = flash_attn_p2p_communicate(rank,
                                                        send_tensor,
                                                        send_dst,
                                                        recv_tensor,
                                                        recv_src,
                                                        ctx.cp_group,
                                                        batch_p2p_comm)

            kv = p2p_comm_buffers[i%2][0]
            # In reversed order of fwd
            if ctx.causal:
                if i == (cp_size-1):
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
                    if ctx.use_fused_attention:
                        # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
                        q_ = q.view(q.shape[0], -1, *q.shape[-2:])
                        # [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn]
                        kv_ = kv.view(*kv.shape[0:2], -1, *kv.shape[-2:])
                        # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
                        out_ = out.view(out.shape[0], -1, *out.shape[-2:])
                        dout_ = dout.view(dout.shape[0], -1, *dout.shape[-2:])
                        dq_, dk_, dv_, _ = fused_attn_bwd(
                            ctx.max_seqlen_q, ctx.max_seqlen_k,
                            cu_seqlens_q, cu_seqlens_k,
                            q_, kv_[0], kv_[1], out_, dout_, TE_DType[q.dtype],
                            [softmax_lse, ctx.rng_states[cp_size-i-1]],
                            tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
                            attn_scale=ctx.softmax_scale,
                            dropout=ctx.dropout_p,
                            qkv_layout="bshd_bshd_bshd",
                            attn_mask_type="causal",
                        )
                    else:
                        # [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
                        q_ = q.view(-1, *q.shape[-2:])
                        dq_ = torch.empty_like(q_)
                        # [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn]
                        kv_ = kv.view(2, -1, *kv.shape[-2:])
                        dkv_ = torch.empty_like(kv_)
                        # [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
                        out_ = out.view(-1, *out.shape[-2:])
                        dout_ = dout.view(-1, *dout.shape[-2:])
                        if _flash_attn_2_3_plus:
                            fa_optional_backward_kwargs["window_size"] = [-1, 0]
                        _flash_attn_backward(
                            dout_, q_, kv_[0], kv_[1], out_, softmax_lse,
                            dq_, dkv_[0], dkv_[1], cu_seqlens_q, cu_seqlens_k,
                            ctx.max_seqlen_q, ctx.max_seqlen_k,
                            ctx.dropout_p, ctx.softmax_scale, True,
                            rng_state=ctx.rng_states[cp_size-i-1],
                            **fa_optional_backward_kwargs
                        )
                elif i >= (cp_size-rank-1):
                    if ctx.use_fused_attention:
                        # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
                        q_ = q.view(q.shape[0], -1, *q.shape[-2:])
                        # [2, b, 2, sk//2, np, hn] -> [2, b, sk//2, np, hn]
                        kv_ = kv[:, :, 0, ...].contiguous()
                        # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
                        out_ = out.view(out.shape[0], -1, *out.shape[-2:])
                        dout_ = dout.view(dout.shape[0], -1, *dout.shape[-2:])
                        dq_, dk_, dv_, _ = fused_attn_bwd(
                            ctx.max_seqlen_q, ctx.max_seqlen_k//2,
                            cu_seqlens_q, cu_seqlens_k//2,
                            q_, kv_[0], kv_[1], out_, dout_, TE_DType[q.dtype],
                            [softmax_lse, ctx.rng_states[cp_size-i-1]],
                            tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
                            attn_scale=ctx.softmax_scale,
                            dropout=ctx.dropout_p,
                            qkv_layout="bshd_bshd_bshd",
                            attn_mask_type="no_mask",
                        )
                    else:
                        # [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
                        q_ = q.view(-1, *q.shape[-2:])
                        dq_ = torch.empty_like(q_)
                        # [2, b, 2, sk//2, np, hn] -> [2, b, sk//2, np, hn] -> [2, b*sk//2, np, hn]
                        kv_ = kv[:, :, 0, ...].contiguous().view(2, -1, *kv.shape[-2:])
                        dkv_ = torch.empty_like(kv_)
                        # [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
                        out_ = out.view(-1, *out.shape[-2:])
                        dout_ = dout.view(-1, *dout.shape[-2:])
                        if _flash_attn_2_3_plus:
                            fa_optional_backward_kwargs["window_size"] = [-1, -1]
                        _flash_attn_backward(
                            dout_, q_, kv_[0], kv_[1], out_, softmax_lse,
                            dq_, dkv_[0], dkv_[1], cu_seqlens_q, cu_seqlens_k//2,
                            ctx.max_seqlen_q, ctx.max_seqlen_k//2,
                            ctx.dropout_p, ctx.softmax_scale, False,
                            rng_state=ctx.rng_states[cp_size-i-1],
                            **fa_optional_backward_kwargs
                        )
                else:
                    if ctx.use_fused_attention:
                        # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
                        q_ = q[:, 1, ...].contiguous()
                        # [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn]
                        kv_ = kv.view(*kv.shape[0:2], -1, *kv.shape[-2:])
                        # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
                        out_ = out[:, 1, ...].contiguous()
                        dout_ = dout[:, 1, ...].contiguous()
                        dq_, dk_, dv_, _ = fused_attn_bwd(
                            ctx.max_seqlen_q//2, ctx.max_seqlen_k,
                            cu_seqlens_q//2, cu_seqlens_k,
                            q_, kv_[0], kv_[1], out_, dout_, TE_DType[q.dtype],
                            [softmax_lse_, ctx.rng_states[cp_size-i-1]],
                            tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
                            attn_scale=ctx.softmax_scale,
                            dropout=ctx.dropout_p,
                            qkv_layout="bshd_bshd_bshd",
                            attn_mask_type="no_mask",
                        )
                    else:
                        # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] -> [b*sq//2, np, hn]
                        q_ = q[:, 1, ...].contiguous().view(-1, *q.shape[-2:])
                        dq_ = torch.empty_like(q_)
                        # [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn]
                        kv_ = kv.view(2, -1, *kv.shape[-2:])
                        dkv_ = torch.empty_like(kv_)
                        # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] -> [b*sq//2, np, hn]
                        out_ = out[:, 1, ...].contiguous().view(-1, *out.shape[-2:])
                        dout_ = dout[:, 1, ...].contiguous().view(-1, *dout.shape[-2:])
                        if _flash_attn_2_3_plus:
                            fa_optional_backward_kwargs["window_size"] = [-1, -1]
                        _flash_attn_backward(
                            dout_, q_, kv_[0], kv_[1], out_, softmax_lse_,
                            dq_, dkv_[0], dkv_[1], cu_seqlens_q//2, cu_seqlens_k,
                            ctx.max_seqlen_q//2, ctx.max_seqlen_k,
                            ctx.dropout_p, ctx.softmax_scale, False,
                            rng_state=ctx.rng_states[cp_size-i-1],
                            **fa_optional_backward_kwargs
                        )
            else:
                if ctx.use_fused_attention:
                    dq_, dk_, dv_, _ = fused_attn_bwd(
                        ctx.max_seqlen_q, ctx.max_seqlen_k,
                        cu_seqlens_q, cu_seqlens_k,
                        q, kv[0], kv[1], out, dout, TE_DType[q.dtype],
                        [softmax_lse, ctx.rng_states[cp_size-i-1]],
                        tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
                        attn_scale=ctx.softmax_scale,
                        dropout=ctx.dropout_p,
                        qkv_layout="bshd_bshd_bshd",
                        attn_mask_type="no_mask",
                    )
                else:
                    # [b, sq, np, hn] -> [b*sq, np, hn]
931
932
                    q_ = q.view(-1, *q.shape[-2:])
                    dq_ = torch.empty_like(q_)
933
                    # [2, b, sk, np, hn] -> [2, b*sk, np, hn]
934
935
                    kv_ = kv.view(2, -1, *kv.shape[-2:])
                    dkv_ = torch.empty_like(kv_)
936
                    # [b, sq, np, hn] -> [b*sq, np, hn]
937
938
                    out_ = out.view(-1, *out.shape[-2:])
                    dout_ = dout.view(-1, *dout.shape[-2:])
939
940
                    if _flash_attn_2_3_plus:
                        fa_optional_backward_kwargs["window_size"] = [-1, -1]
941
942
943
944
945
                    _flash_attn_backward(
                        dout_, q_, kv_[0], kv_[1], out_, softmax_lse,
                        dq_, dkv_[0], dkv_[1], cu_seqlens_q, cu_seqlens_k,
                        ctx.max_seqlen_q, ctx.max_seqlen_k,
                        ctx.dropout_p, ctx.softmax_scale, False,
946
                        **fa_optional_backward_kwargs
947
948
                    )

949
950
951
952
953
954
955
            if i >= (cp_size-rank-1) or not ctx.causal:
                # [b*sq, np, hn] -> [b, 2, sq//2, np, hn] if causal
                # [b*sq, np, hn] -> [b, sq, np, hn] if not causal
                dq_ = dq_.view(*dq.shape)
            else:
                # [b*sq//2, np, hn] -> [b, sq//2, np, hn]
                dq_ = dq_.view(dq.shape[0], *dq.shape[2:])
956

957
            if ctx.causal:
958
959
960
961
962
963
964
965
966
967
968
969
                if i > (cp_size-rank-1):
                    dq.add_(dq_)
                elif i == (cp_size-rank-1):
                    if rank == (cp_size-1):
                        dq.copy_(dq_)
                    else:
                        dq[:, 0, ...].copy_(dq_[:, 0, ...])
                        dq[:, 1, ...].add_(dq_[:, 1, ...])
                elif i > 0:
                    dq[:, 1, ...].add_(dq_)
                else:
                    dq[:, 1, ...].copy_(dq_)
970
971
972
973
974
            else:
                if i == 0:
                    dq.copy_(dq_)
                else:
                    dq.add_(dq_)
975

976
977
978
            # wait until dKV is received
            for req in send_recv_reqs:
                req.wait()
979

980
981
982
983
984
985
986
987
988
989
            dkv = p2p_comm_buffers[(i+1)%2][1]
            if ctx.use_fused_attention:
                dkv_ = torch.cat((dk_.unsqueeze(0), dv_.unsqueeze(0)), dim=0)
            if ctx.causal and i >= (cp_size-rank-1) and i != (cp_size-1):
                # [2, b*sk//2, np, hn] -> [2, b, sk//2, np, hn]
                dkv_ = dkv_.view(*dkv.shape[0:2], *dkv.shape[3:])
            else:
                # [2, b*sk, np, hn] -> [2, b, 2, sk//2, np, hn] if causal
                # [2, b*sk, np, hn] -> [2, b, sk, np, hn] if not causal
                dkv_ = dkv_.view(*dkv.shape)
990

991
            if ctx.causal:
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
                if i == (cp_size-1):
                    if rank == 0:
                        dkv[:, :, 0, ...].add_(dkv_[:, :, 0, ...])
                        dkv[:, :, 1, ...].copy_(dkv_[:, :, 1, ...])
                    else:
                        dkv.add_(dkv_)
                elif i >= (cp_size-rank-1):
                    if i == 0 and rank == (cp_size-1):
                        dkv[:, :, 0, ...].copy_(dkv_)
                    else:
                        dkv[:, :, 0, ...].add_(dkv_)
                elif i > 0:
                    dkv.add_(dkv_)
                else:
                    dkv.copy_(dkv_)
            else:
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
                if i == 0:
                    dkv.copy_(dkv_)
                else:
                    dkv.add_(dkv_)

        if ctx.causal:
            # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
            dq = dq.view(q.shape[0], -1, *q.shape[-2:])
            # [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn]
            dkv = dkv.view(*kv.shape[0:2], -1, *kv.shape[-2:])
        return None, dq, dkv[0], dkv[1], None, None, None, None, None, None, \
                None, None, None, None, None, None


def attn_forward_func_with_cp(
    is_training, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p,
    cp_group, cp_global_ranks, cp_stream, softmax_scale=None, attn_mask_type="causal",
    deterministic=False, use_fused_attention=False
) -> torch.Tensor:
    """Attention implementation with context parallelism"""
    assert (attn_mask_type in ["causal", "no_mask"]
        ), f"Mask type of {attn_mask_type} is not supported with context parallelism!"
    out = AttnFuncWithCP.apply(
        is_training, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
        dropout_p, cp_group, cp_global_ranks, cp_stream, softmax_scale, attn_mask_type,
        deterministic, use_fused_attention
1034
1035
1036
1037
    )
    return out


1038
1039
1040
1041
1042
1043
1044
class RotaryPositionEmbedding(torch.nn.Module):
    """
    Implements Rotary Position Embedding from https://arxiv.org/abs/2104.09864.
    """
    def __init__(
        self,
        dim: int,
1045
        rotary_percent: float = 1.0,
1046
1047
1048
1049
1050
1051
1052
1053
        seq_len_interpolation_factor: Optional[int] = None,
        pretrained_max_position_embeddings: Optional[int] = None,
    ):
        """
        Parameters
        ----------
        dim: int
            rotary embedding dimension
1054
1055
        rotary_percent: float
            Percent of rotary dimension to use for rotary position embeddings.
1056
1057
1058
1059
1060
1061
1062
        seq_len_interpolation_factor: int
            if not None, discrete positions will be interpolated by this factor via the trick in
            https://arxiv.org/abs/2306.15595
        pretrained_max_position_embeddings: int
            pre-trained max_position_embeddings before position interpolation
        """
        super().__init__()
1063
1064
        if rotary_percent < 1.0:
            dim = int(dim * rotary_percent)
1065
        self.seq_len_interpolation_factor = seq_len_interpolation_factor
1066
1067
1068
1069
1070
1071
1072
        inv_freq = 1.0 / (
            10000
            ** (
                torch.arange(0, dim, 2, dtype=torch.float32, device=torch.cuda.current_device())
                / dim
            )
        )
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
        self.register_buffer('inv_freq', inv_freq)
        self.pretrained_max_position_embeddings = pretrained_max_position_embeddings

    def forward(self, max_seq_len: int, offset: int = 0):
        """
        Create rotary position embedding frequencies

        Parameters
        ----------
        max_seq_len: int
            sequence length of a sample
        offset: int, default = 0
            fixed offset for freqencies
        """
1087
1088
1089
1090
        seq = (
            torch.arange(max_seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
            + offset
        )
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108

        if (self.pretrained_max_position_embeddings is not None
            and self.seq_len_interpolation_factor is not None):
            if (max_seq_len >
                self.pretrained_max_position_embeddings * self.seq_len_interpolation_factor):
                # dynamic linear scaling (length > position we have learned)
                seq *= 1 / (max_seq_len / self.pretrained_max_position_embeddings)
            else:
                # fixed linear scaling
                seq *= 1 / self.seq_len_interpolation_factor

        freqs = torch.einsum('i , j -> i j', seq, self.inv_freq)
        # first part even vector components, second part odd vector components,
        #  2 * dim in dimension size
        emb = torch.cat((freqs, freqs), dim=-1)
        # emb [seq_length, .., dim]
        return emb.reshape(emb.size(0), 1, 1, emb.size(1))

1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160

class FusedRoPEFunc(torch.autograd.Function):
    """
    Function for FusedRoPE

    This implementation assumes the input tensor to be in `sbhd`, `bshd` or `thd` format and
    the RoPE tensor to be of shape (s, 1, 1, d). It accepts arbitrary memory layouts to avoid
    the expensive `.contiguous()` calls, thus it may not achieve the best memory access pattern.
    """

    @staticmethod
    def forward(
        ctx,
        t: torch.Tensor,
        freqs: torch.Tensor,
        tensor_format: str = "sbhd",
        cu_seqlens: Union[torch.Tensor, None] = None,
    ) -> torch.Tensor:
        if tensor_format == "sbhd":
            output = tex.fused_rope_forward(t, freqs, False)
        elif tensor_format == "bshd":
            output = tex.fused_rope_forward(
                t.transpose(0, 1), freqs, True
            ).transpose(0, 1)
        elif tensor_format == "thd":
            output = tex.fused_rope_thd_forward(t, cu_seqlens, freqs)
        else:
            raise ValueError(f"Unsupported tensor_format: {tensor_format}.")
        ctx.save_for_backward(freqs, cu_seqlens)
        ctx.tensor_format = tensor_format

        return output

    @staticmethod
    def backward(
        ctx, grad_output: torch.Tensor
    ) -> Tuple[Union[torch.Tensor, None], ...]:
        freqs, cu_seqlens = ctx.saved_tensors
        if ctx.tensor_format == "sbhd":
            grad_input = tex.fused_rope_backward(grad_output, freqs, False)
        elif ctx.tensor_format == "bshd":
            grad_input = tex.fused_rope_backward(
                grad_output.transpose(0, 1), freqs, True
            ).transpose(0, 1)
        elif ctx.tensor_format == "thd":
            grad_input = tex.fused_rope_thd_backward(grad_output, cu_seqlens, freqs)
        else:
            raise ValueError(f"Unsupported tensor_format: {ctx.tensor_format}.")

        return grad_input, None, None, None, None


1161
1162
1163
1164
1165
1166
1167
1168
1169
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
    """
    change sign so the last dimension becomes [-odd, +even]
    """
    x = x.view(x.shape[:-1] + torch.Size((2, x.shape[-1] // 2)))
    x1, x2 = x.unbind(dim=-2)
    return torch.cat((-x2, x1), dim=-1)


1170
def apply_rotary_pos_emb(
1171
1172
1173
1174
1175
1176
    t: torch.Tensor,
    freqs: torch.Tensor,
    tensor_format: str = "sbhd",
    fused: bool = False,
    cu_seqlens: Union[torch.Tensor, None] = None,
) -> torch.Tensor:
1177
    """
1178
    Apply rotary positional embedding tensor to the input tensor.
1179

1180
1181
1182
    Parameters
    ----------
    t: torch.Tensor
1183
        Input tensor of shape `[s, b, h, d]`, `[b, s, h, d]` or `[t, h, d]`, on which
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
        rotary positional embedding will be applied.
    freqs: torch.Tensor
        Rotary positional embedding tensor of shape `[s2, 1, 1, d2]` and dtype 'float',
        with `s2 >= s` and `d2 <= d`.
    fused: bool, default = False
        Whether to use a fused applying RoPE implementation.
    tensor_format: {'sbhd', 'bshd', 'thd'}, default = 'sbhd'
        is `bshd` if `t` is of shape `[bs, seq, ...]`, or `sbhd` if `t` is
        of shape `[seq, bs, ...]`. 'thd' is only supported when `fused` is True.
    cu_seqlens: torch.Tensor, default = None.
        Cumulative sum of sequence lengths in a batch for `t`, with shape [b + 1] and
        dtype torch.int32. Only valid when `tensor_format` is 'thd'.
1196
    """
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
    if fused:
        assert (
            tensor_format != "thd" or cu_seqlens is not None
        ), "cu_seqlens must not be None when tensor_format is 'thd'."
        return FusedRoPEFunc.apply(t, freqs, tensor_format, cu_seqlens)

    assert tensor_format in ("sbhd", "bshd"), (
        "Only formats `sbhd` or `bshd` are supported for input tensor `t` "
        f"when fused is False, got {tensor_format}."
    )

1208
1209
1210
1211
1212
    max_seq_len = freqs.shape[0]
    cur_seq_len = t.shape[1] if tensor_format == "bshd" else t.shape[0]

    # Only apply the rotary embeddings up to the sequence length of the running
    # input.
1213
1214
1215
1216
    assert cur_seq_len <= max_seq_len, (
        f"Rotary Embeddings only supported up to {max_seq_len} sequence length!"
    )
    freqs = freqs[:cur_seq_len]
1217
    if tensor_format == "bshd":
1218
1219
1220
1221
        freqs = freqs.transpose(0, 1)  # [seq, 1, 1, dim] -> [1, seq, 1, dim]
    # cos/sin first then dtype conversion for better precision
    cos_ = torch.cos(freqs).to(t.dtype)
    sin_ = torch.sin(freqs).to(t.dtype)
1222

1223
1224
1225
1226
1227
1228
    rot_dim = freqs.shape[-1]
    # ideally t_pass is empty so rotary pos embedding is applied to all tensor t
    t, t_pass = t[..., :rot_dim], t[..., rot_dim:]

    # first part is cosine component
    # second part is sine component, need to change signs with _rotate_half method
1229
    t = (t * cos_) + (_rotate_half(t) * sin_)
1230
1231
1232
    return torch.cat((t, t_pass), dim=-1)


cyanguwa's avatar
cyanguwa committed
1233
class _SplitAlongDim(torch.autograd.Function):
1234
1235
1236
1237
1238
    """"""

    @staticmethod
    def forward(ctx,
                mixed_x_layer: torch.Tensor,
cyanguwa's avatar
cyanguwa committed
1239
1240
                split_dim: int,
                split_size_or_sections: Union[int, List[int], Tuple[int]],
1241
    ) -> Tuple[torch.Tensor, ...]:
cyanguwa's avatar
cyanguwa committed
1242
1243
1244
        ctx.split_dim = split_dim
        ctx.split_size_or_sections = split_size_or_sections
        return torch.split(mixed_x_layer, split_size_or_sections, dim = split_dim)
1245
1246
1247
1248
1249
1250

    @staticmethod
    def backward(ctx,
                 *grad_outputs):
        assert len(grad_outputs) > 0, "No gradients received for backprop!"

cyanguwa's avatar
cyanguwa committed
1251
1252
1253
1254
1255
1256
1257
1258
1259
        if isinstance(ctx.split_size_or_sections, (list, tuple)):
            split_sizes = ctx.split_size_or_sections
            assert (len(grad_outputs) == len(split_sizes)
                ), "Unequal number of gradients vs split sections for backprop!"
        if isinstance(ctx.split_size_or_sections, int):
            split_sizes = [ctx.split_size_or_sections] * len(grad_outputs)
        dims = len(grad_outputs[0].shape)
        split_dim = (ctx.split_dim + dims) % dims

1260
1261
        noop_ok = True
        strides = grad_outputs[0].stride()
1262
        data_ptr = grad_outputs[0].untyped_storage().data_ptr()
cyanguwa's avatar
cyanguwa committed
1263
        shape = list(grad_outputs[0].shape)
1264
        for i, tensor in enumerate(grad_outputs):
cyanguwa's avatar
cyanguwa committed
1265
1266
1267
            shape_i = shape
            shape_i[split_dim] = split_sizes[i]
            offset_size = sum(split_sizes[:i]) * np.prod(shape[split_dim+1:])
1268
            if (tensor.stride() != strides or
cyanguwa's avatar
cyanguwa committed
1269
                list(tensor.shape) != shape_i or
1270
                tensor.untyped_storage().data_ptr() != data_ptr or
cyanguwa's avatar
cyanguwa committed
1271
                tensor.storage_offset() != offset_size):
1272
1273
1274
1275
1276
1277
1278
                noop_ok = False
                break

        if noop_ok:
            ret = torch.Tensor().to(device=grad_outputs[0].device,
                                    dtype=grad_outputs[0].dtype)
            new_shape = list(shape)
cyanguwa's avatar
cyanguwa committed
1279
1280
            new_shape[split_dim] = sum(split_sizes)
            ret.set_(grad_outputs[0].untyped_storage(),
1281
1282
                     grad_outputs[0].storage_offset(),
                     new_shape,
cyanguwa's avatar
cyanguwa committed
1283
                     strides
1284
            )
cyanguwa's avatar
cyanguwa committed
1285
            return ret, None, None
1286

cyanguwa's avatar
cyanguwa committed
1287
        return torch.cat(grad_outputs, dim = split_dim), None, None
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307


class UnfusedDotProductAttention(torch.nn.Module):
    """Parallel attention w/o QKV and Proj Gemms
    BMM1 -> softmax + dropout -> BMM2
    """

    def __init__(
        self,
        norm_factor: float,
        attention_dropout: float = 0.0,
        attention_dropout_ctx: Optional[Callable] = nullcontext,
        layer_number: Optional[int] = None,
    ) -> None:
        super().__init__()

        self.norm_factor = norm_factor
        self.attention_dropout_ctx = attention_dropout_ctx
        self.layer_number = layer_number

1308
        self.scale_mask_softmax = FusedScaleMaskSoftmax(attention_mask_func)
1309
1310
1311
1312
1313
1314

        # Dropout. Note that for a single iteration, this layer will generate
        # different outputs on different number of parallel partitions but
        # on average it should not be partition dependent.
        self.attention_dropout = torch.nn.Dropout(attention_dropout)

1315
1316
1317
1318
        # An FP16 training trick required for certain GPT-like models.
        self.apply_qk_layer_scaling = (
            bool(int(os.getenv("NVTE_APPLY_QK_LAYER_SCALING", "0"))) and layer_number is not None)

1319
1320
1321
1322
1323
    def forward(
        self,
        query_layer: torch.Tensor,
        key_layer: torch.Tensor,
        value_layer: torch.Tensor,
1324
1325
1326
        qkv_layout: str = "sbh3d",
        cu_seqlens_q: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
        cu_seqlens_kv: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
1327
        attn_mask_type: str = "causal",
1328
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
1329
1330
        core_attention_bias_type: str = "no_bias",
        core_attention_bias: Optional[torch.Tensor] = None,
1331
        alibi_slopes: Optional[torch.Tensor] = None,
1332
    ) -> torch.Tensor:
1333
        """Unfused attention fprop"""
1334

1335
1336
1337
1338
1339
1340
1341
1342
1343
        assert (qkv_layout in QKVLayouts
            ), f"UnfusedDotProductAttention does not support qkv_layout = {qkv_layout}!"
        qkv_format = ''.join([i for i in qkv_layout.split('_')[0] if i.isalpha()])
        assert (qkv_format != 'thd'
            ), """UnfusedDotProductAttention does not support variable sequence lengths!"""
        if qkv_format == 'bshd':
            # convert to sbhd and use sbhd implementation for now
            query_layer, key_layer, value_layer = [x.transpose(0, 1)
                for x in [query_layer, key_layer, value_layer]]
1344

1345
        batch_size, seqlen = query_layer.shape[1], query_layer.shape[0]
1346
        apply_qk_layer_scaling = self.apply_qk_layer_scaling and key_layer.dtype == torch.float16
1347
1348
1349
1350
1351
1352
1353
1354
1355

        # [b, np, sq, sk]
        output_size = (
            query_layer.size(1),
            query_layer.size(2),
            query_layer.size(0),
            key_layer.size(0),
        )

1356
1357
1358
1359
1360
1361
1362
1363
        if key_layer.shape[2] != query_layer.shape[2]:
            assert (query_layer.shape[2]%key_layer.shape[2]==0
                ),"The number of attention heads must be divisible by the number of GQA groups!"
            key_layer = key_layer.repeat_interleave(
                    int(query_layer.shape[2]/key_layer.shape[2]), dim = 2)
            value_layer = value_layer.repeat_interleave(
                    int(query_layer.shape[2]/value_layer.shape[2]), dim = 2)

1364
1365
1366
1367
1368
1369
1370
1371
        # [sq, b, np, hn] -> [sq, b * np, hn]
        query_layer = query_layer.reshape(
            output_size[2], output_size[0] * output_size[1], -1
        )
        # [sk, b, np, hn] -> [sk, b * np, hn]
        key_layer = key_layer.reshape(output_size[3], output_size[0] * output_size[1], -1)

        # preallocting result tensor: [b * np, sq, sk]
1372
1373
        # WAR to set dtype to FP32 as ONNX lacks BF16 support for ConstantOfShape operator
        is_bf16 = query_layer.dtype == torch.bfloat16
1374
1375
1376
1377
        matmul_result = torch.empty(
            output_size[0] * output_size[1],
            output_size[2],
            output_size[3],
1378
            dtype=torch.float32 if is_in_onnx_export_mode() and is_bf16 else query_layer.dtype,
1379
1380
1381
            device=torch.cuda.current_device(),
        )

1382
1383
1384
        if is_in_onnx_export_mode() and is_bf16:
            matmul_result = matmul_result.bfloat16()

1385
1386
1387
1388
1389
        scale = self.norm_factor
        if apply_qk_layer_scaling:
            scale *= self.layer_number

        # Raw attention scores. [b * np, sq, sk]
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
        if core_attention_bias_type == "no_bias":
            matmul_result = torch.baddbmm(
                matmul_result,
                query_layer.transpose(0, 1),  # [b * np, sq, hn]
                key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
                beta=0.0,
                alpha=(1.0 / scale),
            )

        elif core_attention_bias_type == "pre_scale_bias":
            assert core_attention_bias is not None, "core_attention_bias should not be None!"
            matmul_result = torch.bmm(
                query_layer.transpose(0, 1),  # [b * np, sq, hn]
                key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
            )
            matmul_result = (matmul_result.view(
                output_size[0], output_size[1], output_size[2], output_size[3])
                + core_attention_bias).view(-1, output_size[2], output_size[3])
            matmul_result /= scale

1410
1411
1412
1413
        elif core_attention_bias_type in ["post_scale_bias", "alibi"]:
            if core_attention_bias_type == "post_scale_bias":
                assert core_attention_bias is not None, "core_attention_bias should not be None!"
            if core_attention_bias_type == "alibi":
1414
1415
                _, core_attention_bias = get_alibi(
                    output_size[1], output_size[2], output_size[3], alibi_slopes=alibi_slopes)
1416
1417
1418
1419
1420
1421
1422
1423
1424
            matmul_result = torch.baddbmm(
                matmul_result,
                query_layer.transpose(0, 1),  # [b * np, sq, hn]
                key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
                beta=0.0,
                alpha=(1.0 / scale),
            )
            matmul_result = (matmul_result.view(
                output_size[0], output_size[1], output_size[2], output_size[3])
1425
1426
                + core_attention_bias).view(-1, output_size[2], output_size[3]).to(
                dtype=query_layer.dtype)
1427
1428
1429
1430
1431
1432

        # change view to [b, np, sq, sk]
        attention_scores = matmul_result.view(*output_size)

        # attention scores and attention mask [b, np, sq, sk]
        softmax_scale = self.layer_number if apply_qk_layer_scaling else None
1433
1434
        attention_probs = self.scale_mask_softmax(
            attention_scores, attention_mask, attn_mask_type, softmax_scale)
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        with self.attention_dropout_ctx():
            attention_probs = self.attention_dropout(attention_probs)

        # value_layer -> context layer.
        # [sk, b, np, hn] --> [b, np, sq, hn]
        output_size = (
            value_layer.size(1),
            value_layer.size(2),
            query_layer.size(0),
            value_layer.size(3),
        )

        # change view [sk, b * np, hn]
        value_layer = value_layer.reshape(
            value_layer.size(0), output_size[0] * output_size[1], -1
        )

        # change view [b * np, sq, sk]
        attention_probs = attention_probs.view(
            output_size[0] * output_size[1], output_size[2], -1
        )

        # matmul: [b * np, sq, hn]
        context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))

        # change view [b, np, sq, hn]
        context_layer = context_layer.view(*output_size)

1466
1467
1468
        if qkv_format == 'sbhd':
            # [b, np, sq, hn] --> [sq, b, np, hn]
            context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
1469

1470
1471
1472
1473
1474
1475
1476
1477
1478
            # [sq, b, np, hn] --> [sq, b, hp]
            context_layer = context_layer.view(seqlen, batch_size, -1)

        if qkv_format == 'bshd':
            # [b, np, sq, hn] --> [b, sq, np, hn]
            context_layer = context_layer.permute(0, 2, 1, 3).contiguous()

            # [b, sq, np, hn] --> [b, sq, hp]
            context_layer = context_layer.view(batch_size, seqlen, -1)
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512

        return context_layer


class _PrepareQKVForFA(torch.autograd.Function):
    """This class converts QKV from interleaved (s, b, ...) layout
       to separate contiguous q, k, v tensors in (b, s, ...) layout."""

    @staticmethod
    def forward(ctx,
                query_layer: torch.Tensor,
                key_layer: torch.Tensor,
                value_layer: torch.Tensor
    ) -> torch.Tensor:
        # All inputs received are non-contiguous tensors.
        # The `query_layer` tensor is used to access the
        # full memory region of the QKV tensor.
        qkv = tex.fa_prepare_fwd(query_layer)
        q, k, v = split_tensor_along_dim(qkv, 0, 3)
        query_layer = torch.squeeze(q, 0)
        key_layer = torch.squeeze(k, 0)
        value_layer = torch.squeeze(v, 0)
        return query_layer, key_layer, value_layer

    @staticmethod
    def backward(ctx,
                 dq: torch.Tensor,
                 dk: torch.Tensor,
                 dv: torch.Tensor
    ) -> Tuple[Union[torch.Tensor, None], ...]:
        dqkv = tex.fa_prepare_bwd(dq, dk, dv)
        dq, dk, dv = split_tensor_along_dim(dqkv, -1, 3)
        return dq, dk, dv

1513

1514
1515
1516
1517
1518
1519
1520
def _get_qkv_layout(
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
        qkv_format: str = 'sbhd',
    ) -> str:
    """Get qkv layout.
1521

1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
    Parameters
    ----------
    q: torch.Tensor
        Query tensor.
    k: torch.Tensor
        Key tensor.
    v: torch.Tensor
        Value tensor.
    qkv_format: str, default = `sbhd`
        Dimension format for `q`, `k` and `v`, {`sbhd`, `bshd`, `thd`}. `s` stands for
        the sequence length dimension, `b` batch size, `h` the number of attention heads,
        `d` head size, and `t` the total number of sequences in a batch, i.e.
        `t = sum(s_i) for i = 0...b-1`.

    Returns
    ----------
    qkv_layout: str
       Memory layout of `q`, `k` and `v`. Each `qkv_format` can be mapped to one of five
       memory layouts. For example, `sb3hd` means `q`, `k`, `v` are created as one chunk
       of memory and that they are interleaved in the `2`nd dimension. `sbhd_sbh2d` means
       `q` and `kv` are created in two chunks and that `q` itself is contiguous and `k`, `v`
       are interleaved with each other in the `3`rd dimension, `k = kv[:,:,:,0,:]` and
       `v = kv[:,:,:,1,:]`.
       Mapping:
       `sbhd`: {`sb3hd`, `sbh3d`, `sbhd_sb2hd`, `sbhd_sbh2d`, `sbhd_sbhd_sbhd`}
       `bshd`: {`bs3hd`, `bsh3d`, `bshd_bs2hd`, `bshd_bsh2d`, `bshd_bshd_bshd`}
       `thd` : {`t3hd`, `th3d`, `thd_t2hd`, `thd_th2d`, `thd_thd_thd`}
    """
1550

1551
1552
    check_last_dim_contiguous = all(x.stride(-1) == 1 for x in [q, k, v])
    assert check_last_dim_contiguous, "q, k and v must have stride 1 in their last dimension!"
1553

1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
    def run_iteratively(q, k, v):
        data_ptr = q.untyped_storage().data_ptr()
        check_ptrs_qkv = all(x.untyped_storage().data_ptr() == data_ptr for x in [q, k, v])
        data_ptr = k.untyped_storage().data_ptr()
        check_ptrs_kv = all(x.untyped_storage().data_ptr() == data_ptr for x in [k, v])

        stride = q.stride()
        check_strides_qkv = all(stride == x.stride() for x in [q, k, v])
        stride = k.stride()
        check_strides_kv = all(stride == x.stride() for x in [k, v])

        shape = q.shape
        check_shapes_qkv = all(shape == x.shape for x in [q, k, v])
        shape = k.shape
        check_shapes_kv = all(shape == x.shape for x in [k, v])

        last_dim_size = q.shape[-1]
        check_last_dim_offsets_qkv = all(i * last_dim_size == x.storage_offset()
                            for i, x in enumerate([q, k, v]))
        last_dim_size = k.shape[-1]
        check_last_dim_offsets_kv = all(i * last_dim_size == x.storage_offset()
                            for i, x in enumerate([k, v]))

        last_two_dims_size = q.shape[-1] * q.shape[-2]
        check_last_two_dims_offsets_qkv = all(i * last_two_dims_size == x.storage_offset()
                            for i, x in enumerate([q, k, v]))
        last_two_dims_size = k.shape[-1] * k.shape[-2]
        check_last_two_dims_offsets_kv = all(i * last_two_dims_size == x.storage_offset()
                            for i, x in enumerate([k, v]))

        if (check_ptrs_qkv and check_strides_qkv and check_shapes_qkv
            and check_last_two_dims_offsets_qkv
            and not check_last_dim_offsets_qkv):
            # sb3hd, bs3hd, t3hd
            qkv_layout = qkv_format[:-2] + '3' + qkv_format[-2:]
        elif (check_ptrs_qkv and check_strides_qkv and check_shapes_qkv
            and check_last_dim_offsets_qkv):
            # sbh3d, bsh3d, th3d
            qkv_layout = qkv_format[:-1] + '3' + qkv_format[-1:]
        elif (check_ptrs_kv and check_strides_kv and check_shapes_kv
            and check_last_two_dims_offsets_kv
            and not check_last_dim_offsets_kv):
            # sbhd_sb2hd, bshd_bs2hd, thd_t2hd
            qkv_layout = qkv_format + '_' + qkv_format[:-2] + '2' + qkv_format[-2:]
        elif (check_ptrs_kv and check_strides_kv and check_shapes_kv
            and check_last_dim_offsets_kv):
            # sbhd_sbh2d, bshd_bsh2d, thd_th2d
            qkv_layout = qkv_format + '_' + qkv_format[:-1] + '2' + qkv_format[-1:]
        elif check_strides_kv and check_shapes_kv:
            # sbhd_sbhd_sbhd, bshd_bshd_bshd, thd_thd_thd
            qkv_layout = '_'.join(list([qkv_format])*3)
        else:
            qkv_layout = 'not_supported'

        return qkv_layout

    qkv_layout = run_iteratively(q, k, v)
    if qkv_layout == 'not_supported':
        # force q,k,v to be contiguous and run get_layout again
        q, k, v = [x.contiguous() for x in [q, k, v]]
        qkv_layout = run_iteratively(q, k, v)
    if qkv_layout == 'not_supported':
1616
1617
        raise Exception("The provided qkv memory layout is not supported!")

1618
    return qkv_layout, q, k, v
1619

1620

1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
def check_set_window_size(
        attn_mask_type: str,
        window_size: Tuple[int, int] = None,
    ):
    """Check if sliding window size is compliant with mask type and if not,
    assert or set it to the appropriate size
    """
    if "causal" in attn_mask_type:
        if window_size is None:
            window_size = (-1, 0)
        else:
            assert (
                window_size[1] == 0
            ), "window_size[1] should be 0 when self_attn_mask_type includes 'causal'!"
    else:
        if window_size is None:
            window_size = (-1, -1)
    return window_size
1639

1640

1641
class FlashAttention(torch.nn.Module):
1642
    """Dot product attention, using HazyResearch flash-attn package:
1643
    https://github.com/Dao-AILab/flash-attention
1644
1645
1646
1647
1648
1649
1650
    """

    def __init__(
        self,
        norm_factor: float,
        attention_dropout: float = 0.0,
        attention_dropout_ctx: Optional[Callable] = nullcontext,
1651
1652
        attention_type: str = "self",
        layer_number: Optional[int] = None,
1653
        deterministic: bool = False,
1654
1655
1656
1657
1658
1659
    ) -> None:
        super().__init__()

        assert (
            _flash_attn_version >= _flash_attn_version_required
        ), f"FlashAttention minimum version {_flash_attn_version_required} is required."
1660
1661
1662
        assert (
            _flash_attn_version <= _flash_attn_max_version
        ), f"FlashAttention maximum version {_flash_attn_max_version} is supported."
1663
1664
1665
1666

        self.norm_factor = norm_factor
        self.attention_dropout_ctx = attention_dropout_ctx
        self.attention_dropout = attention_dropout
1667
1668
        self.attention_type = attention_type
        self.layer_number = 1 if layer_number is None else layer_number
1669
        self.deterministic = deterministic
1670
1671
1672
1673
1674
1675

    def forward(
        self,
        query_layer: torch.Tensor,
        key_layer: torch.Tensor,
        value_layer: torch.Tensor,
1676
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
1677
1678
1679
        qkv_layout: str = "sbh3d",
        cu_seqlens_q: Optional[torch.Tensor] = None,
        cu_seqlens_kv: Optional[torch.Tensor] = None,
1680
1681
        max_seqlen_q: Optional[int] = None,
        max_seqlen_kv: Optional[int] = None,
1682
        attn_mask_type: str = "causal",
1683
        window_size: Optional[Tuple[int, int]] = None,
1684
        alibi_slopes: Optional[torch.Tensor] = None,
1685
        cp_group: Optional[dist_group_type] = None,
1686
        cp_global_ranks: List[int] = None,
1687
        cp_stream: torch.cuda.Stream = None,
1688
1689
1690
    ) -> torch.Tensor:
        """flash-attn fprop"""

1691
1692
        window_size = check_set_window_size(attn_mask_type, window_size)

1693
        assert (
1694
1695
1696
            query_layer.dtype in [torch.float16, torch.bfloat16]
            and key_layer.dtype in [torch.float16, torch.bfloat16]
            and value_layer.dtype in [torch.float16, torch.bfloat16]
1697
            ), "FlashAttention currently only supports FP16 and BF16."
1698
1699
        assert (
            query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda
1700
1701
1702
1703
1704
            ), "FlashAttention currently only supports CUDA tensors."
        assert (
            qkv_layout in QKVLayouts
            ), f"FlashAttention does not support qkv_layout = {qkv_layout}!"

1705
1706
        context_parallel = (cp_group is not None) and (get_distributed_world_size(cp_group) != 1)

1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
        qkv_format = ''.join([i for i in qkv_layout.split('_')[0] if i.isalpha()])

        if qkv_format == 'sbhd':
            # For now just 128, will make it more general in the future
            if (query_layer.shape[-1] == 128 and
                query_layer.shape[0] * query_layer.shape[1] >= 512 and
                qkv_layout == "sbh3d"):
                query_layer, key_layer, value_layer = _PrepareQKVForFA.apply(query_layer,
                                                                             key_layer,
                                                                             value_layer)
            else:
                query_layer, key_layer, value_layer = [x.transpose(0,1).contiguous()
                    for x in (query_layer, key_layer, value_layer)]
1720
        elif qkv_format == 'bshd':
1721
1722
1723
            query_layer, key_layer, value_layer = [x.contiguous()
                for x in (query_layer, key_layer, value_layer)]

1724
        batch_size = query_layer.shape[0]
1725

1726
        if qkv_format in ['sbhd', 'bshd']:
1727
            max_seqlen_q, max_seqlen_kv = query_layer.shape[1], key_layer.shape[1]
1728
1729
1730
1731
1732
1733
1734
            if not context_parallel:
                # [b * s, h, d]
                query_layer, key_layer, value_layer = [
                    x.view(x.shape[0] * x.shape[1], *x.shape[2:])
                    for x in [query_layer, key_layer, value_layer]
                ]

1735
            if 'padding' in attn_mask_type:
1736
                assert not context_parallel, "Padding mask not supported with context parallelism!"
1737
1738
1739
1740
1741

                if self.attention_type == "self":
                    assert (
                        max_seqlen_q == max_seqlen_kv
                    ), "Maximum sequence length for Q and KV should be the same."
1742
1743
                    if cu_seqlens_q is None:
                        assert (attention_mask is not None
1744
                                ), "Please provide attention_mask for padding!"
1745
1746
1747
1748
1749
1750
                        cu_seqlens_q, indices_q = get_cu_seqlens_and_indices(attention_mask)
                    else:
                        indices_q = get_indices(max_seqlen_q, cu_seqlens_q)
                    cu_seqlens_kv = cu_seqlens_q
                    query_layer, key_layer, value_layer = PackTensors.apply(
                        indices_q, query_layer, key_layer, value_layer
1751
1752
                    )
                else:
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
                    if cu_seqlens_q is None or cu_seqlens_kv is None:
                        assert (attention_mask is not None
                            ), "Please provide attention_mask for padding!"
                        cu_seqlens_q, indices_q = get_cu_seqlens_and_indices(
                            attention_mask[0])
                        cu_seqlens_kv, indices_kv = get_cu_seqlens_and_indices(
                            attention_mask[1])
                    else:
                        indices_q = get_indices(max_seqlen_q, cu_seqlens_q)
                        indices_kv = get_indices(max_seqlen_kv, cu_seqlens_kv)
                    query_layer = PackTensors.apply(indices_q, query_layer)
                    key_layer, value_layer = PackTensors.apply(
                        indices_kv, key_layer, value_layer
1766
1767
                    )
            else:
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
                # Cumulative sequence lengths for unpadded data
                if cu_seqlens_q is None:
                    cu_seqlens_q = _get_full_cu_seqlens(
                        batch_size,
                        max_seqlen_q,
                        query_layer.device,
                    )
                if cu_seqlens_kv is None:
                    cu_seqlens_kv = _get_full_cu_seqlens(
                        batch_size,
                        max_seqlen_kv,
                        key_layer.device,
                    )
1781
        elif qkv_format == 'thd':
1782
            assert not context_parallel, "thd format not supported with context parallelism!"
1783
1784
            assert (cu_seqlens_q is not None and cu_seqlens_kv is not None
                ), "cu_seqlens_q and cu_seqlens_kv can not be None when qkv_format = thd!"
1785
1786
1787
1788
1789
1790
            if max_seqlen_q is None:
                seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
                max_seqlen_q = seqlens_q.max().item()
            if max_seqlen_kv is None:
                seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
                max_seqlen_kv = seqlens_kv.max().item()
1791

1792
        if context_parallel:
1793
1794
1795
            assert (
                window_size in ((-1, -1), (-1, 0))
                ), "Sliding window attention is not supported with context parallelism."
1796
1797
1798
            assert (
                alibi_slopes is None
            ), "Alibi slope bias addition is not supported with context parallelism."
1799
            with self.attention_dropout_ctx():
1800
1801
                output = attn_forward_func_with_cp(
                    self.training, query_layer, key_layer, value_layer,
1802
                    cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv,
1803
                    self.attention_dropout if self.training else 0.0,
1804
                    cp_group, cp_global_ranks, cp_stream,
1805
                    softmax_scale=1.0/self.norm_factor,
1806
                    attn_mask_type=attn_mask_type,
1807
                    deterministic=self.deterministic
1808
1809
                )
        else:
1810
1811
1812
1813
1814
1815
1816
1817

            from .cpu_offload import CPUOffloadEnabled
            if CPUOffloadEnabled:
                tensor_list = [query_layer, key_layer, value_layer, cu_seqlens_q, cu_seqlens_kv]
                for tensor in tensor_list:
                    if tensor is not None:
                        tensor.activation_offloading = True

1818
            with self.attention_dropout_ctx():
1819
                fa_optional_forward_kwargs = {}
1820
1821
                if _flash_attn_2_3_plus:
                    fa_optional_forward_kwargs["window_size"] = window_size
1822
1823
1824
1825
                if _flash_attn_2_4_plus:
                    fa_optional_forward_kwargs["alibi_slopes"] = alibi_slopes
                if _flash_attn_2_4_1_plus:
                    fa_optional_forward_kwargs["deterministic"] = self.deterministic
1826
                output = flash_attn_forward_func(
1827
                    query_layer, key_layer, value_layer,
1828
                    cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv,
1829
                    self.attention_dropout if self.training else 0.0,
1830
                    softmax_scale=1.0/self.norm_factor, causal="causal" in attn_mask_type,
1831
                    **fa_optional_forward_kwargs,
1832
                )
1833

1834
        if 'padding' in attn_mask_type:
1835
            output = UnpackTensor.apply(indices_q, batch_size * max_seqlen_q, output)
1836

1837
1838
1839
        if qkv_format == 'sbhd':
            # (bs)hd -> bs(hd) -> sb(hd)
            output = output.view(batch_size, max_seqlen_q, -1).transpose(0, 1).contiguous()
1840
        elif qkv_format == 'bshd':
1841
1842
            # (bs)hd -> bs(hd)
            output = output.view(batch_size, max_seqlen_q, -1).contiguous()
1843
1844
1845
        elif qkv_format == 'thd':
            # thd -> t(hd)
            output = output.view(output.shape[0], -1).contiguous()
1846
1847

        return output
1848
1849


1850
1851
1852
1853
1854
1855
class FusedAttnFunc_qkvpacked(torch.autograd.Function):
    """Function for FusedAttention with packed QKV input"""

    @staticmethod
    def forward(ctx, is_training, max_seqlen, cu_seqlens, qkv, qkv_dtype, attn_bias, attn_scale,
                dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type,
1856
                rng_gen, fused_attention_backend, use_FAv2_bwd):
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
        out, aux_ctx_tensors = fused_attn_fwd_qkvpacked(
            is_training, max_seqlen, cu_seqlens, qkv, qkv_dtype,
            fused_attention_backend, attn_bias,
            None, None, None, None, None,
            attn_scale, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type,
            rng_gen)

        ctx.save_for_backward(qkv, out, cu_seqlens)
        ctx.aux_ctx_tensors = aux_ctx_tensors
        ctx.max_seqlen = max_seqlen
        ctx.qkv_dtype = qkv_dtype
        ctx.attn_scale = attn_scale
        ctx.dropout_p = dropout_p
        ctx.fast_zero_fill = fast_zero_fill
        ctx.qkv_layout = qkv_layout
        ctx.attn_bias_type = attn_bias_type
        ctx.attn_mask_type = attn_mask_type
        ctx.fused_attention_backend = fused_attention_backend
1875
        ctx.use_FAv2_bwd = use_FAv2_bwd
1876
1877
1878
1879
1880

        return out

    @staticmethod
    def backward(ctx, d_out):
1881
        d_out = d_out.contiguous()
1882
        qkv, out, cu_seqlens = ctx.saved_tensors
1883
1884
        if not ctx.aux_ctx_tensors[0].is_contiguous():
            ctx.aux_ctx_tensors[0] = ctx.aux_ctx_tensors[0].contiguous()
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
        if ctx.use_FAv2_bwd:
            softmax_lse, rng_state = ctx.aux_ctx_tensors
            dqkv = torch.empty_like(qkv)
            maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
            d_out, q, k, v, out = [maybe_contiguous(x)
                for x in (d_out, qkv[:,0], qkv[:,1], qkv[:,2], out)]
            flash_attn_cuda_bwd(
                d_out, q, k, v, out, softmax_lse, dqkv[:,0], dqkv[:,1], dqkv[:,2],
                cu_seqlens, cu_seqlens, ctx.max_seqlen, ctx.max_seqlen,
                ctx.dropout_p, ctx.attn_scale, False,
1895
                "causal" in ctx.attn_mask_type, None, rng_state
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
            )
            dqkv = dqkv[..., :d_out.shape[-1]]
        else:
            dqkv, *rest = fused_attn_bwd_qkvpacked(
                ctx.max_seqlen, cu_seqlens, qkv, out, d_out,
                ctx.qkv_dtype, ctx.aux_ctx_tensors,
                ctx.fused_attention_backend,
                None, None, None, None, None, None, None, None, None,
                ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill,
                ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type)
1906

1907
1908
        # if no_bias or alibi, return dqkv
        if ctx.attn_bias_type in ["no_bias", "alibi"]:
1909
1910
1911
1912
1913
1914
1915
1916
            return (None, None, None, dqkv, None, None, None,
                    None, None, None, None, None, None,
                    None, None, None, None, None, None)
        # else, return (dqkv, dbias)
        return (None, None, None, dqkv, None, rest[0], None,
                None, None, None, None, None, None,
                None, None, None, None, None, None)

1917

1918
1919
1920
1921
1922
1923
1924
class FusedAttnFunc_kvpacked(torch.autograd.Function):
    """Function for FusedAttention with packed KV input"""

    @staticmethod
    def forward(ctx, is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
                q, kv, qkv_dtype, attn_bias, attn_scale, dropout_p, fast_zero_fill,
                qkv_layout, attn_bias_type, attn_mask_type,
1925
                rng_gen, fused_attention_backend, use_FAv2_bwd):
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
        out, aux_ctx_tensors = fused_attn_fwd_kvpacked(
            is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
            q, kv, qkv_dtype, fused_attention_backend, attn_bias,
            None, None, None, None, None,
            attn_scale, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type,
            rng_gen)

        ctx.save_for_backward(q, kv, out, cu_seqlens_q, cu_seqlens_kv)
        ctx.aux_ctx_tensors = aux_ctx_tensors
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_kv = max_seqlen_kv
        ctx.qkv_dtype = qkv_dtype
        ctx.attn_scale = attn_scale
        ctx.dropout_p = dropout_p
        ctx.fast_zero_fill = fast_zero_fill
        ctx.qkv_layout = qkv_layout
        ctx.attn_bias_type = attn_bias_type
        ctx.attn_mask_type = attn_mask_type
        ctx.fused_attention_backend = fused_attention_backend
1945
        ctx.use_FAv2_bwd = use_FAv2_bwd
1946
1947
1948
1949
1950

        return out

    @staticmethod
    def backward(ctx, d_out):
1951
        d_out = d_out.contiguous()
1952
        q, kv, out, cu_seqlens_q, cu_seqlens_kv = ctx.saved_tensors
1953
1954
        if not ctx.aux_ctx_tensors[0].is_contiguous():
            ctx.aux_ctx_tensors[0] = ctx.aux_ctx_tensors[0].contiguous()
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
        if ctx.use_FAv2_bwd:
            softmax_lse, rng_state = ctx.aux_ctx_tensors
            dq = torch.empty_like(q)
            dkv = torch.empty_like(kv)
            maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
            d_out, q, k, v, out = [maybe_contiguous(x)
                for x in (d_out, q, kv[:,0], kv[:,1], out)]
            flash_attn_cuda_bwd(
                d_out, q, k, v, out, softmax_lse, dq, dkv[:,0], dkv[:,1],
                cu_seqlens_q, cu_seqlens_kv, ctx.max_seqlen_q, ctx.max_seqlen_kv,
                ctx.dropout_p, ctx.attn_scale, False,
1966
                "causal" in ctx.attn_mask_type, None, rng_state
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
            )
            dq = dq[..., :d_out.shape[-1]]
            dkv = dkv[..., :d_out.shape[-1]]
        else:
            dq, dkv, *rest = fused_attn_bwd_kvpacked(
                ctx.max_seqlen_q, ctx.max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
                q, kv, out, d_out,
                ctx.qkv_dtype, ctx.aux_ctx_tensors,
                ctx.fused_attention_backend,
                None, None, None, None, None, None, None, None, None,
                ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill,
                ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type)
1979

1980
1981
        # if no_bias or alibi, return dqkv
        if ctx.attn_bias_type in ["no_bias", "alibi"]:
1982
1983
1984
1985
1986
1987
1988
1989
            return (None, None, None, None, None, dq, dkv, None, None, None,
                    None, None, None, None, None, None,
                    None, None, None, None, None, None)
        # else, return (dqkv, dbias)
        return (None, None, None, None, None, dq, dkv, None, rest[0], None,
                None, None, None, None, None, None,
                None, None, None, None, None, None)

1990

1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
class FusedAttnFunc(torch.autograd.Function):
    """Function for FusedAttention with separate Q, K, V tensors"""

    @staticmethod
    def forward(ctx, is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
                q, k, v, qkv_dtype, attn_bias, attn_scale, dropout_p, fast_zero_fill,
                qkv_layout, attn_bias_type, attn_mask_type,
                rng_gen, fused_attention_backend, use_FAv2_bwd):
        out, aux_ctx_tensors = fused_attn_fwd(
            is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
            q, k, v, qkv_dtype, fused_attention_backend, attn_bias,
            None, None, None, None, None,
            attn_scale, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type,
            rng_gen)

2006
2007
2008
2009
2010
2011
2012
2013
2014
        from .cpu_offload import CPUOffloadEnabled
        if CPUOffloadEnabled:
            tensor_list = [q, k, v, out, cu_seqlens_q, cu_seqlens_kv]
            qkv_layout = 'sbhd_sbhd_sbhd'
            for tensor in tensor_list:
                if tensor is not None:
                    tensor.activation_offloading = True


2015
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025
2026
2027
2028
2029
2030
2031
2032
        ctx.save_for_backward(q, k, v, out, cu_seqlens_q, cu_seqlens_kv)
        ctx.aux_ctx_tensors = aux_ctx_tensors
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_kv = max_seqlen_kv
        ctx.qkv_dtype = qkv_dtype
        ctx.attn_scale = attn_scale
        ctx.dropout_p = dropout_p
        ctx.fast_zero_fill = fast_zero_fill
        ctx.qkv_layout = qkv_layout
        ctx.attn_bias_type = attn_bias_type
        ctx.attn_mask_type = attn_mask_type
        ctx.fused_attention_backend = fused_attention_backend
        ctx.use_FAv2_bwd = use_FAv2_bwd

        return out

    @staticmethod
    def backward(ctx, d_out):
2033
        d_out = d_out.contiguous()
2034
        q, k, v, out, cu_seqlens_q, cu_seqlens_kv = ctx.saved_tensors
2035
2036
        if not ctx.aux_ctx_tensors[0].is_contiguous():
            ctx.aux_ctx_tensors[0] = ctx.aux_ctx_tensors[0].contiguous()
2037
2038
2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
        if ctx.use_FAv2_bwd:
            softmax_lse, rng_state = ctx.aux_ctx_tensors
            dq = torch.empty_like(q)
            dk = torch.empty_like(k)
            dv = torch.empty_like(v)
            maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
            d_out, q, k, v, out = [maybe_contiguous(x)
                for x in (d_out, q, k, v, out)]
            flash_attn_cuda_bwd(
                d_out, q, k, v, out, softmax_lse, dq, dk, dv,
                cu_seqlens_q, cu_seqlens_kv, ctx.max_seqlen_q, ctx.max_seqlen_kv,
                ctx.dropout_p, ctx.attn_scale, False,
2049
                "causal" in ctx.attn_mask_type, None, rng_state
2050
2051
2052
2053
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
            )
            dq = dq[..., :d_out.shape[-1]]
            dk = dk[..., :d_out.shape[-1]]
            dv = dv[..., :d_out.shape[-1]]
        else:
            dq, dk, dv, *rest = fused_attn_bwd(
                ctx.max_seqlen_q, ctx.max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
                q, k, v, out, d_out,
                ctx.qkv_dtype, ctx.aux_ctx_tensors,
                ctx.fused_attention_backend,
                None, None, None, None, None, None, None, None, None,
                ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill,
                ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type)

2064
2065
        # if no_bias or alibi, return dqkv
        if ctx.attn_bias_type in ["no_bias", "alibi"]:
2066
2067
2068
2069
2070
2071
2072
2073
            return (None, None, None, None, None, dq, dk, dv, None, None, None,
                    None, None, None, None, None, None,
                    None, None, None, None, None, None)
        # else, return (dqkv, dbias)
        return (None, None, None, None, None, dq, dk, dv, None, rest[0], None,
                None, None, None, None, None, None,
                None, None, None, None, None, None)

2074

2075
2076
2077
2078
2079
2080
2081
2082
2083
2084
class FusedAttention(torch.nn.Module):
    """Dot product attention, with multiple backends:

    1. FusedAttnBackend["F16_max512_seqlen"]
       cuDNN based fused attention for FP16/BF16 and <=512 sequence length.
    2. FusedAttnBackend["F16_arbitrary_seqlen"]
       cuDNN based fused attention for FP16/BF16 and any sequence length.

    Support matrix:

2085
2086
2087
2088
    | backend       | 1                       | 2                              |
    | flash based   | no                      | yes                            |
    | cuDNN based   | yes                     | yes                            |
    | qkv dtype     | fp16/bf16               | fp16/bf16                      |
2089
    | attn_type     | self/cross              | self/cross                     |
2090
    | qkv_layout    |                         |                                |
2091
    |  - (q,k,v)    | sb3hd, bs3hd            | sb3hd, bs3hd, sbh3d, bsh3d     |
2092
    |               | sbhd_sb2hd, bshd_bs2hd  | sbhd_sb2hd, bshd_bs2hd         |
2093
2094
    |               | bshd_bshd_bshd          | sbhd_sbh2d, bshd_bsh2d         |
    |               |                         | sbhd_sbhd_sbhd, bshd_bshd_bshd |
2095
2096
    | mask_type     | causal/padding/no_mask  | causal/padding/no_mask         |
    | bias_type     | post_scale_bias/no_bias | post_scale_bias/alibi/no_bias  |
2097
    | dropout       | yes                     | yes                            |
2098
2099
    | max_seqlen    | <=512, multiple of 64   | any, multiple of 64            |
    | head_dim      | 64                      | <=128, multiple of 8           |
2100
    | output dtype  | fp16/bf16               | fp16/bf16                      |
2101
2102
2103
2104
2105
2106
2107
2108
    """

    def __init__(
        self,
        norm_factor: float,
        attention_dropout: float = 0.0,
        attention_dropout_ctx: Optional[Callable] = nullcontext,
        attention_type: str = "self",
2109
2110
        layer_number: Optional[int] = None,
        deterministic: bool = False,
2111
2112
2113
2114
2115
2116
2117
    ) -> None:
        super().__init__()

        self.norm_factor = norm_factor
        self.attention_dropout = attention_dropout
        self.attention_dropout_ctx = attention_dropout_ctx
        self.attention_type = attention_type
2118
        self.use_FAv2_bwd = (os.getenv("NVTE_FUSED_ATTN_USE_FAv2_BWD", "0") == "1"
Tim Moon's avatar
Tim Moon committed
2119
                        and get_device_compute_capability() == (9, 0))
2120
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134
2135
        self.layer_number = 1 if layer_number is None else layer_number
        if deterministic:
            # workspace optimization path is deterministic
            os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "-1"

        # CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT
        # - unset:       enables workspace optimization when required workspace is <= 256MB
        #                or when bias gradient needs to be computed
        # - n:           enables workspace optimization when required workspace is <= n bytes
        # - -1:          enables workspace optimization always
        # - 0:           disables workspace optimization always
        if "NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT" in os.environ:
            if os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] == "0":
                os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "0"
            if os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] == "1":
                os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "-1"
2136

2137
    @no_torch_dynamo()
2138
2139
2140
2141
2142
    def forward(
        self,
        query_layer: torch.Tensor,
        key_layer: torch.Tensor,
        value_layer: torch.Tensor,
2143
2144
2145
        qkv_layout: str = "sbh3d",
        cu_seqlens_q: Optional[torch.Tensor] = None,
        cu_seqlens_kv: Optional[torch.Tensor] = None,
2146
2147
        max_seqlen_q: Optional[int] = None,
        max_seqlen_kv: Optional[int] = None,
2148
        attn_mask_type: str = "causal",
2149
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
2150
2151
        fused_attention_backend:
            tex.NVTE_Fused_Attn_Backend = tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend,
2152
2153
2154
        core_attention_bias_type: str = "no_bias",
        core_attention_bias: Optional[torch.Tensor] = None,
        fast_zero_fill: bool = True,
2155
2156
2157
        cp_group: Optional[dist_group_type] = None,
        cp_global_ranks: List[int] = None,
        cp_stream: torch.cuda.Stream = None,
2158
2159
2160
    ) -> torch.Tensor:
        """fused attention fprop"""

2161
        assert (fused_attention_backend
2162
2163
            != tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend
            ), 'No fused attention backend supports this input combination!'
2164
2165
2166
2167
2168
2169
2170
2171
        assert (
            (query_layer.dtype in [torch.float16, torch.bfloat16])
            and (key_layer.dtype in [torch.float16, torch.bfloat16])
            and (value_layer.dtype in [torch.float16, torch.bfloat16])
            ), 'FusedAttention only supports FP16 and BF16 data types.'
        assert (
            query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda
            ), 'FusedAttention only supports CUDA tensors.'
2172
2173
2174
2175
        assert (
            qkv_layout in QKVLayouts
            ), f"FusedAttention does not support qkv_layout = {qkv_layout}!"

2176
2177
        context_parallel = (cp_group is not None) and (get_distributed_world_size(cp_group) != 1)

2178
        qkv_format = ''.join([i for i in qkv_layout.split('_')[0] if i.isalpha()])
2179
2180
2181
2182
        assert (
            qkv_format != 'thd'
            ), 'FusedAttention does not support qkv_format = thd!'

2183
2184
2185
2186
2187
2188
2189
        if qkv_format in ['sbhd', 'bshd']:
            if qkv_format == 'sbhd':
                batch_size, max_seqlen_q, max_seqlen_kv = (
                    query_layer.shape[1], query_layer.shape[0], key_layer.shape[0])
            if qkv_format == 'bshd':
                batch_size, max_seqlen_q, max_seqlen_kv = (
                    query_layer.shape[0], query_layer.shape[1], key_layer.shape[1])
2190
            if 'padding' in attn_mask_type:
2191
2192
                assert not context_parallel, "Padding mask not supported with context parallelism!"

2193
2194
2195
2196
2197
                if cu_seqlens_q is None or cu_seqlens_kv is None:
                    if attention_mask is None:
                        raise RuntimeError(
                            "Please provide attention_mask or cu_seqlens for padding!"
                        )
2198
                    if self.attention_type == "self":
2199
2200
                        cu_seqlens_q = get_cu_seqlens(attention_mask)
                        cu_seqlens_kv = cu_seqlens_q
2201
                    else:
2202
2203
                        cu_seqlens_q = get_cu_seqlens(attention_mask[0])
                        cu_seqlens_kv = get_cu_seqlens(attention_mask[1])
2204
            else:
2205
2206
2207
2208
2209
2210
2211
2212
2213
2214
2215
2216
                if cu_seqlens_q is None:
                    cu_seqlens_q = _get_full_cu_seqlens(
                        batch_size,
                        max_seqlen_q,
                        query_layer.device,
                    )
                if cu_seqlens_kv is None:
                    cu_seqlens_kv = _get_full_cu_seqlens(
                        batch_size,
                        max_seqlen_kv,
                        key_layer.device,
                    )
2217
2218
2219

        qkv_dtype = TE_DType[query_layer.dtype]

2220
        use_FAv2_bwd = (self.use_FAv2_bwd
2221
                and (core_attention_bias_type == "no_bias")
2222
2223
                and (fused_attention_backend
                    == tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen))
2224
2225
2226
2227
2228
2229
2230
2231
2232
2233
2234
2235
2236
2237
2238
2239
2240
2241
2242
2243
2244
2245
2246
2247
2248
2249
2250
2251
2252
2253
2254
2255
2256
2257
2258
2259
2260
2261
2262
2263
2264
2265
2266

        if context_parallel:
            assert (fused_attention_backend
                == tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen
                ), f"{fused_attention_backend} does not work with context parallelism!"
            assert (core_attention_bias_type == "no_bias"), \
                "Core attention bias has not been supported with context parallelism yet!"
            if qkv_format == 'sbhd':
                query_layer, key_layer, value_layer = [x.transpose(0,1).contiguous()
                    for x in (query_layer, key_layer, value_layer)]
            with self.attention_dropout_ctx():
                output = attn_forward_func_with_cp(
                    self.training,
                    query_layer, key_layer, value_layer,
                    cu_seqlens_q, cu_seqlens_kv,
                    max_seqlen_q, max_seqlen_kv,
                    self.attention_dropout if self.training else 0.0,
                    cp_group, cp_global_ranks, cp_stream,
                    softmax_scale=1.0/self.norm_factor,
                    attn_mask_type=attn_mask_type,
                    use_fused_attention=True,
                )
            if qkv_format == 'sbhd':
                output = output.transpose(0,1).contiguous()
        else:
            with self.attention_dropout_ctx():
                output = FusedAttnFunc.apply(
                    self.training,
                    max_seqlen_q, max_seqlen_kv,
                    cu_seqlens_q, cu_seqlens_kv,
                    query_layer, key_layer, value_layer,
                    qkv_dtype,
                    core_attention_bias,
                    1.0/self.norm_factor,
                    self.attention_dropout if self.training else 0.0,
                    fast_zero_fill,
                    qkv_layout,
                    core_attention_bias_type,
                    attn_mask_type,
                    None, # rng_gen
                    fused_attention_backend,
                    use_FAv2_bwd,
                )
2267

2268
2269
        # ...hd -> ...(hd)
        return output.view(*output.shape[:-2], -1)
2270
2271


2272
2273
2274
2275
2276
2277
2278
class DotProductAttention(torch.nn.Module):
    """Allows the model to jointly attend to information from different
    representation subspaces as described in the paper:
    `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.

    .. note::

2279
        Argument :attr:`attention_mask` in the `forward` call is only used when
2280
        :attr:`attn_mask_type` includes '"padding"' or `"arbitrary"`.
2281
2282
2283

    .. warning::

2284
        FlashAttention uses a non-deterministic algorithm for optimal performance. To observe
2285
        deterministic behavior at the cost of performance, use FlashAttention version >= `2.4.1`
2286
2287
        and set the environment variable :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO=0`. In order
        to disable`flash-attn` entirely, set :attr:`NVTE_FLASH_ATTN=0`.
2288
2289
2290
2291
2292
2293
2294

    Parameters
    ----------
    num_attention_heads : int
                         number of attention heads in the transformer layer.
    kv_channels : int
                number of key-value channels.
2295
2296
2297
2298
2299
2300
2301
2302
    num_gqa_groups : Optional[int] = None
                    number of GQA groups in the transformer layer.
                    Grouped Query Attention is described in
                    `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
                    This only affects the keys and values, not the queries.
                    GQA-1 is equivalent to Multi-Query Attention
                    (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
                    is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
2303
2304
    attention_dropout: float, default = 0.0
                      dropout probability for the dropout op during multi-head attention.
2305
    attn_mask_type: str, default = `causal`
2306
2307
2308
2309
2310
2311
2312
2313
2314
2315
2316
2317
2318
2319
                   type of attention mask passed into softmax operation, options are "`no_mask`",
                   "`padding`", "`causal`", "`padding,causal`", "`causal,padding`", and
                   "`arbitrary`", where "`padding,causal`" and "`causal,padding`" are equivalent.
                   This arg can be overridden by :attr:`attn_mask_type` in the `forward` method.
                   It is useful for cases involving compilation/tracing, e.g. ONNX export, and the
                   forward arg is useful for dynamically changing mask types, e.g. a different mask
                   for training and inference. For "`no_mask`", no attention mask is applied. For
                   "`causal`" or the causal mask in "`padding,causal`", TransformerEngine calculates
                   and applies an upper triangular mask to the softmax input. No user input is
                   needed. For "`padding`" or the padding mask in "`padding,causal`", users need to
                   provide the locations of padded tokens either via :attr:`cu_seqlens_q` and
                   :attr:`cu_seqlens_kv` in the shape of [batch_size + 1] or :attr:`attention_mask`
                   in the shape [batch_size, 1, 1, max_seq_len]. For the "`arbitrary`" mask, users
                   need to provide a mask that is broadcastable to the shape of softmax input.
2320
2321
2322
2323
2324
2325
    window_size: Optional[Tuple[int, int]], default = `None`
                sliding window size for local attention, where query at position i attends to keys
                in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
                + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
                window and causal mask specifically. Similar to :attr:`attn_mask_type`, it can
                be overridden by :attr:`window_size` in `forward` as well.
2326
2327
    attention_type: str, default = `self`
                   type of attention, either "`self`" and "`cross`".
2328
2329
2330
    layer_number: int, default = `None`
                 layer number of the current `DotProductAttention` when multiple such modules
                 are concatenated, for instance in consecutive transformer blocks.
2331
2332
2333
2334
2335
2336
2337
2338
2339
2340
    qkv_format: str, default = `sbhd`
               dimension format for `query_layer`, `key_layer` and `value_layer`,
               {`sbhd`, `bshd`, `thd`}. `s` stands for the sequence length, `b` batch size,
               `h` the number of heads, `d` head size, and `t` the total number of sequences
               in a batch, with `t = sum(s_i), for i = 0...b-1`. `sbhd` and `bshd` formats
               are used for when sequences in a batch are of equal length or padded to
               equal length, and the `thd` format is used for when sequences in a batch
               have different lengths. Please note that these formats do not reflect how
               tensors `query_layer`, `key_layer`, `value_layer` are laid out in memory.
               For that, please use `_get_qkv_layout` to gain the layout information.
2341
2342
2343
2344
2345
2346
2347
2348
2349

    Parallelism parameters
    ----------------------
    sequence_parallel : bool, default = `False`
                       if set to `True`, uses sequence parallelism.
    tp_size : int, default = 1
             tensor parallel world size.
    tp_group : ProcessGroup, default = `None`
              tensor parallel process group.
2350
2351
2352
2353
2354
2355
2356
2357
2358
    cp_group : ProcessGroup, default = `None`
              context parallel process group.
    cp_global_ranks : list of global rank IDs, default = `None`
                     global rank IDs of GPUs that are in cp_group.
    cp_stream : CUDA stream, default = `None`
               context parallelism splits flash attention into multiple steps for
               compute and communication overlapping. To address the wave quantization
               issue of each split step, we add an additional CUDA stream so that we
               can overlap two flash attention kernels.
2359
2360
2361
2362
2363
2364
    """

    def __init__(
        self,
        num_attention_heads: int,
        kv_channels: int,
2365
        num_gqa_groups: Optional[int] = None,
2366
        attention_dropout: float = 0.0,
2367
        qkv_format: str = "sbhd",
2368
        attn_mask_type: str = "causal",
2369
        window_size: Optional[Tuple[int, int]] = None,
2370
2371
2372
2373
2374
        sequence_parallel: bool = False,
        tp_size: int = 1,
        get_rng_state_tracker: Optional[Callable] = None,
        tp_group: Optional[dist_group_type] = None,
        layer_number: Optional[int] = None,
2375
        attention_type: str = "self",
2376
        cp_group: Optional[dist_group_type] = None,
2377
        cp_global_ranks: List[int] = None,
2378
        cp_stream: torch.cuda.Stream = None,
2379
2380
2381
    ) -> None:
        super().__init__()

2382
        self.qkv_format = qkv_format
2383
2384
2385
        attn_mask_type = attn_mask_type.replace(",","_")
        if attn_mask_type == "causal_padding":
            attn_mask_type = "padding_causal"
2386
        self.attn_mask_type = attn_mask_type
2387
2388
        self.window_size = window_size
        self.window_size = check_set_window_size(attn_mask_type, self.window_size)
2389
        self.tp_size = tp_size if tp_group is None else get_distributed_world_size(tp_group)
2390
2391
        self.tp_group = tp_group
        self.get_rng_state_tracker = get_rng_state_tracker
2392
        self.num_attention_heads = num_attention_heads
2393
        self.layer_number = 1 if layer_number is None else layer_number
2394
2395
2396
        self.cp_group = cp_group
        self.cp_global_ranks = cp_global_ranks
        self.cp_stream = cp_stream
2397

2398
2399
2400
        self.hidden_size_per_attention_head = kv_channels
        self.num_gqa_groups = (
            num_attention_heads if num_gqa_groups is None else num_gqa_groups
2401
        )
2402
2403
2404
2405
        self.num_gqa_groups_per_partition = int(self.num_gqa_groups // tp_size)

        assert (num_attention_heads % self.num_gqa_groups == 0
                ), "The number of attention heads must be divisible by the number of GQA groups!"
2406
2407
2408
2409
2410
2411
2412
2413
2414

        if sequence_parallel or get_rng_state_tracker is None:
            attention_dropout_ctx = nullcontext
        else:
            attention_dropout_ctx = get_rng_state_tracker().fork

        norm_factor = math.sqrt(self.hidden_size_per_attention_head)

        self.device_compute_capability = get_device_compute_capability()
2415
2416
        self.deterministic = not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1"))) \
                             or torch.are_deterministic_algorithms_enabled()
2417

2418
2419
        self.use_flash_attention = (
            int(os.getenv("NVTE_FLASH_ATTN", "1"))
Tim Moon's avatar
Tim Moon committed
2420
            and self.device_compute_capability >= (8, 0)
2421
        )
2422
        if not _flash_attn_2_4_1_plus and self.deterministic:
2423
2424
            self.use_flash_attention = False
            warnings.warn(
2425
2426
2427
                "Disabling usage of FlashAttention since version <2.4.1 does not support "
                "deterministic execution. In order to use FA with deterministic behavior,"
                " please install FlashAttention version >=2.4.1."
2428
2429
            )

2430
2431
        self.use_fused_attention = (
            int(os.getenv("NVTE_FUSED_ATTN", "1"))
Tim Moon's avatar
Tim Moon committed
2432
            and self.device_compute_capability >= (8, 0)
2433
        )
2434

2435
2436
2437
2438
2439
2440
2441
        assert (
            attention_type in AttnTypes
        ), f"attention_type {attention_type} not supported"

        self.attention_type = attention_type
        self.attention_dropout = attention_dropout

2442
2443
2444
2445
2446
2447
        attn_kwargs = {
            "attention_dropout": attention_dropout,
            "attention_dropout_ctx": attention_dropout_ctx,
        }

        if self.use_flash_attention:
2448
2449
2450
2451
2452
2453
            self.flash_attention = FlashAttention(norm_factor,
                                                  attention_type=attention_type,
                                                  layer_number=layer_number,
                                                  deterministic=self.deterministic,
                                                  **attn_kwargs)

2454
        # Instantiating three types since use of flash-attn and FusedAttention
2455
        # might be ruled out due to forward inputs.
2456
        if self.use_fused_attention:
2457
2458
2459
2460
2461
            self.fused_attention = FusedAttention(norm_factor,
                                                  attention_type=attention_type,
                                                  layer_number=layer_number,
                                                  deterministic=self.deterministic,
                                                  **attn_kwargs)
2462
2463
2464
2465
2466
2467
2468
        self.unfused_attention = UnfusedDotProductAttention(
            norm_factor, **attn_kwargs, layer_number=layer_number)

    def _checkpointed_attention_forward(
        self,
        attention_func: Callable,
        *forward_args: Tuple[torch.Tensor, ...],
2469
        **forward_kwargs: Dict[str, Any],
2470
2471
2472
    ) -> torch.Tensor:
        """Forward method with activation checkpointing."""

2473
2474
        def custom_forward(*input_args, **input_kwargs):
            return attention_func(*input_args, **input_kwargs)
2475
2476
2477

        hidden_states = checkpoint(
            custom_forward,
2478
2479
2480
            distribute_saved_activations=False,
            get_rng_state_tracker=self.get_rng_state_tracker,
            tp_group=self.tp_group,
2481
            *forward_args,
2482
            **forward_kwargs,
2483
2484
2485
2486
        )

        return hidden_states

2487
2488
2489
2490
2491
2492
    def set_context_parallel_group(
        self,
        cp_group: Union[dist_group_type, None],
        cp_global_ranks: List[int],
        cp_stream: torch.cuda.Stream,
    ) -> None:
2493
2494
2495
2496
2497
2498
2499
2500
2501
2502
2503
2504
2505
        """
        Set the context parallel attributes for the given
        module before executing the forward pass.

        Parameters
        ----------
        cp_group : ProcessGroup
                  context parallel process group.
        cp_global_ranks : List[int]
                         list of global ranks in the context group.
        cp_stream : torch.cuda.Stream
                   cuda stream for context parallel execution.
        """
2506
2507
2508
2509
        self.cp_group = cp_group
        self.cp_global_ranks = cp_global_ranks
        self.cp_stream = cp_stream

2510
    @no_torch_dynamo(recursive=False)
2511
2512
2513
2514
2515
    def forward(
        self,
        query_layer: torch.Tensor,
        key_layer: torch.Tensor,
        value_layer: torch.Tensor,
2516
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
2517
2518
2519
        qkv_format: Optional[str] = None,
        cu_seqlens_q: Optional[torch.Tensor] = None,
        cu_seqlens_kv: Optional[torch.Tensor] = None,
2520
2521
        max_seqlen_q: Optional[int] = None,
        max_seqlen_kv: Optional[int] = None,
2522
        attn_mask_type: Optional[str] = None,
2523
        window_size: Optional[Tuple[int, int]] = None,
2524
        checkpoint_core_attention: bool = False,
2525
2526
        core_attention_bias_type: str = "no_bias",
        core_attention_bias: Optional[torch.Tensor] = None,
2527
        alibi_slopes: Optional[torch.Tensor] = None,
2528
        fast_zero_fill: bool = True,
2529
        inference_params: Optional[InferenceParams] = None,
2530
2531
2532
2533
2534
2535
    ) -> torch.Tensor:
        """
        Dot Product Attention Layer.

        .. note::

2536
2537
            Argument :attr:`attention_mask` is only used when :attr:`attn_mask_type`
            includes '"padding"' or `"arbitrary"`.
2538
2539
2540
2541
2542
2543
2544
2545
2546

        .. note::

            Input tensors :attr:`query_layer`, :attr:`key_layer`, and :attr:`value_layer`
            must each be of shape (:attr:`sequence_length`, :attr:`batch_size`,
            :attr:`num_attention_heads`, :attr:`kv_channels`). Output of shape
            (:attr:`sequence_length`, :attr:`batch_size`, :attr:`num_attention_heads`
            * :attr:`kv_channels`) is returned.

2547
2548
        .. note::

2549
2550
2551
2552
2553
2554
2555
2556
2557
2558
2559
2560
2561
2562
2563
2564
2565
2566
            DotProductAttention supports three backends: 1) FlashAttention which calls
            HazyResearch/Dao-AILab's `flash-attn <https://arxiv.org/pdf/2305.13245.pdf>`_
            PyTorch API, 2) FusedAttention which has multiple fused attention implementations
            based on `cuDNN Graph API
            <https://docs.nvidia.com/deeplearning/cudnn/developer-guide/index.html#op-fusion>`_
            (see :attr:`FusedAttention` for more details on FusedAttention backends), and 3)
            UnfusedDotProductAttention which is the native PyTorch implementation
            with fused scaled masked softmax.

        .. note::

            Users can use environment variables :attr:`NVTE_FLASH_ATTN`, :attr:`NVTE_FUSED_ATTN`,
            and :attr:`NVTE_FUSED_ATTN_BACKEND` to control which DotProductAttention backend,
            and FusedAttention backend if applicable, to use. TransformerEngine prioritizes
            FlashAttention over FusedAttention and over UnfusedDotProductAttention.
            If FusedAttention is being used, users can also choose to switch to flash-attn's
            implementation for backward by setting :attr:`NVTE_FUSED_ATTN_USE_FAv2_BWD=1`
            (default: 0), because of the performance differences between various versions of
2567
2568
2569
2570
2571
            flash-attn and FusedAttention. Further, :attr:`NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT`
            can be used to enable (:attr:`1`) or disable (:attr:`0`) the workspace related
            optimizations in FusedAttention. When unset, TransformerEngine determines the code path
            based on its internal logic. These optimizations trade memory for performance
            and should be used with care.
2572

2573
2574
2575
2576
2577
2578
2579
2580
        Parameters
        ----------
        query_layer : torch.Tensor
                     Query tensor.
        key_layer : torch.Tensor
                   Key tensor.
        value_layer : torch.Tensor
                     Value tensor.
2581
2582
2583
2584
2585
2586
2587
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]],
             default = `None`. Boolean tensor(s) used to mask out attention softmax input.
             It should be 'None' for 'causal' and 'no_mask' types. For 'padding' masks, it should be
             a single tensor of [batch_size, 1, 1, seqlen_q] for self-attention, and a tuple of
             two tensors in shapes [batch_size, 1, 1, seqlen_q] and [batch_size, 1, 1, seqlen_kv]
             for cross-attention. For the 'arbitrary' mask type, it should be in a shape that is
             broadcastable to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv].
2588
2589
2590
2591
2592
2593
2594
2595
        qkv_format: str, default = `None`
                   If provided, overrides :attr:`qkv_format` from initialization.
        cu_seqlens_q: Optional[torch.Tensor], default = `None`
                   Cumulative sum of sequence lengths in a batch for `query_layer`,
                   with shape [batch_size + 1] and dtype torch.int32.
        cu_seqlens_kv: Optional[torch.Tensor], default = `None`
                   Cumulative sum of sequence lengths in a batch for `key_layer` and `value_layer`,
                   with shape [batch_size + 1] and dtype torch.int32.
2596
2597
2598
2599
2600
2601
        max_seqlen_q: Optional[int], default = `None`
                      Maximum sequence length in `query_layer`.
                      Calculated from `cu_seqlens_q` if not provided.
        max_seqlen_kv: Optional[int], default = `None`
                       Maximum sequence length in `key_layer` and `value_layer`.
                       Calculated from `cu_seqlens_kv` if not provided.
2602
2603
2604
        attn_mask_type: {`no_mask`, `padding`, `causal`, `padding,causal`, `causal,padding`,
                       `arbitrary`}, default = `None`. Type of attention mask passed into
                       softmax operation. 'padding,causal' and 'causal,padding' are equivalent.
2605
        window_size: Optional[Tuple[int, int]], default = `None`
2606
                    Sliding window size for local attention.
2607
2608
2609
2610
2611
        checkpoint_core_attention : bool, default = `False`
                                   If true, forward activations for attention are recomputed
                                   during the backward pass in order to save memory that would
                                   otherwise be occupied to store the forward activations until
                                   backprop.
2612
        core_attention_bias_type: str, default = `no_bias`
2613
                    Bias type, {`no_bias`, `pre_scale_bias`, `post_scale_bias`, `alibi`}
2614
        core_attention_bias: Optional[torch.Tensor], default = `None`
2615
2616
                    Bias tensor for Q * K.T, shape [1, num_head, max_seqlen_q, max_seqlen_kv].
                    It should be 'None' for 'no_bias' and 'alibi' bias types.
2617
2618
2619
2620
        alibi_slopes: Optional[torch.Tensor], default = `None`
                     ALiBi slopes in FP32 and shape [nheads] or [batch_size, nheads].
                     It adds a bias of (-alibi_slope * (i + seqlen_k - seqlen_q - j))
                     to the attention score of query i and key j.
2621
        fast_zero_fill: bool, default = `True`
2622
                    Whether to use the fast path to set output tensors to 0 or not.
2623
2624
2625
2626
2627
2628
2629
2630
2631
2632
        inference_params: Optional[InferenceParams], default = `None`
            Optimizes execution performance during inference by caching Keys and Values of the
            current decoding iteration. These cached values are appended to the K and V values
            computed in previous iterations, eliminating the need to recalculate them for the
            entire sequence.
            Initialization of `inference_params` is required prior to use to ensure sufficient
            memory allocation.
            Adjustments of the sequence_len_offset should be done after a complete forward pass.
            If rotary positional embeddings (RoPE) are utilized, they must be prepared beforehand.
            Supports "sbhd" and "bshd" layouts, with the "sbhd" layout being more efficient.
2633
2634
        """

2635
2636
2637
2638
        assert (
            query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda
            ), 'DotProductAttention only supports CUDA tensors.'

2639
2640
2641
        assert (key_layer.shape == value_layer.shape
            ), "Keys and values must have the same shape!"

2642
2643
        if attn_mask_type is not None:
            window_size = check_set_window_size(attn_mask_type, window_size)
2644
        if attn_mask_type is None:
2645
            attn_mask_type = self.attn_mask_type
2646
2647
2648
2649
2650
2651
2652
2653
        else:
            attn_mask_type = attn_mask_type.replace(",","_")
            if attn_mask_type == "causal_padding":
                attn_mask_type = "padding_causal"

        assert (attn_mask_type in AttnMaskTypes
            ), f"Attention mask type {attn_mask_type} is not supported!"

2654
2655
2656
        if window_size is None:
            window_size = self.window_size

2657
2658
        if qkv_format is None:
            qkv_format = self.qkv_format
2659

2660
2661
2662
2663
2664
2665
2666
2667
2668
2669
2670
2671
2672
2673
2674
2675
2676
2677
2678
2679
2680
2681
2682
2683
2684
2685
2686
2687
2688
2689
2690
2691
2692
        if inference_params is not None:
            assert self.layer_number is not None, "Layer number must be set!"

            if qkv_format == "bshd":
                key_layer = key_layer.transpose(0, 1)
                value_layer = value_layer.transpose(0, 1)

            (inference_key_memory, inference_value_memory,
            ) = inference_params.key_value_memory_dict[self.layer_number]

            batch_start = inference_params.batch_size_offset
            batch_end = batch_start + key_layer.size(1)
            assert batch_end <= inference_key_memory.size(1)

            sequence_start = inference_params.sequence_len_offset
            sequence_end = sequence_start + key_layer.size(0)
            assert sequence_end <= inference_key_memory.size(0)

            # Copy keys and values into KV-cache
            inference_key_memory[
                sequence_start:sequence_end, batch_start:batch_end, ...] = key_layer
            inference_value_memory[
                sequence_start:sequence_end, batch_start:batch_end, ...] = value_layer
            key_layer = inference_key_memory[:sequence_end, batch_start:batch_end, ...]
            value_layer = inference_value_memory[:sequence_end, batch_start:batch_end, ...]

            if qkv_format == "bshd":
                key_layer = key_layer.transpose(0, 1)
                value_layer = value_layer.transpose(0, 1)

            key_layer = key_layer.contiguous()
            value_layer = value_layer.contiguous()

2693
        assert (key_layer.shape[-2] == self.num_gqa_groups_per_partition
2694
2695
2696
2697
2698
2699
2700
2701
2702
2703
2704
2705
2706
2707
2708
2709
2710
            and value_layer.shape[-2] == self.num_gqa_groups_per_partition
            ), f"Keys and values must have num_gqa_group = {self.num_gqa_groups} heads!"
        assert (qkv_format in ['sbhd', 'bshd', 'thd']
            ), "DotProductAttention only supports qkv_format = {'sbhd', 'bshd', 'thd'}!"

        if qkv_format == 'thd':
            assert (all(len(x.shape) == 3 for x in (query_layer, key_layer, value_layer))
                ), "Queries, keys and values must be 3D tensors when qkv_format = thd!"
            assert (cu_seqlens_q is not None and cu_seqlens_kv is not None
                ), "cu_seqlens_q and cu_seqlens_kv can not be None when qkv_format = thd!"
            assert (cu_seqlens_q.shape == cu_seqlens_kv.shape
                and len(cu_seqlens_q.shape) == 1
                and len(cu_seqlens_kv.shape) == 1
                ), "cu_seqlens_q and cu_seqlens_q must both have shape [batch_size + 1]!"
            assert (cu_seqlens_q.dtype == torch.int32
                and cu_seqlens_kv.dtype == torch.int32
                ), "cu_seqlens_q and cu_seqlens_q must both be in dtype torch.int32!"
2711
2712
2713
2714
2715
2716
            if max_seqlen_q is None:
                seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
                max_seqlen_q = seqlens_q.max().item()
            if max_seqlen_kv is None:
                seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
                max_seqlen_kv = seqlens_kv.max().item()
2717
2718
2719
2720
2721
2722
2723
2724
2725
2726
2727
2728
2729
2730
2731
2732
2733
2734
2735

        if qkv_format in ['sbhd', 'bshd']:
            assert (all(len(x.shape) == 4 for x in (query_layer, key_layer, value_layer))
                ), f"Queries, keys and values must be 4D tensors when qkv_format = {qkv_format}!"
            if qkv_format == 'sbhd':
                max_seqlen_q, max_seqlen_kv = (query_layer.shape[0], key_layer.shape[0])
            if qkv_format == 'bshd':
                max_seqlen_q, max_seqlen_kv = (query_layer.shape[1], key_layer.shape[1])
            if cu_seqlens_q is not None:
                seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
                assert (all(seqlens_q <= max_seqlen_q)
                    ), """Sequence lengths indicated by cu_seqlens_q must be no greater than
                    the sequence dimention in 'query_layer'!"""
            if cu_seqlens_kv is not None:
                seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
                assert (all(seqlens_kv <= max_seqlen_kv)
                    ), """Sequence lengths indicated by cu_seqlens_kv must be no greater than
                    the sequence dimention in 'key_layer' and 'value_layer'!"""

2736
2737
        qkv_layout, query_layer, key_layer, value_layer = _get_qkv_layout(
            query_layer, key_layer, value_layer, qkv_format = qkv_format)
2738

2739
2740
        # The priority for attention backends (subject to availability and clearing the filters)
        # is: FlashAttention > FusedAttention (cuDNN) > UnfusedDotProductAttention.
2741
        use_flash_attention = self.use_flash_attention
2742
        use_fused_attention = self.use_fused_attention
2743
        use_unfused_attention = True
2744

2745
2746
2747
        # The following section filters out some backends based on
        # certain asserts before executing the forward pass.

2748
2749
2750
2751
2752
        # Filter: ONNX export.
        if is_in_onnx_export_mode():
            use_flash_attention = False
            use_fused_attention = False

2753
        # Filter: Input type.
2754
2755
2756
2757
2758
        if (query_layer.dtype not in [torch.bfloat16, torch.float16]
            or key_layer.dtype not in [torch.bfloat16, torch.float16]
            or value_layer.dtype not in [torch.bfloat16, torch.float16]
        ):
            use_flash_attention = False
2759
            use_fused_attention = False
2760

2761
        # Filter: Device and dimensions.
2762
        # FAv2 supports head_dim <= 256, and for >192 requires sm80/sm90
2763
2764
2765
2766
2767
        # FAv2 requires head_dim % 8 == 0
        if (key_layer.shape[-1] > 256
            or key_layer.shape[-1] % 8 != 0
            or (key_layer.shape[-1] > 192
                and self.device_compute_capability not in ((8, 0), (9, 0)))):
2768
2769
            use_flash_attention = False

2770
        # Filter: cross attention + causal mask.
2771
2772
2773
        # (in training mode)
        if (inference_params is None
            and _flash_attn_2_1_plus
2774
            and "causal" in attn_mask_type
2775
2776
            and max_seqlen_q != max_seqlen_kv
        ):
2777
            warnings.warn(
2778
2779
                "In training mode, disable the use of FlashAttention since version 2.1+ has "
                "changed its behavior for causal mask in cross attention. See "
2780
2781
2782
2783
                "https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag"
            )
            use_flash_attention = False

2784
2785
2786
        context_parallel = (self.cp_group is not None and \
            get_distributed_world_size(self.cp_group) != 1)

2787
2788
2789
2790
2791
2792
2793
        # Filter: sliding window attention.
        # UnfusedDotProductAttention can support SWA via arbitrary attention mask.
        if window_size not in ((-1, -1), (-1, 0)):
            use_fused_attention = False
            if (not _flash_attn_2_3_plus) or context_parallel:
                use_flash_attention = False

2794
        # Filter: Attention mask type.
2795
        #   attn_mask_type(s)    |     supported backends
2796
        # ------------------------------------------------
2797
2798
        #   no_mask              |     All
        #   padding              |     UnfusedDotProductAttention, FlashAttention, FusedAttention
2799
        #   causal               |     All
2800
        #   padding + causal     |     FlashAttention, FusedAttention
2801
2802
2803
2804
2805
        #   arbitrary            |     UnfusedDotProductAttention
        #
        if attn_mask_type == "arbitrary":
            use_flash_attention = False
            use_fused_attention = False
2806
2807
2808
2809
2810

        if (inference_params is None
            and "causal" in attn_mask_type
            and max_seqlen_q != max_seqlen_kv
        ):
2811
            use_unfused_attention = False
2812

2813
2814
2815
2816
2817
2818
2819
2820
2821
2822
2823
2824
2825
2826
2827
2828
2829
2830
2831
2832
2833
2834
2835
2836
2837
2838
2839
2840
        # Filter: bias.
        global _alibi_cache
        if alibi_slopes is not None:
            assert (core_attention_bias_type == "alibi"
                ), "core_attention_bias_type must be alibi in order to use alibi_slopes!"
            if self.layer_number == 1:
                _alibi_cache["_alibi_slopes_require_update"] = True
                _alibi_cache["_alibi_bias_require_update"] = True
        if core_attention_bias_type == "alibi":
            assert (core_attention_bias is None
                ), "core_attention_bias must be None when core_attention_bias_type is alibi!"
            if (_alibi_cache["_num_heads"] != query_layer.shape[-2]
                or _alibi_cache["_max_seqlen_q"] != max_seqlen_q
                or _alibi_cache["_max_seqlen_kv"] != max_seqlen_kv
                or _alibi_cache["_alibi_slopes"] is None):
                _alibi_cache["_alibi_slopes_require_update"] = True
                _alibi_cache["_alibi_bias_require_update"] = True

        if core_attention_bias_type not in ["no_bias", "alibi"] or core_attention_bias is not None:
            use_flash_attention = False

        fu_core_attention_bias_type = core_attention_bias_type
        fu_core_attention_bias = core_attention_bias
        if core_attention_bias_type == "alibi" and use_fused_attention and alibi_slopes is not None:
            fu_core_attention_bias_type = "post_scale_bias"
            _, fu_core_attention_bias = get_alibi(
                query_layer.shape[-2], max_seqlen_q, max_seqlen_kv, alibi_slopes=alibi_slopes,
                bias_dtype=query_layer.dtype)
2841
2842
2843
2844
2845
2846
2847
        if (use_fused_attention
            and fu_core_attention_bias_type == "post_scale_bias"
            and (fu_core_attention_bias.shape[0] != 1
            or fu_core_attention_bias.shape[1] != query_layer.shape[-2])):
            if fu_core_attention_bias.requires_grad:
                # remove this line when cuDNN adds bwd support for
                # [1, 1, s, s], [b, 1, s, s] and [b, h, s, s]
2848
                use_fused_attention = False
2849
            else:
2850
2851
2852
                # max512 backend will only support [1, h, s, s]
                os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1"

2853
2854
2855
2856
2857
        if use_fused_attention:
            fused_attention_backend = tex.get_fused_attn_backend(
                TE_DType[query_layer.dtype],
                TE_DType[key_layer.dtype],
                QKVLayout[qkv_layout],
2858
                AttnBiasType[fu_core_attention_bias_type],
2859
                AttnMaskType[attn_mask_type],
2860
                self.attention_dropout,
2861
2862
2863
2864
2865
2866
                query_layer.shape[-2], # num_attn_heads
                key_layer.shape[-2], # num_gqa_groups
                max_seqlen_q,
                max_seqlen_kv,
                query_layer.shape[-1], # head_dim
            )
2867
2868
2869
            # DPA does not support FP8; for FP8, use cpp_extensions modules directly
            is_backend_avail = (fused_attention_backend in
                [FusedAttnBackend["F16_max512_seqlen"], FusedAttnBackend["F16_arbitrary_seqlen"]])
2870
2871
2872
2873
            use_fused_attention = ( \
                use_fused_attention and is_backend_avail and \
                (not context_parallel or \
                 fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]))
2874
2875
2876
2877
2878
            if (fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]
                and fu_core_attention_bias_type == "post_scale_bias"
                and (fu_core_attention_bias.shape[0] != 1
                or fu_core_attention_bias.shape[1] != query_layer.shape[-2])):
                use_fused_attention = False
2879

2880
2881
2882
2883
2884
2885
2886
2887
2888
2889
2890
2891
2892
2893
2894
2895
2896
2897
        # Filter: determinism.
        # backend                                  | deterministic
        # ---------------------------------------------------------
        # flash-attn v1                            | yes
        # flash-attn v2                            | no
        # FusedAttnBackend["F16_max512_seqlen"]    | yes
        # FusedAttnBackend["F16_arbitrary_seqlen"] | workspace optimization path: yes; otherwise: no
        # UnfusedDotProductAttention               | yes
        #
        # Note that FusedAttnBackend["F16_arbitrary_seqlen"] only has workspace optimization path
        # on sm90 architectures.
        #
        if (use_fused_attention
            and fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]
            and self.deterministic
            and self.device_compute_capability != (9, 0)):
            use_fused_attention = False

2898
2899
2900
2901
2902
2903
        # Select FusedAttention on sm90 and FlashAttention on others for performance
        if (use_flash_attention
            and use_fused_attention
            and fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]):
            if self.device_compute_capability == (9, 0):
                use_flash_attention = False
2904
2905

        if use_flash_attention:
2906
2907
            if _NVTE_DEBUG:
                print("[DotProductAttention]: using flash-attn",_flash_attn_version)
2908
2909
2910
            if core_attention_bias_type == "alibi":
                alibi_slopes, _ = get_alibi(
                    query_layer.shape[-2], max_seqlen_q, max_seqlen_kv, alibi_slopes=alibi_slopes)
2911
2912
2913
2914
2915
2916
2917
2918
            return self.flash_attention(query_layer,
                                        key_layer,
                                        value_layer,
                                        attention_mask=attention_mask,
                                        qkv_layout=qkv_layout,
                                        cu_seqlens_q=cu_seqlens_q,
                                        cu_seqlens_kv=cu_seqlens_kv,
                                        attn_mask_type=attn_mask_type,
2919
                                        window_size=window_size,
2920
                                        alibi_slopes=alibi_slopes,
2921
2922
                                        cp_group=self.cp_group,
                                        cp_global_ranks=self.cp_global_ranks,
2923
2924
2925
                                        cp_stream=self.cp_stream,
                                        max_seqlen_q=max_seqlen_q,
                                        max_seqlen_kv=max_seqlen_kv)
2926

2927
        if use_fused_attention:
2928
2929
2930
            if _NVTE_DEBUG:
                print("[DotProductAttention]: using cuDNN fused attention (backend "
                    + str(int(fused_attention_backend)) + ")")
2931
            if checkpoint_core_attention:
2932
2933
2934
2935
2936
2937
2938
2939
2940
2941
2942
                return self._checkpointed_attention_forward(
                    self.fused_attention,
                    query_layer,
                    key_layer,
                    value_layer,
                    qkv_layout=qkv_layout,
                    cu_seqlens_q=cu_seqlens_q,
                    cu_seqlens_kv=cu_seqlens_kv,
                    attn_mask_type=attn_mask_type,
                    attention_mask=attention_mask,
                    fused_attention_backend=fused_attention_backend,
2943
2944
                    core_attention_bias_type=fu_core_attention_bias_type,
                    core_attention_bias=fu_core_attention_bias,
2945
2946
2947
2948
2949
2950
2951
2952
2953
2954
2955
2956
2957
2958
2959
2960
                    fast_zero_fill=fast_zero_fill,
                    cp_group=self.cp_group,
                    cp_global_ranks=self.cp_global_ranks,
                    cp_stream=self.cp_stream,
                    max_seqlen_q=max_seqlen_q,
                    max_seqlen_kv=max_seqlen_kv)
            return self.fused_attention(
                query_layer,
                key_layer,
                value_layer,
                qkv_layout=qkv_layout,
                cu_seqlens_q=cu_seqlens_q,
                cu_seqlens_kv=cu_seqlens_kv,
                attn_mask_type=attn_mask_type,
                attention_mask=attention_mask,
                fused_attention_backend=fused_attention_backend,
2961
2962
                core_attention_bias_type=fu_core_attention_bias_type,
                core_attention_bias=fu_core_attention_bias,
2963
2964
2965
2966
2967
2968
2969
2970
2971
                fast_zero_fill=fast_zero_fill,
                cp_group=self.cp_group,
                cp_global_ranks=self.cp_global_ranks,
                cp_stream=self.cp_stream,
                max_seqlen_q=max_seqlen_q,
                max_seqlen_kv=max_seqlen_kv)

        assert (not context_parallel), \
            "Context parallelism is only implemented with Flash Attention and Fused Attention!"
2972

2973
2974
2975
2976
2977
2978
2979
        from .cpu_offload import CPUOffloadEnabled
        if CPUOffloadEnabled:
            warnings.warn(
                           "Attention activation Offloading is only implemented"
                           "with Flash Attention and Fused Attention!"
                         )

2980
2981
        if _NVTE_DEBUG:
            print("[DotProductAttention]: using unfused DPA")
2982
2983
2984
2985
2986
2987
2988
2989
2990
2991
2992
2993
2994
        if use_unfused_attention:
            if checkpoint_core_attention:
                return self._checkpointed_attention_forward(
                    self.unfused_attention,
                    query_layer,
                    key_layer,
                    value_layer,
                    qkv_layout = qkv_layout,
                    cu_seqlens_q = cu_seqlens_q,
                    cu_seqlens_kv = cu_seqlens_kv,
                    attn_mask_type = attn_mask_type,
                    attention_mask = attention_mask,
                    core_attention_bias_type = core_attention_bias_type,
2995
2996
                    core_attention_bias = core_attention_bias,
                    alibi_slopes = alibi_slopes)
2997
2998
2999
3000
3001
3002
3003
3004
3005
            return self.unfused_attention(query_layer,
                    key_layer,
                    value_layer,
                    qkv_layout = qkv_layout,
                    cu_seqlens_q = cu_seqlens_q,
                    cu_seqlens_kv = cu_seqlens_kv,
                    attn_mask_type = attn_mask_type,
                    attention_mask = attention_mask,
                    core_attention_bias_type = core_attention_bias_type,
3006
3007
                    core_attention_bias = core_attention_bias,
                    alibi_slopes = alibi_slopes)
3008
3009

        raise Exception("No dot product attention support for the provided inputs!")
3010
3011


3012
3013
3014
3015
3016
3017
3018
class MultiheadAttention(torch.nn.Module):
    r"""
    Multi-head Attention (MHA), including Query,
    Key, Value and Output projection.

    .. note::

3019
3020
        Argument :attr:`attention_mask` in the `forward` call is only used when
        :attr:`attn_mask_type` includes '"padding"' or `"arbitrary"`.
3021

3022
3023
3024
3025
3026
3027
3028
3029
3030
3031
3032
3033
3034
3035
3036
3037
3038
3039
3040
3041
3042
3043
3044
3045
3046
    Parameters
    ----------
    hidden_size : int
                 size of each input sample.
    num_attention_heads : int
                         number of attention heads in the transformer layer.
    kv_channels: int, default = `None`
                number of key-value channels. defaults to
                :attr:`hidden_size` / :attr:`num_attention_heads` if `None`.
    attention_dropout: float, default = 0.1
                      dropout probability for the dropout op during multi-head attention.
    layernorm_epsilon : float, default = 1e-5
                       a value added to the denominator of layer normalization
                       for numerical stability.
    init_method : Callable, default = `None`
                 used for initializing weights of QKV and FC1 weights in the following way:
                 `init_method(weight)`. When set to `None`, defaults to
                 `torch.nn.init.normal_(mean=0.0, std=0.023)`.
    output_layer_init_method : Callable, default = `None`
                              used for initializing weights of PROJ and FC2 in the following way:
                              `output_layer_init_method(weight)`. When set to `None`, defaults to
                              `torch.nn.init.normal_(mean=0.0, std=0.023)`.
    layer_number: int, default = `None`
                 layer number of the current `TransformerLayer` when multiple such modules are
                 concatenated to form a transformer block.
3047
3048
    attn_mask_type: {'no_mask', 'padding', 'causal', 'padding_causal' 'arbitrary'},
                   default = `causal`
3049
3050
3051
3052
3053
                   type of attention mask passed into softmax operation. Overridden by
                   :attr:`attn_mask_type` in the `forward` method. The forward
                   arg is useful for dynamically changing mask types, e.g. a different
                   mask for training and inference. The init arg is useful for cases
                   involving compilation/tracing, e.g. ONNX export.
3054
3055
3056
3057
3058
3059
    window_size: Optional[Tuple[int, int]], default = `None`
                sliding window size for local attention, where query at position i attends to keys
                in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
                + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
                window and causal mask specifically. Similar to :attr:`attn_mask_type`, it can
                be overridden by :attr:`window_size` in `forward` as well.
3060
3061
3062
3063
3064
3065
3066
3067
3068
3069
3070
3071
3072
    num_gqa_groups : int, default = `None`
                         number of GQA groups in the transformer layer.
                         Grouped Query Attention is described in
                         `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
                         This only affects the keys and values, not the querys.
                         GQA-1 is equivalent to Multi-Query Attention
                         (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
                         is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
    return_layernorm_output : bool, default = `False`
                             if set to `True`, output of layernorm is returned from the forward
                             together with the output of the linear transformation.
                             Example use case: residual connection for transformer module is
                             taken post layernorm.
3073
3074
    input_layernorm: bool, default = `False`
                     if set to `True`, layer normalization to the input is applied.
3075
3076
3077
3078
3079
3080
3081
3082
3083
3084
3085
3086
3087
3088
3089
3090
3091
3092
3093
3094
3095
3096
3097
    attention_type: { 'self', 'cross' }, default = 'self'
                   type of attention applied.
    zero_centered_gamma : bool, default = 'False'
                         if set to 'True', gamma parameter in LayerNorm is initialized to 0 and
                         the LayerNorm formula changes to

                         .. math::
                            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} *
                            (1 + \gamma) + \beta
    normalization : { 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm'
                   type of normalization applied.
    qkv_weight_interleaved : bool, default = `True`
                            if set to `False`, the QKV weight is interpreted as a concatenation of
                            query, key, and value weights along the `0th` dimension. The default
                            interpretation is that the individual `q`, `k`, and `v` weights for each
                            attention head are interleaved. This parameter is set to `False` when
                            using :attr:`fuse_qkv_params=False`.
    bias : bool, default = `True`
          if set to `False`, the transformer layer will not learn any additive biases.
    device : Union[torch.device, str], default = "cuda"
          The device on which the parameters of the model will allocated. It is the user's
          responsibility to ensure all parameters are moved to the GPU before running the
          forward pass.
3098
3099
3100
3101
3102
3103
3104
3105
    qkv_format: str, default = `sbhd`
            dimension format for `query_layer`, `key_layer` and `value_layer`,
            {`sbhd`, `bshd`}. `s` stands for the sequence length, `b` batch size,
            `h` the number of heads and `d` head size. `sbhd` and `bshd` formats
            are used for when sequences in a batch are of equal length or padded to
            equal length. Please note that these formats do not reflect how
            tensors `query_layer`, `key_layer`, `value_layer` are laid out in memory.
            For that, please use `_get_qkv_layout` to gain the layout information.
3106
3107
3108
3109
3110
3111
3112
3113
3114
3115
3116
3117
3118
3119
3120
3121
3122
3123
3124
3125
3126
3127
3128
3129
3130
3131
3132
3133
3134
3135
3136
3137
3138
3139
3140
3141
3142
3143
3144
3145

    Parallelism parameters
    ----------------------
    set_parallel_mode : bool, default = `False`
                      if set to `True`, QKV and FC1 layers are used as Column Parallel
                      whereas PROJ and FC2 is used as Row Parallel as described
                      `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
    sequence_parallel : bool, default = `False`
                       if set to `True`, uses sequence parallelism.
    tp_group : ProcessGroup, default = `None`
              tensor parallel process group.
    tp_size : int, default = 1
             used as TP (tensor parallel) world size when TP groups are not formed during
             initialization. In this case, users must call the
             `set_tensor_parallel_group(tp_group)` method on the initialized module before the
             forward pass to supply the tensor parallel group needed for tensor and sequence
             parallel collectives.

    Optimization parameters
    -----------------------
    fuse_wgrad_accumulation : bool, default = 'False'
                             if set to `True`, enables fusing of creation and accumulation of
                             the weight gradient. When enabled, it is assumed that the weights
                             have an additional `main_grad` attribute (used instead of the
                             regular `grad`) which is a pre-allocated buffer of the correct
                             size to accumulate gradients in.
    params_dtype : torch.dtype, default = `torch.get_default_dtype()`
                  it controls the type used to allocate the initial parameters. Useful when
                  the model is trained with lower precision and the original FP32 parameters
                  would not fit in GPU memory.
    return_bias : bool, default = `False`
                 when set to `True`, this module will not apply the additive bias itself, but
                 instead return the bias value during the forward pass together with the
                 output of the linear transformation :math:`y = xA^T`. This is useful when
                 the bias addition can be fused to subsequent operations.
    fuse_qkv_params: bool, default = 'False'
                    if set to `True`, `TransformerLayer` module exposes a single fused
                    parameter for query-key-value. This enables optimizations such as QKV
                    fusion without concatentations/splits and also enables the argument
                    `fuse_wgrad_accumulation`.
3146
3147
3148
3149
3150
3151
    """

    def __init__(
        self,
        hidden_size: int,
        num_attention_heads: int,
3152
3153
3154
3155
3156
        kv_channels: Optional[int] = None,
        attention_dropout: float = 0.1,
        layernorm_epsilon: float = 1e-5,
        init_method: Optional[Callable] = None,
        output_layer_init_method: Optional[Callable] = None,
3157
        layer_number: Optional[int] = None,
3158
        attn_mask_type: str = "causal",
3159
        window_size: Optional[Tuple[int, int]] = None,
3160
3161
        tp_group: Optional[dist_group_type] = None,
        tp_size: int = 1,
3162
        num_gqa_groups: Optional[int] = None,
3163
3164
3165
        fuse_wgrad_accumulation: bool = False,
        get_rng_state_tracker: Optional[Callable] = None,
        sequence_parallel: bool = False,
3166
        params_dtype: Optional[torch.dtype] = None,
3167
        return_bias: bool = False,
3168
3169
3170
3171
3172
3173
3174
3175
3176
3177
3178
        return_layernorm_output: bool = False,
        input_layernorm: bool = False,
        attention_type: str = "self",
        set_parallel_mode: bool = False,
        fuse_qkv_params: bool = False,
        zero_centered_gamma: bool = False,
        qkv_weight_interleaved: bool = True,
        ub_bulk_wgrad: bool = False,
        ub_bulk_dgrad: bool = False,
        ub_split_rs: bool = False,
        ub_split_ag: bool = False,
3179
3180
        ub_atomic_gemm_rs: bool = False,
        ub_atomic_gemm_ag: bool = False,
3181
        bias: bool = True,
3182
        normalization: str = "LayerNorm",
3183
        device: Union[torch.device, str] = "cuda",
3184
        qkv_format: str = "sbhd",
3185
3186
    ) -> None:
        super().__init__()
3187

3188
        self.qkv_format = qkv_format
3189
        self.attn_mask_type = attn_mask_type
3190
3191
        self.window_size = window_size
        self.window_size = check_set_window_size(attn_mask_type, self.window_size)
3192
        self.layer_number = layer_number
3193
3194
3195
3196
3197
        self.input_layernorm = input_layernorm
        self.attention_type = attention_type
        self.get_rng_state_tracker = get_rng_state_tracker
        self.tp_group = tp_group
        self.return_layernorm_output = return_layernorm_output
3198
        self.params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
3199
        self.num_attention_heads = num_attention_heads
3200
3201
3202
3203
3204
3205
3206
3207
        self.return_bias = return_bias

        kv_channels = kv_channels if kv_channels else (hidden_size // num_attention_heads)

        if init_method is None:
            init_method = get_default_init_method()
        if output_layer_init_method is None:
            output_layer_init_method = get_default_init_method()
3208
3209
3210
3211
3212

        if not fuse_qkv_params:
            qkv_weight_interleaved = False
        self.qkv_weight_interleaved = qkv_weight_interleaved

3213
3214
3215
        assert attention_type in AttnTypes, f"attention_type {attention_type} not supported"
        if layer_number is not None:
            assert layer_number > 0, "layer_number must be a positive integer"
3216
3217
3218
3219
3220
3221
3222

        tp_size = tp_size if tp_group is None else get_distributed_world_size(tp_group)
        self.tp_size = tp_size
        self.sequence_parallel = (tp_size > 1) and sequence_parallel

        self.hidden_size_per_attention_head = kv_channels
        self.num_attention_heads_per_partition = divide(num_attention_heads, tp_size)
3223
3224
3225
3226
        self.num_gqa_groups = (
            num_attention_heads if num_gqa_groups is None else num_gqa_groups
        )
        assert (num_attention_heads % self.num_gqa_groups == 0
cyanguwa's avatar
cyanguwa committed
3227
3228
                ), "The number of attention heads must be divisible by the number of GQA groups!"
        assert (self.num_gqa_groups % tp_size == 0
3229
3230
3231
                ), "The number of GQA groups must be divisible by tensor parallel size!"
        self.num_gqa_groups_per_partition = int(self.num_gqa_groups // tp_size)
        self.hidden_size_kv = int(hidden_size * self.num_gqa_groups // num_attention_heads)
3232
3233
3234
3235
3236
3237
3238

        common_gemm_kwargs = {
            "fuse_wgrad_accumulation": fuse_wgrad_accumulation,
            "tp_group": tp_group,
            "tp_size": tp_size,
            "get_rng_state_tracker": get_rng_state_tracker,
            "sequence_parallel": sequence_parallel,
3239
            "params_dtype": self.params_dtype,
3240
            "device": device,
3241
3242
3243
3244
        }

        qkv_parallel_mode = "column" if set_parallel_mode else None

cyanguwa's avatar
cyanguwa committed
3245
        if self.attention_type == "self":
3246
3247
3248
3249
3250
3251
3252
            parameters_split = None
            if not fuse_qkv_params:
                parameters_split = collections.OrderedDict([
                    ("query", hidden_size),
                    ("key", self.hidden_size_kv),
                    ("value", self.hidden_size_kv),
                ])
3253
3254
3255
            if self.input_layernorm:
                self.layernorm_qkv = LayerNormLinear(
                    hidden_size,
cyanguwa's avatar
cyanguwa committed
3256
                    hidden_size + 2 * self.hidden_size_kv,
3257
3258
3259
3260
3261
3262
                    eps=layernorm_epsilon,
                    init_method=init_method,
                    bias=bias,
                    return_bias=False,
                    parallel_mode=qkv_parallel_mode,
                    return_layernorm_output=return_layernorm_output,
cyanguwa's avatar
cyanguwa committed
3263
                    parameters_split=parameters_split,
3264
3265
3266
3267
                    zero_centered_gamma=zero_centered_gamma,
                    ub_bulk_wgrad=ub_bulk_wgrad,
                    ub_bulk_dgrad=ub_bulk_dgrad,
                    ub_split_ag=ub_split_ag,
3268
                    normalization=normalization,
3269
                    ub_atomic_gemm_ag=ub_atomic_gemm_ag,
Przemyslaw Tredak's avatar
Przemyslaw Tredak committed
3270
                    ub_name="qkv",
3271
3272
3273
3274
3275
                    **common_gemm_kwargs,
                )
            else:
                self.qkv = Linear(
                    hidden_size,
cyanguwa's avatar
cyanguwa committed
3276
                    hidden_size + 2 * self.hidden_size_kv,
3277
3278
3279
3280
                    init_method=init_method,
                    bias=bias,
                    return_bias=False,
                    parallel_mode=qkv_parallel_mode,
cyanguwa's avatar
cyanguwa committed
3281
                    parameters_split=parameters_split,
3282
3283
                    **common_gemm_kwargs,
                )
cyanguwa's avatar
cyanguwa committed
3284
        elif self.attention_type == "cross":
3285
3286
3287
3288
3289
3290
3291
3292
3293
            if self.input_layernorm:
                self.layernorm_query = LayerNormLinear(
                    hidden_size,
                    hidden_size,
                    eps=layernorm_epsilon,
                    init_method=init_method,
                    bias=bias,
                    return_bias=False,
                    parallel_mode=qkv_parallel_mode,
3294
                    parameters_split=("query",) if not fuse_qkv_params else None,
3295
3296
3297
3298
3299
                    return_layernorm_output=return_layernorm_output,
                    zero_centered_gamma=zero_centered_gamma,
                    ub_bulk_wgrad=ub_bulk_wgrad,
                    ub_bulk_dgrad=ub_bulk_dgrad,
                    ub_split_ag=ub_split_ag,
3300
                    normalization=normalization,
3301
                    ub_atomic_gemm_ag=ub_atomic_gemm_ag,
Przemyslaw Tredak's avatar
Przemyslaw Tredak committed
3302
                    ub_name="qkv",
3303
3304
3305
3306
3307
3308
3309
3310
3311
3312
3313
3314
3315
3316
                    **common_gemm_kwargs,
                )
            else:
                self.query_layer = Linear(
                    hidden_size,
                    hidden_size,
                    init_method=init_method,
                    bias=bias,
                    return_bias=False,
                    parallel_mode=qkv_parallel_mode,
                    **common_gemm_kwargs,
                )
            self.key_value = Linear(
                hidden_size,
3317
                2 * self.hidden_size_kv,
3318
3319
3320
3321
                init_method=init_method,
                bias=bias,
                return_bias=False,
                parallel_mode=qkv_parallel_mode,
3322
                parameters_split=("key", "value") if not fuse_qkv_params else None,
3323
3324
3325
3326
3327
3328
3329
                **common_gemm_kwargs,
            )

        # Attention.
        self.core_attention = DotProductAttention(
            num_attention_heads,
            kv_channels,
3330
3331
            num_gqa_groups=self.num_gqa_groups,
            attention_dropout=attention_dropout,
3332
            qkv_format=self.qkv_format,
3333
3334
3335
3336
            tp_size=tp_size,
            get_rng_state_tracker=get_rng_state_tracker,
            sequence_parallel=sequence_parallel,
            tp_group=tp_group,
3337
            layer_number=self.layer_number,
3338
            attention_type=self.attention_type,
3339
3340
3341
3342
3343
3344
3345
3346
        )

        # Linear
        self.proj = Linear(
            hidden_size,
            hidden_size,
            init_method=output_layer_init_method,
            bias=bias,
3347
            return_bias=return_bias,
3348
3349
3350
            parallel_mode="row" if set_parallel_mode else None,
            ub_split_rs=ub_split_rs,
            ub_split_ag=ub_split_ag,
3351
3352
            ub_atomic_gemm_rs=ub_atomic_gemm_rs,
            ub_atomic_gemm_ag=ub_atomic_gemm_ag,
Przemyslaw Tredak's avatar
Przemyslaw Tredak committed
3353
            ub_name="proj",
3354
3355
3356
3357
3358
            **common_gemm_kwargs,
        )


    def _allocate_memory(
3359
        self, inference_max_sequence_len: int, batch_size: int, dtype: torch.dtype
3360
3361
3362
3363
    ) -> torch.Tensor:
        return torch.empty(
            inference_max_sequence_len,
            batch_size,
3364
            self.num_gqa_groups_per_partition,
3365
            self.hidden_size_per_attention_head,
3366
            dtype=dtype,
3367
3368
3369
3370
            device=torch.cuda.current_device(),
        )

    def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None:
3371
3372
3373
3374
3375
3376
3377
3378
3379
        """
        Set the tensor parallel group for the given
        module before executing the forward pass.

        Parameters
        ----------
        tp_group : ProcessGroup, default = `None`
                  tensor parallel process group.
        """
3380
3381
        self.tp_group = tp_group

3382
    def set_context_parallel_group(
3383
3384
        self,
        cp_group: Union[dist_group_type, None],
3385
        cp_global_ranks: List[int],
3386
3387
        cp_stream: torch.cuda.Stream,
    ) -> None:
3388
3389
3390
3391
3392
3393
3394
3395
3396
3397
3398
3399
3400
        """
        Set the context parallel attributes for the given
        module before executing the forward pass.

        Parameters
        ----------
        cp_group : ProcessGroup
                  context parallel process group.
        cp_global_ranks : List[int]
                         list of global ranks in the context group.
        cp_stream : torch.cuda.Stream
                   cuda stream for context parallel execution.
        """
3401
3402
3403
3404
3405
3406
        # Deep iterate but skip self to avoid infinite recursion.
        for index, child in enumerate(self.modules()):
            if index == 0:
                continue
            if hasattr(child, "set_context_parallel_group"):
                child.set_context_parallel_group(cp_group, cp_global_ranks, cp_stream)
3407

3408
3409
3410
    def forward(
        self,
        hidden_states: torch.Tensor,
3411
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
3412
        encoder_output: Optional[torch.Tensor] = None,
3413
        attn_mask_type: Optional[str] = None,
3414
        window_size: Optional[Tuple[int, int]] = None,
3415
3416
        is_first_microbatch: Optional[bool] = None,
        checkpoint_core_attention: bool = False,
3417
        inference_params: Optional[InferenceParams] = None,
3418
        rotary_pos_emb: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
3419
3420
        core_attention_bias_type: str = "no_bias",
        core_attention_bias: Optional[torch.Tensor] = None,
3421
        alibi_slopes: Optional[torch.Tensor] = None,
3422
        fast_zero_fill: bool = True,
3423
    ) -> Tuple[Union[torch.Tensor, None], ...]:
3424
3425
3426
3427
3428
        """
        Forward propagation for MultiheadAttention layer.

        .. note::

3429
3430
            Argument :attr:`attention_mask` is only used when :attr:`attn_mask_type`
            includes `"padding"` or `"arbitrary"`.
3431
3432
3433
3434
3435

        Parameters
        ----------
        hidden_states : torch.Tensor
             Input tensor.
3436
3437
3438
3439
3440
3441
3442
3443
3444
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]],
             default = `None`. Boolean tensor(s) used to mask out attention softmax input.
             It should be 'None' for 'causal' and 'no_mask' types. For 'padding' masks, it should be
             a single tensor of [batch_size, 1, 1, seqlen_q] for self-attention, and a tuple of
             two tensors in shapes [batch_size, 1, 1, seqlen_q] and [batch_size, 1, 1, seqlen_kv]
             for cross-attention. For the 'arbitrary' mask type, it should be in a shape that is
             broadcastable to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv].
        attn_mask_type: {'no_mask', 'padding', 'causal', 'padding_causal', 'arbitrary'},
                       default = `None`
3445
                       type of attention mask passed into softmax operation.
3446
3447
        window_size: Optional[Tuple[int, int]], default = `None`
                    sliding window size for local attention.
3448
3449
3450
3451
3452
3453
3454
3455
3456
3457
3458
3459
3460
3461
3462
3463
3464
3465
3466
3467
3468
3469
3470
3471
3472
        encoder_output : Optional[torch.Tensor], default = `None`
             Output of the encoder block to be fed into the decoder block if using
             `layer_type="decoder"`.
        is_first_microbatch : {True, False, None}, default = None
                             During training using either gradient accumulation or
                             pipeline parallelism a minibatch of data is further split
                             into microbatches. Between the microbatches of the same minibatch
                             the model weights are not updated. Setting this parameter indicates
                             whether the current microbatch is the first in a minibatch or not.
                             When set, this parameter enables additional optimizations:

                             * during FP8 training, it allows caching of the FP8 versions of
                               the weights
                             * it also allows skipping gradient accumulation during the
                               first microbatch (since it is the first gradient being
                               produced)
        checkpoint_core_attention: bool, default = `False`
                                  If true, forward activations for core attention are recomputed
                                  during the backward pass in order to save memory that would
                                  otherwise be occupied to store the forward activations until
                                  backprop.
        rotary_pos_emb: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], default = `None`
                       Embeddings for query and key tensors for applying rotary position
                       embedding. By default no input embedding is applied.
        core_attention_bias_type: str, default = `no_bias`
3473
                    Bias type, {`no_bias`, `pre_scale_bias`, 'post_scale_bias`, `alibi`}
3474
        core_attention_bias: Optional[torch.Tensor], default = `None`
3475
3476
                    Bias tensor for Q * K.T, shape [1, num_head, max_seqlen_q, max_seqlen_kv].
                    It should be 'None' for 'no_bias' and 'alibi' bias types.
3477
3478
3479
3480
        alibi_slopes: Optional[torch.Tensor], default = `None`
                     ALiBi slopes in FP32 and shape [nheads] or [batch_size, nheads].
                     It adds a bias of (-alibi_slope * (i + seqlen_k - seqlen_q - j))
                     to the attention score of query i and key j.
3481
3482
3483
        fast_zero_fill: bool, default = `True`
                    Whether to set output tensors to 0 or not before use.
        """
3484
3485
        # hidden_states: [sq, b, h]

3486
3487
        if attn_mask_type is not None:
            window_size = check_set_window_size(attn_mask_type, window_size)
3488
        if attn_mask_type is None:
3489
            attn_mask_type = self.attn_mask_type
3490
3491
        if window_size is None:
            window_size = self.window_size
3492

3493
3494
3495
3496
3497
        if "padding" in attn_mask_type and attention_mask is not None:
            for i,_ in enumerate(attention_mask):
                assert (
                    attention_mask[i].dtype == torch.bool
                ), "Attention mask must be in boolean type!"
3498

3499
3500
        assert (core_attention_bias_type in AttnBiasTypes
                ), f"core_attention_bias_type {core_attention_bias_type} is not supported!"
3501

3502
        # =================================================
3503
        # Pre-allocate memory for key-values for inference
3504
3505
3506
3507
        # =================================================

        if inference_params and self.layer_number is not None:
            if self.layer_number not in inference_params.key_value_memory_dict:
3508
                inf_max_seq_len = inference_params.max_sequence_length
3509
3510
                inf_max_batch_size = inference_params.max_batch_size
                inference_key_memory = self._allocate_memory(
3511
                    inf_max_seq_len, inf_max_batch_size, hidden_states.dtype
3512
3513
                )
                inference_value_memory = self._allocate_memory(
3514
                    inf_max_seq_len, inf_max_batch_size, hidden_states.dtype
3515
3516
3517
3518
3519
3520
3521
3522
3523
3524
3525
                )
                inference_params.key_value_memory_dict[self.layer_number] = (
                    inference_key_memory,
                    inference_value_memory,
                )
            else:
                (
                    inference_key_memory,
                    inference_value_memory,
                ) = inference_params.key_value_memory_dict[self.layer_number]

3526
        # ======================
3527
        # Query, Key, and Value
3528
        # ======================
3529

cyanguwa's avatar
cyanguwa committed
3530
3531
        if self.attention_type == "self":
            # Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn]
3532
3533
3534
3535
3536
3537
3538
3539
3540
3541
3542
3543
3544
3545
3546
            if self.input_layernorm:
                layernorm_qkv_outputs = self.layernorm_qkv(
                    hidden_states,
                    is_first_microbatch=is_first_microbatch,
                )
                if self.return_layernorm_output:
                    mixed_x_layer, layernorm_output = layernorm_qkv_outputs
                else:
                    mixed_x_layer = layernorm_qkv_outputs
            else:
                mixed_x_layer = self.qkv(
                    hidden_states,
                    is_first_microbatch=is_first_microbatch,
                )

cyanguwa's avatar
cyanguwa committed
3547
3548
            num_queries_per_key_value = (self.num_attention_heads_per_partition //
                                         self.num_gqa_groups_per_partition)
3549
            if self.qkv_weight_interleaved:
cyanguwa's avatar
cyanguwa committed
3550
                # [sq, b, ng * (np/ng + 2) * hn] --> [sq, b, ng, (np/ng + 2), hn]
3551
                new_tensor_shape = mixed_x_layer.size()[:-1] + (
cyanguwa's avatar
cyanguwa committed
3552
3553
                    self.num_gqa_groups_per_partition,
                    (num_queries_per_key_value + 2),
3554
3555
3556
3557
                    self.hidden_size_per_attention_head,
                )
                # split along second last dimension
                split_dim = -2
cyanguwa's avatar
cyanguwa committed
3558
3559
3560
3561
3562
3563
3564
3565
3566
            else:
                # [sq, b, ng * (np/ng + 2) * hn] --> [sq, b, (np/ng + 2), ng, hn]
                new_tensor_shape = mixed_x_layer.size()[:-1] + (
                    (num_queries_per_key_value + 2),
                    self.num_gqa_groups_per_partition,
                    self.hidden_size_per_attention_head
                )
                # split along third last dimension
                split_dim = -3
3567
3568
3569

            mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)

cyanguwa's avatar
cyanguwa committed
3570
3571
3572
3573
3574
3575
3576
3577
3578
            # qkv_weight_interleaved:
            #  [sq, b, ng, (np/ng + 2), hn]
            #  --> [sq, b, ng, np/ng, hn], [sq, b, ng, 1, hn], [sq, b, ng, 1, hn]
            # not qkv_weight_interleaved:
            #  [sq, b, (np/ng + 2), ng, hn]
            #  --> [sq, b, np/ng, np, hn], [sq, b, 1, ng, hn], [sq, b, 1, ng, hn]
            if not is_in_onnx_export_mode():
                query_layer, key_layer, value_layer = _SplitAlongDim.apply(
                    mixed_x_layer, split_dim, (num_queries_per_key_value, 1, 1)
3579
                )
3580
            else:
cyanguwa's avatar
cyanguwa committed
3581
3582
3583
3584
3585
3586
3587
3588
3589
3590
3591
3592
                query_layer, key_layer, value_layer = torch.split(
                    mixed_x_layer, (num_queries_per_key_value, 1, 1), dim = split_dim,
                 )

            # query: -> [sq, b, np, hn]
            # key, value: -> [sq, b, ng, hn]
            query_layer, key_layer, value_layer = (x.reshape(x.size(0), x.size(1), -1,
                                                             self.hidden_size_per_attention_head)
                                                   for x in (query_layer, key_layer, value_layer))

        elif self.attention_type == "cross":
            # Attention heads [sk, b, h] --> [sk, b, (ng * 2 * hn)]
3593
            mixed_kv_layer = self.key_value(
cyanguwa's avatar
cyanguwa committed
3594
                encoder_output,
3595
3596
3597
3598
                is_first_microbatch=is_first_microbatch,
            )

            if self.qkv_weight_interleaved:
cyanguwa's avatar
cyanguwa committed
3599
                # [sq, b, (ng * 2 * hn)] --> [sq, b, ng, 2 * hn]
3600
                new_tensor_shape = mixed_kv_layer.size()[:-1] + (
3601
                    self.num_gqa_groups_per_partition,
3602
3603
3604
3605
3606
                    2 * self.hidden_size_per_attention_head,
                )
                # split along last dimension
                split_dim = -1
            else:
cyanguwa's avatar
cyanguwa committed
3607
                # [sq, b, (ng * 2 * hn)] --> [sq, b, 2 * ng, hn]
3608
                new_tensor_shape = mixed_kv_layer.size()[:-1] + (
3609
                    2 * self.num_gqa_groups_per_partition,
3610
3611
3612
3613
3614
3615
3616
                    self.hidden_size_per_attention_head,
                )
                # split along second last dimension
                split_dim = -2

            mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)

cyanguwa's avatar
cyanguwa committed
3617
3618
3619
3620
3621
            # mixed_kv_layer --> 2 [sk, b, ng, hn]
            if not is_in_onnx_export_mode():
                key_layer, value_layer = _SplitAlongDim.apply(
                    mixed_kv_layer, split_dim, mixed_kv_layer.shape[split_dim] // 2,
                )
3622
            else:
cyanguwa's avatar
cyanguwa committed
3623
3624
3625
                key_layer, value_layer = torch.split(
                    mixed_kv_layer, mixed_kv_layer.shape[split_dim] // 2, dim = split_dim,
                )
3626
3627
3628
3629
3630
3631
3632
3633
3634
3635
3636
3637
3638
3639
3640
3641
3642
3643
3644
3645
3646
3647
3648
3649

            # Attention head [sq, b, h] --> [sq, b, hp]
            if self.input_layernorm:
                layernorm_query_outputs = self.layernorm_query(
                    hidden_states,
                    is_first_microbatch=is_first_microbatch,
                )
                if self.return_layernorm_output:
                    query_layer, layernorm_output = layernorm_query_outputs
                else:
                    query_layer = layernorm_query_outputs
            else:
                query_layer = self.query_layer(
                    hidden_states,
                    is_first_microbatch=is_first_microbatch,
                )

            # [sq, b, hp] --> [sq, b, np, hn]
            new_tensor_shape = query_layer.size()[:-1] + (
                self.num_attention_heads_per_partition,
                self.hidden_size_per_attention_head,
            )
            query_layer = query_layer.view(*new_tensor_shape)

3650
3651
3652
        # ======================================================
        # Apply relative positional encoding (rotary embedding)
        # ======================================================
3653

3654
        if rotary_pos_emb is not None:
3655
            # duplicate the pos_emb for self attention
3656
3657
3658
3659
            if not isinstance(rotary_pos_emb, tuple):
                rotary_pos_emb = ((rotary_pos_emb,) * 2)

            q_pos_emb, k_pos_emb = rotary_pos_emb
3660
3661
3662
3663
3664
3665
3666
3667
3668
3669
3670
3671
3672
3673

            # adjust key and value for inference
            if inference_params is not None:
                if self.qkv_format == "sbhd":
                    sequence_length = key_layer.size(0)
                elif self.qkv_format == "bshd":
                    sequence_length = key_layer.size(1)

                sequence_start = inference_params.sequence_len_offset
                sequence_end = sequence_start + sequence_length

                q_pos_emb = q_pos_emb[sequence_start:sequence_end, ...]
                k_pos_emb = k_pos_emb[sequence_start:sequence_end, ...]

3674
3675
            query_layer = apply_rotary_pos_emb(query_layer, q_pos_emb, self.qkv_format, fused=True)
            key_layer = apply_rotary_pos_emb(key_layer, k_pos_emb, self.qkv_format, fused=True)
3676

3677
3678
3679
3680
        # ===========================
        # Core attention computation
        # ===========================

3681
3682
3683
3684
        context_layer = self.core_attention(
            query_layer,
            key_layer,
            value_layer,
3685
            qkv_format=self.qkv_format,
3686
3687
            cu_seqlens_q=None,
            cu_seqlens_kv=None,
3688
3689
            attention_mask=attention_mask,
            attn_mask_type=attn_mask_type,
3690
            window_size=window_size,
3691
3692
3693
            checkpoint_core_attention=checkpoint_core_attention,
            core_attention_bias_type=core_attention_bias_type,
            core_attention_bias=core_attention_bias,
3694
            alibi_slopes=alibi_slopes,
3695
            fast_zero_fill=fast_zero_fill,
3696
            inference_params=inference_params,
3697
3698
        )

3699
        # ===================
3700
        # Output. [sq, b, h]
3701
        # ===================
3702

3703
        projection_output = self.proj(
3704
3705
3706
            context_layer, is_first_microbatch=is_first_microbatch
        )

3707
3708
3709
3710
3711
3712
3713
3714
        if self.return_bias:
            attention_output, attention_bias = projection_output
        else:
            attention_output, attention_bias = projection_output, None

        outputs = (attention_output,)
        if self.return_bias:
            outputs += (attention_bias,)
3715
        if self.input_layernorm and self.return_layernorm_output:
3716
3717
            outputs += (layernorm_output,)
        return outputs if len(outputs) > 1 else outputs[0]