attention.py 128 KB
Newer Older
1
2
3
4
5
6
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

"""Attention."""
import os
7
import warnings
8
9
10
import math
from importlib.metadata import version
from contextlib import nullcontext
11
from typing import Any, Callable, List, Optional, Tuple, Union, Dict
12
from pkg_resources import packaging
cyanguwa's avatar
cyanguwa committed
13
import numpy as np
14
15

import torch
16
import torch.nn.functional as F
17
18

import transformer_engine_extensions as tex
19
20
21
22
23
from transformer_engine.pytorch.cpp_extensions.fused_attn import (
    fused_attn_fwd_qkvpacked,
    fused_attn_bwd_qkvpacked,
    fused_attn_fwd_kvpacked,
    fused_attn_bwd_kvpacked,
24
25
    fused_attn_fwd,
    fused_attn_bwd,
26
27
28
29
30
    QKVLayout,
    AttnBiasType,
    AttnMaskType,
    FusedAttnBackend,
)
31
32
33
34
35
36
from transformer_engine.pytorch.module import LayerNormLinear, Linear
from transformer_engine.pytorch.utils import (
    divide,
    attention_mask_func,
    split_tensor_along_dim,
    get_device_compute_capability,
37
    get_default_init_method,
38
39
40
41
)
from transformer_engine.pytorch.constants import (
    AttnMaskTypes,
    AttnTypes,
42
    AttnBiasTypes,
43
    QKVLayouts,
44
    dist_group_type,
45
    TE_DType,
46
47
48
49
)
from transformer_engine.pytorch.softmax import FusedScaleMaskSoftmax
from transformer_engine.pytorch.distributed import (
    get_distributed_world_size,
50
    get_distributed_rank,
51
52
53
    checkpoint,
)
from transformer_engine.pytorch.export import is_in_onnx_export_mode
54
from transformer_engine.pytorch.jit import jit_fuser
55
56

_flash_attn_version = packaging.version.Version(version("flash-attn"))
57
_flash_attn_version_required = packaging.version.Version("1.0.6")
58
59
60
61
_flash_attn_2_available = _flash_attn_version >= packaging.version.Version("2")

if _flash_attn_2_available:
    from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_forward_func # pylint: disable=no-name-in-module
62
    from flash_attn_2_cuda import varlen_bwd as flash_attn_cuda_bwd # pylint: disable=no-name-in-module
63
64
    from flash_attn.flash_attn_interface import _flash_attn_varlen_forward as _flash_attn_forward # pylint: disable=no-name-in-module,ungrouped-imports
    from flash_attn.flash_attn_interface import _flash_attn_varlen_backward as _flash_attn_backward # pylint: disable=no-name-in-module
65
else:
66
    from flash_attn.flash_attn_interface import flash_attn_unpadded_func as flash_attn_forward_func # pylint: disable=no-name-in-module,ungrouped-imports
67
    from flash_attn.flash_attn_interface import _flash_attn_forward, _flash_attn_backward
68
69


70
71
72
_cu_seqlens_q, _cu_seqlens_kv, _indices_q, _indices_kv = None, None, None, None


73
__all__ = ["DotProductAttention", "MultiheadAttention"]
74
75


76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
def get_cu_seqlens_and_indices(mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Given a padding mask of shape [batch_size, 1, 1, max_seqlen], returns an int32
    tensor of shape [batch_size + 1,] containing the cumulative sequence
    lengths of every sample in the batch and the indices containing valid
    samples.
    """
    mask = mask.squeeze(1).squeeze(1)
    bs, seqlen = mask.shape

    reduced_mask = mask.sum(dim=1)
    cu_seqlens = reduced_mask.cumsum(dim=0).to(torch.int32)
    zero = torch.zeros(1, dtype=torch.int32, device="cuda")
    cu_seqlens = torch.cat((zero, cu_seqlens))

    mask = mask.reshape(-1)
    indices = mask.nonzero()
    indices = indices.unsqueeze(-1)

    num_nonzeros = indices.shape[0]
    pad_amount = bs * seqlen - num_nonzeros
    indices = F.pad(input=indices, pad=(0, 0, 0, 0, 0, pad_amount),
                    mode="constant", value=float(bs * seqlen))

    return cu_seqlens, indices


@jit_fuser
def pack_tensor(
    indices: torch.Tensor,
    tensor: torch.Tensor,
) -> torch.Tensor:
    """
    Packs the given tensor using the `indices`.
    """
    padding_indice = torch.zeros(
        1, tensor.shape[1], tensor.shape[2], dtype=tensor.dtype, device=tensor.device)
    tensor = torch.cat((tensor, padding_indice), dim=0)

    indices = indices.repeat(1, tensor.shape[1], tensor.shape[2])
    packed = torch.gather(tensor, 0, indices)
    return packed


@jit_fuser
def pack_2_tensors(
    indices: torch.Tensor,
    t1: torch.Tensor,
    t2: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Packs the given 2 tensors using the `indices`.
    """
    t1_packed = pack_tensor(indices, t1)
    t2_packed = pack_tensor(indices, t2)
    return t1_packed, t2_packed


@jit_fuser
def pack_3_tensors(
    indices: torch.Tensor,
    t1: torch.Tensor,
    t2: torch.Tensor,
    t3: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Packs the given 3 tensors using the `indices`.
    """
    t1_packed = pack_tensor(indices, t1)
    t2_packed = pack_tensor(indices, t2)
    t3_packed = pack_tensor(indices, t3)
    return t1_packed, t2_packed, t3_packed


@jit_fuser
def unpack_tensor(
    indices: torch.Tensor,
    dim0: int,
    tensor: torch.Tensor,
) -> torch.Tensor:
    """
    Inverse of `pack_tensor`.
    """
    indices = indices.repeat(1, tensor.shape[1], tensor.shape[2])
    unpacked = torch.zeros(
        dim0 + 1, tensor.shape[1], tensor.shape[2], dtype=tensor.dtype, device=tensor.device)
    unpacked.scatter_(0, indices, tensor)
    unpacked = unpacked[0:-1,:,:]
    return unpacked


@jit_fuser
def unpack_2_tensors(
    indices: torch.Tensor,
    dim0: int,
    t1: torch.Tensor,
    t2: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Inverse of `pack_2_tensors`.
    """
    t1_unpacked = unpack_tensor(indices, dim0, t1)
    t2_unpacked = unpack_tensor(indices, dim0, t2)
    return t1_unpacked, t2_unpacked


@jit_fuser
def unpack_3_tensors(
    indices: torch.Tensor,
    dim0: int,
    t1: torch.Tensor,
    t2: torch.Tensor,
    t3: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Inverse of `pack_3_tensors`.
    """
    t1_unpacked = unpack_tensor(indices, dim0, t1)
    t2_unpacked = unpack_tensor(indices, dim0, t2)
    t3_unpacked = unpack_tensor(indices, dim0, t3)
    return t1_unpacked, t2_unpacked, t3_unpacked


class PackTensors(torch.autograd.Function):
    """
    Autograd function to pack tensors.
    """
    @staticmethod
    def forward(
        ctx,
        indices: torch.Tensor,
        *tensors: Tuple[torch.Tensor, ...]
    ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
        assert 1 <= len(tensors) <= 3, f"Packing {len(tensors)} tensors not supported."
        ctx.indices = indices
        ctx.dim0 = tensors[0].shape[0]
        if len(tensors) == 1:
            return pack_tensor(indices, *tensors)
        if len(tensors) == 2:
            return pack_2_tensors(indices, *tensors)
        return pack_3_tensors(indices, *tensors)

    @staticmethod
    def backward(ctx, *grad_outputs: Tuple[torch.Tensor, ...]):
        if len(grad_outputs) == 1:
            return None, unpack_tensor(ctx.indices, ctx.dim0, *grad_outputs)
        if len(grad_outputs) == 2:
            return None, *unpack_2_tensors(ctx.indices, ctx.dim0, *grad_outputs)
        return None, *unpack_3_tensors(ctx.indices, ctx.dim0, *grad_outputs)


class UnpackTensor(torch.autograd.Function):
    """
    Autograd function to unpack a tensor.
    """
    @staticmethod
    def forward(
        ctx,
        indices: torch.Tensor,
        dim0: int,
        tensor: torch.Tensor,
    ) -> torch.Tensor:
        ctx.indices = indices
        return unpack_tensor(indices, dim0, tensor)

    @staticmethod
    def backward(ctx, grad_output):
        return None, None, pack_tensor(ctx.indices, grad_output)


def _unpack_attn_mask_type(attn_mask_type: str) -> Tuple[str, bool]:
    """
    Unpacks the attention mask type string and returns a single mask type
    and a boolean for whether to apply causal mask. Also ensures that the
    combination of masks passed in is supported by one of the attention
    backends available.
    """
    mask_types = attn_mask_type.split(',')
    assert (
        all(mask_type in AttnMaskTypes for mask_type in mask_types)
    ), f"Mask type {attn_mask_type} is not supported."

    # Whether or not to apply causal mask toggle.
    causal_mask = False
    if "causal" in mask_types:
        mask_types.remove("causal")
        causal_mask = True

    if len(mask_types) == 0:  # Only apply causal mask.
        return "causal", True
    if len(mask_types) == 1 and causal_mask:  # Causal + padding masks
        assert mask_types[0] == "padding", f"Causal + {mask_types[0]} masking not supported."
        return "padding", True
    if len(mask_types) == 1:  # Arbitrary or padding or no_mask
        return mask_types[0], False
    raise RuntimeError("Unsupported combination of mask types.")


274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
def flash_attn_p2p_communicate(rank, send_tensor, send_dst,
                               recv_tensor, recv_src,
                               cp_group, batch_p2p_comm):
    """Point-to-point communications of KV and dKV in Flash Attention with context parallelism"""
    send_recv_ops = []

    if batch_p2p_comm:
        if rank % 2 == 0:
            send_op = torch.distributed.P2POp(torch.distributed.isend,
                                              send_tensor,
                                              send_dst,
                                              cp_group)
            recv_op = torch.distributed.P2POp(torch.distributed.irecv,
                                              recv_tensor,
                                              recv_src,
                                              cp_group)
            send_recv_ops.append(send_op)
            send_recv_ops.append(recv_op)
        else:
            recv_op = torch.distributed.P2POp(torch.distributed.irecv,
                                              recv_tensor,
                                              recv_src,
                                              cp_group)
            send_op = torch.distributed.P2POp(torch.distributed.isend,
                                              send_tensor,
                                              send_dst,
                                              cp_group)
            send_recv_ops.append(recv_op)
            send_recv_ops.append(send_op)
        send_recv_reqs = torch.distributed.batch_isend_irecv(send_recv_ops)
    else:
        if rank % 2 == 0:
            send_op = torch.distributed.isend(send_tensor, send_dst, cp_group)
            recv_op = torch.distributed.irecv(recv_tensor, recv_src, cp_group)
            send_recv_ops.append(send_op)
            send_recv_ops.append(recv_op)
        else:
            recv_op = torch.distributed.irecv(recv_tensor, recv_src, cp_group)
            send_op = torch.distributed.isend(send_tensor, send_dst, cp_group)
            send_recv_ops.append(recv_op)
            send_recv_ops.append(send_op)
        send_recv_reqs = send_recv_ops

    return send_recv_reqs


@torch.jit.script
def flash_attn_fwd_out_correction(out, out_per_step, softmax_lse, softmax_lse_per_step):
    """Merge partial outputs of each step in Flash Attention with context parallelism"""
    softmax_lse_corrected_exp = torch.exp(softmax_lse_per_step - softmax_lse).transpose(1, 2)
    softmax_lse_corrected_exp = softmax_lse_corrected_exp.unsqueeze(-1)
    out_corrected = out_per_step*softmax_lse_corrected_exp
    out.add_(out_corrected)


@torch.jit.script
def flash_attn_fwd_softmax_lse_correction(softmax_lse, softmax_lse_per_step):
    """Merge softmax stats of each step in Flash Attention with context parallelism"""
    softmax_lse.exp_()
    softmax_lse.add_(softmax_lse_per_step.to(torch.double).exp())
    softmax_lse.log_()


class FlashAttnUnpaddedFuncWithCP(torch.autograd.Function):
    """
    Flash Attention implementation with context parallelism.
    Split flash attention compute into multiple steps, and overlap current-step
    compute with next-step communication.
    """

    @staticmethod
    def forward(ctx, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p,
                cp_group, cp_global_ranks, cp_stream, softmax_scale, causal, deterministic):
        if softmax_scale is None:
            softmax_scale = q.shape[-1] ** (-0.5)

        cp_size = get_distributed_world_size(cp_group)
        rank = get_distributed_rank(cp_group)
        send_dst = cp_global_ranks[(rank + 1) % cp_size]
        recv_src = cp_global_ranks[(rank + cp_size - 1) % cp_size]
        batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2)

        # [b, s, np, hn] -> [b, 2, s//2, np, hn]
        q, k, v = [x.view(x.shape[0], 2, x.shape[1]//2, *x.shape[2:]) for x in [q, k, v]]
        if _flash_attn_2_available:
            assert(q.shape[-1] % 8 == 0), "hidden size per attention head should be multiple of 8"
        # Flash Attn inputs
        q_inputs = [None, None]
        kv_inputs = [None, None]
        # Flash Attn outputs
        out_per_step = [None for _ in range(cp_size)]
        softmax_lse_per_step = [None for _ in range(cp_size)]
        rng_states = [None for _ in range(cp_size)]

        # create two streams to resolve wave quantization issue of Flash Attn in each step
        flash_attn_streams = [torch.cuda.current_stream(), cp_stream]
        # synchronize fwd results correction across steps
        fwd_results_correction_done = torch.cuda.Event()

        p2p_comm_buffers = [None for _ in range(cp_size)]
        p2p_comm_buffers[0] = torch.cat((k.unsqueeze(0), v.unsqueeze(0)), dim=0)
        send_recv_reqs = [[], []]

        for i in range(cp_size+1):
            if i < cp_size:
                with torch.cuda.stream(flash_attn_streams[i%2]):
                    # wait until KV is received
                    for req in send_recv_reqs[(i+1)%2]:
                        req.wait()

                    if i < (cp_size-1):
                        p2p_comm_buffers[i+1] = torch.empty_like(p2p_comm_buffers[i])
                        send_recv_reqs[i%2] = flash_attn_p2p_communicate(rank,
                                                                         p2p_comm_buffers[i],
                                                                         send_dst,
                                                                         p2p_comm_buffers[i+1],
                                                                         recv_src,
                                                                         cp_group,
                                                                         batch_p2p_comm)

                    kv_inputs[i%2] = p2p_comm_buffers[i]
                    if causal:
                        if i == 0:
                            # [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
                            q_inputs[i%2] = q.view(-1, *q.shape[-2:])
                            # [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn]
                            kv_inputs[i%2] = kv_inputs[i%2].view(2, -1, *k.shape[-2:])
                            if _flash_attn_2_available:
                                _, _, _, _, out_per_step[i], \
                                softmax_lse_per_step[i], _, rng_states[i] = _flash_attn_forward(
                                    q_inputs[i%2], kv_inputs[i%2][0], kv_inputs[i%2][1],
                                    cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
                                    dropout_p, softmax_scale, causal=True, return_softmax=False,
                                )
                            else:
                                out_per_step[i] = torch.empty_like(q_inputs[i%2])
                                _, softmax_lse_per_step[i], rng_states[i], _ = _flash_attn_forward( # pylint: disable=unbalanced-tuple-unpacking
                                    q_inputs[i%2], kv_inputs[i%2][0], kv_inputs[i%2][1],
                                    out_per_step[i], cu_seqlens_q, cu_seqlens_k,
                                    max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale,
                                    causal=True, return_softmax=False,
                                )
                        elif i <= rank:
                            # [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
                            q_inputs[i%2] = q.view(-1, *q.shape[-2:])
                            # [2, b, sk//2, np, hn] -> [2, b*sk//2, np, hn]
                            kv_inputs[i%2] = kv_inputs[i%2][:, :, 0, ...].contiguous()
                            kv_inputs[i%2] = kv_inputs[i%2].view(2, -1, *k.shape[-2:])
                            if _flash_attn_2_available:
                                _, _, _, _, out_per_step[i], \
                                softmax_lse_per_step[i], _, rng_states[i] = _flash_attn_forward(
                                    q_inputs[i%2], kv_inputs[i%2][0], kv_inputs[i%2][1],
                                    cu_seqlens_q, cu_seqlens_k//2, max_seqlen_q, max_seqlen_k//2,
                                    dropout_p, softmax_scale, causal=False, return_softmax=False,
                                )
                            else:
                                out_per_step[i] = torch.empty_like(q_inputs[i%2])
                                _, softmax_lse_per_step[i], rng_states[i], _ = _flash_attn_forward( # pylint: disable=unbalanced-tuple-unpacking
                                    q_inputs[i%2], kv_inputs[i%2][0], kv_inputs[i%2][1],
                                    out_per_step[i], cu_seqlens_q, cu_seqlens_k//2,
                                    max_seqlen_q, max_seqlen_k//2, dropout_p, softmax_scale,
                                    causal=False, return_softmax=False,
                                )
                        else:
                            # [b, sq//2, np, hn] -> [b*sq//2, np, hn]
                            q_inputs[i%2] = q[:, 1, ...].contiguous().view(-1, *q.shape[-2:])
                            # [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn]
                            kv_inputs[i%2] = kv_inputs[i%2].view(2, -1, *k.shape[-2:])
                            if _flash_attn_2_available:
                                _, _, _, _, out_per_step[i], \
                                softmax_lse_per_step[i], _, rng_states[i] = _flash_attn_forward(
                                    q_inputs[i%2], kv_inputs[i%2][0], kv_inputs[i%2][1],
                                    cu_seqlens_q//2, cu_seqlens_k, max_seqlen_q//2, max_seqlen_k,
                                    dropout_p, softmax_scale, causal=False, return_softmax=False,
                                )
                            else:
                                out_per_step[i] = torch.empty_like(q_inputs[i%2])
                                _, softmax_lse_per_step[i], rng_states[i], _ = _flash_attn_forward( # pylint: disable=unbalanced-tuple-unpacking
                                    q_inputs[i%2], kv_inputs[i%2][0], kv_inputs[i%2][1],
                                    out_per_step[i], cu_seqlens_q//2, cu_seqlens_k,
                                    max_seqlen_q//2, max_seqlen_k, dropout_p, softmax_scale,
                                    causal=False, return_softmax=False,
                                )
                    else:
                        assert False, "Not implemented yet!"

            if i > 0:
                # wait until fwd restuls correction of last step is done
                if i > 1:
                    flash_attn_streams[(i-1)%2].wait_event(fwd_results_correction_done)

                with torch.cuda.stream(flash_attn_streams[(i-1)%2]):
                    if causal:
                        if i == 1:
                            out = torch.empty_like(q).zero_()
                            softmax_lse = torch.clone(softmax_lse_per_step[0]).to(torch.double)
                            # [b, np, sq] -> [b, np, 2, sq//2]
                            softmax_lse_ = softmax_lse.view(
                                *softmax_lse.shape[:-1], 2, softmax_lse.shape[-1]//2
                            )
                        elif (i-1) <= rank:
                            flash_attn_fwd_softmax_lse_correction(softmax_lse,
                                                                  softmax_lse_per_step[i-1])
                        else:
                            flash_attn_fwd_softmax_lse_correction(softmax_lse_[..., 1, :],
                                                                  softmax_lse_per_step[i-1])
                    else:
                        assert False, "Not implemented yet!"

                if i < cp_size:
                    flash_attn_streams[(i-1)%2].record_event(fwd_results_correction_done)

        torch.cuda.current_stream().wait_stream(flash_attn_streams[1])

        softmax_lse = softmax_lse.to(torch.float)
        for i in range(cp_size):
            # [b*sq, np, hn] -> [b, sq, np, hn] or [b*sq//2, np, hn] -> [b, sq//2, np, hn]
            out_ = out_per_step[i].view(out.shape[0], -1, *out.shape[-2:])
            if i <= rank:
                flash_attn_fwd_out_correction(out.view(*out_.shape),
                                              out_,
                                              softmax_lse,
                                              softmax_lse_per_step[i])
            else:
                flash_attn_fwd_out_correction(out[:, 1, ...],
                                              out_,
                                              softmax_lse_[..., 1, :],
                                              softmax_lse_per_step[i])

        kv = p2p_comm_buffers[-1]
        out = out.view(-1, *out.shape[-2:])
        ctx.save_for_backward(q, kv, out, softmax_lse, cu_seqlens_q, cu_seqlens_k)
        ctx.rng_states = rng_states
        ctx.cp_group = cp_group
        ctx.cp_global_ranks = cp_global_ranks
        ctx.dropout_p = dropout_p
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_k = max_seqlen_k
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
        ctx.deterministic = deterministic
        return out

    @staticmethod
    def backward(ctx, dout):
        q, kv, out, softmax_lse, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors

        cp_size = get_distributed_world_size(ctx.cp_group)
        rank = get_distributed_rank(ctx.cp_group)
        send_dst = ctx.cp_global_ranks[(rank + cp_size - 1) % cp_size]
        recv_src = ctx.cp_global_ranks[(rank + 1) % cp_size]
        batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2)

        # [b, np, sq] -> [b, np, 2, sq//2]
        softmax_lse_ = softmax_lse.view(*softmax_lse.shape[:-1], 2, softmax_lse.shape[-1]//2)
        # [b*sq, np, hn] -> [b, 2, sq//2, np, hn]
        out = out.view(*q.shape)
        dout = dout.view(*q.shape)
        # Flash Attn outputs
        dq = torch.empty_like(q)

        p2p_comm_buffers = [torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device), \
                            torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device)]
        p2p_comm_buffers[0][0].copy_(kv)
        send_recv_reqs = []

        fa_optional_backward_kwargs = {}
        if not _flash_attn_2_available:
            fa_optional_backward_kwargs["num_splits"] = 1 if ctx.deterministic else 0

        for i in range(cp_size):
            # wait until KV is received
            for req in send_recv_reqs:
                req.wait()

            send_tensor = p2p_comm_buffers[i%2]
            recv_tensor = p2p_comm_buffers[(i+1)%2]
            if i == 0:
                send_tensor = send_tensor[0]
                recv_tensor = recv_tensor[0]
            if i == (cp_size-1):
                send_tensor = send_tensor[1]
                recv_tensor = recv_tensor[1]

            send_recv_reqs = flash_attn_p2p_communicate(rank,
                                                        send_tensor,
                                                        send_dst,
                                                        recv_tensor,
                                                        recv_src,
                                                        ctx.cp_group,
                                                        batch_p2p_comm)

            kv = p2p_comm_buffers[i%2][0]
            # In reversed order of fwd
            if ctx.causal:
                if i == (cp_size-1):
                    # [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
                    q_ = q.view(-1, *q.shape[-2:])
                    dq_ = torch.empty_like(q_)
                    # [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn]
                    kv_ = kv.view(2, -1, *kv.shape[-2:])
                    dkv_ = torch.empty_like(kv_)
                    # [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
                    out_ = out.view(-1, *out.shape[-2:])
                    dout_ = dout.view(-1, *dout.shape[-2:])
                    _flash_attn_backward(
                        dout_, q_, kv_[0], kv_[1], out_, softmax_lse,
                        dq_, dkv_[0], dkv_[1], cu_seqlens_q, cu_seqlens_k,
                        ctx.max_seqlen_q, ctx.max_seqlen_k,
                        ctx.dropout_p, ctx.softmax_scale, True,
                        rng_state=ctx.rng_states[cp_size-i-1],
                        **fa_optional_backward_kwargs
                    )
                elif i >= (cp_size-rank-1):
                    # [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
                    q_ = q.view(-1, *q.shape[-2:])
                    dq_ = torch.empty_like(q_)
                    # [2, b, sk//2, np, hn] -> [2, b*sk//2, np, hn]
                    kv_ = kv[:, :, 0, ...].contiguous().view(2, -1, *kv.shape[-2:])
                    dkv_ = torch.empty_like(kv_)
                    # [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
                    out_ = out.view(-1, *out.shape[-2:])
                    dout_ = dout.view(-1, *dout.shape[-2:])
                    _flash_attn_backward(
                        dout_, q_, kv_[0], kv_[1], out_, softmax_lse,
                        dq_, dkv_[0], dkv_[1], cu_seqlens_q, cu_seqlens_k//2,
                        ctx.max_seqlen_q, ctx.max_seqlen_k//2,
                        ctx.dropout_p, ctx.softmax_scale, False,
                        rng_state=ctx.rng_states[cp_size-i-1],
                        **fa_optional_backward_kwargs
                    )
                else:
                    # [b, sq//2, np, hn] -> [b*sq//2, np, hn]
                    q_ = q[:, 1, ...].contiguous().view(-1, *q.shape[-2:])
                    dq_ = torch.empty_like(q_)
                    # [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn]
                    kv_ = kv.view(2, -1, *kv.shape[-2:])
                    dkv_ = torch.empty_like(kv_)
                    # [b, sq//2, np, hn] -> [b*sq//2, np, hn]
                    out_ = out[:, 1, ...].contiguous().view(-1, *out.shape[-2:])
                    dout_ = dout[:, 1, ...].contiguous().view(-1, *dout.shape[-2:])
                    _flash_attn_backward(
                        dout_, q_, kv_[0], kv_[1], out_, softmax_lse_[..., 1, :],
                        dq_, dkv_[0], dkv_[1], cu_seqlens_q//2, cu_seqlens_k,
                        ctx.max_seqlen_q//2, ctx.max_seqlen_k,
                        ctx.dropout_p, ctx.softmax_scale, False,
                        rng_state=ctx.rng_states[cp_size-i-1],
                        **fa_optional_backward_kwargs
                    )

                if i >= (cp_size-rank-1):
                    # [b*sq, np, hn] -> [b, 2, sq//2, np, hn]
                    dq_ = dq_.view(*dq.shape)
                else:
                    # [b*sq//2, np, hn] -> [b, sq//2, np, hn]
                    dq_ = dq_.view(dq.shape[0], *dq.shape[2:])

                if i > (cp_size-rank-1):
                    dq.add_(dq_)
                elif i == (cp_size-rank-1):
                    if rank == (cp_size-1):
                        dq.copy_(dq_)
                    else:
                        dq[:, 0, ...].copy_(dq_[:, 0, ...])
                        dq[:, 1, ...].add_(dq_[:, 1, ...])
                elif i > 0:
                    dq[:, 1, ...].add_(dq_)
                else:
                    dq[:, 1, ...].copy_(dq_)

                # wait until dKV is received
                for req in send_recv_reqs:
                    req.wait()

                dkv = p2p_comm_buffers[(i+1)%2][1]
                if i >= (cp_size-rank-1) and i != (cp_size-1):
                    # [2, b*sk//2, np, hn] -> [2, b, sk//2, np, hn]
                    dkv_ = dkv_.view(*dkv.shape[0:2], *dkv.shape[3:])
                else:
                    # [2, b*sk, np, hn] -> [2, b, 2, sk//2, np, hn]
                    dkv_ = dkv_.view(*dkv.shape)

                if i == (cp_size-1):
                    if rank == 0:
                        dkv[:, :, 0, ...].add_(dkv_[:, :, 0, ...])
                        dkv[:, :, 1, ...].copy_(dkv_[:, :, 1, ...])
                    else:
                        dkv.add_(dkv_)
                elif i >= (cp_size-rank-1):
                    if i == 0 and rank == (cp_size-1):
                        dkv[:, :, 0, ...].copy_(dkv_)
                    else:
                        dkv[:, :, 0, ...].add_(dkv_)
                elif i > 0:
                    dkv.add_(dkv_)
                else:
                    dkv.copy_(dkv_)
            else:
                assert False, "Not implemented yet!"

        # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
        dq = dq.view(q.shape[0], -1, *q.shape[-2:])
        # [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn]
        dkv = dkv.view(*kv.shape[0:2], -1, *kv.shape[-2:])
        return dq, dkv[0], dkv[1], None, None, None, None, None, None, None, None, None, None, None


def flash_attn_forward_func_with_cp(q, k, v, cu_seqlens_q, cu_seqlens_k,
                                    max_seqlen_q, max_seqlen_k, dropout_p,
                                    cp_group, cp_global_ranks, cp_stream,
                                    softmax_scale=None, causal=False,
                                    deterministic=False):
    """Flash Attention implementation with context parallelism"""
    out = FlashAttnUnpaddedFuncWithCP.apply(
        q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p,
        cp_group, cp_global_ranks, cp_stream, softmax_scale, causal, deterministic
    )
    return out


694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
class RotaryPositionEmbedding(torch.nn.Module):
    """
    Implements Rotary Position Embedding from https://arxiv.org/abs/2104.09864.
    """
    def __init__(
        self,
        dim: int,
        seq_len_interpolation_factor: Optional[int] = None,
        pretrained_max_position_embeddings: Optional[int] = None,
    ):
        """
        Parameters
        ----------
        dim: int
            rotary embedding dimension
        seq_len_interpolation_factor: int
            if not None, discrete positions will be interpolated by this factor via the trick in
            https://arxiv.org/abs/2306.15595
        pretrained_max_position_embeddings: int
            pre-trained max_position_embeddings before position interpolation
        """
        super().__init__()
        self.seq_len_interpolation_factor = seq_len_interpolation_factor
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer('inv_freq', inv_freq)
        self.pretrained_max_position_embeddings = pretrained_max_position_embeddings

    def forward(self, max_seq_len: int, offset: int = 0):
        """
        Create rotary position embedding frequencies

        Parameters
        ----------
        max_seq_len: int
            sequence length of a sample
        offset: int, default = 0
            fixed offset for freqencies
        """
        seq = torch.arange(max_seq_len, device=self.inv_freq.device) + offset
        seq = seq.type_as(self.inv_freq)

        if (self.pretrained_max_position_embeddings is not None
            and self.seq_len_interpolation_factor is not None):
            if (max_seq_len >
                self.pretrained_max_position_embeddings * self.seq_len_interpolation_factor):
                # dynamic linear scaling (length > position we have learned)
                seq *= 1 / (max_seq_len / self.pretrained_max_position_embeddings)
            else:
                # fixed linear scaling
                seq *= 1 / self.seq_len_interpolation_factor

        freqs = torch.einsum('i , j -> i j', seq, self.inv_freq)
        # first part even vector components, second part odd vector components,
        #  2 * dim in dimension size
        emb = torch.cat((freqs, freqs), dim=-1)
        # emb [seq_length, .., dim]
        return emb.reshape(emb.size(0), 1, 1, emb.size(1))

752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
    """
    change sign so the last dimension becomes [-odd, +even]
    """
    x = x.view(x.shape[:-1] + torch.Size((2, x.shape[-1] // 2)))
    x1, x2 = x.unbind(dim=-2)
    return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
    """
    input tensor t is of shape [seq_length, ..., dim]
    rotary positional embeding tensor `freqs` is of shape [seq_length, ..., dim]
    """
    rot_dim = freqs.shape[-1]
    # ideally t_pass is empty so rotary pos embedding is applied to all tensor t
    t, t_pass = t[..., :rot_dim], t[..., rot_dim:]

    # first part is cosine component
    # second part is sine component, need to change signs with _rotate_half method
    t = (t * freqs.cos()) + (_rotate_half(t) * freqs.sin())
    return torch.cat((t, t_pass), dim=-1)


cyanguwa's avatar
cyanguwa committed
776
class _SplitAlongDim(torch.autograd.Function):
777
778
779
780
781
    """"""

    @staticmethod
    def forward(ctx,
                mixed_x_layer: torch.Tensor,
cyanguwa's avatar
cyanguwa committed
782
783
                split_dim: int,
                split_size_or_sections: Union[int, List[int], Tuple[int]],
784
    ) -> Tuple[torch.Tensor, ...]:
cyanguwa's avatar
cyanguwa committed
785
786
787
        ctx.split_dim = split_dim
        ctx.split_size_or_sections = split_size_or_sections
        return torch.split(mixed_x_layer, split_size_or_sections, dim = split_dim)
788
789
790
791
792
793

    @staticmethod
    def backward(ctx,
                 *grad_outputs):
        assert len(grad_outputs) > 0, "No gradients received for backprop!"

cyanguwa's avatar
cyanguwa committed
794
795
796
797
798
799
800
801
802
        if isinstance(ctx.split_size_or_sections, (list, tuple)):
            split_sizes = ctx.split_size_or_sections
            assert (len(grad_outputs) == len(split_sizes)
                ), "Unequal number of gradients vs split sections for backprop!"
        if isinstance(ctx.split_size_or_sections, int):
            split_sizes = [ctx.split_size_or_sections] * len(grad_outputs)
        dims = len(grad_outputs[0].shape)
        split_dim = (ctx.split_dim + dims) % dims

803
804
805
        noop_ok = True
        strides = grad_outputs[0].stride()
        data_ptr = grad_outputs[0].storage().data_ptr()
cyanguwa's avatar
cyanguwa committed
806
        shape = list(grad_outputs[0].shape)
807
        for i, tensor in enumerate(grad_outputs):
cyanguwa's avatar
cyanguwa committed
808
809
810
            shape_i = shape
            shape_i[split_dim] = split_sizes[i]
            offset_size = sum(split_sizes[:i]) * np.prod(shape[split_dim+1:])
811
            if (tensor.stride() != strides or
cyanguwa's avatar
cyanguwa committed
812
                list(tensor.shape) != shape_i or
813
                tensor.storage().data_ptr() != data_ptr or
cyanguwa's avatar
cyanguwa committed
814
                tensor.storage_offset() != offset_size):
815
816
817
818
819
820
821
                noop_ok = False
                break

        if noop_ok:
            ret = torch.Tensor().to(device=grad_outputs[0].device,
                                    dtype=grad_outputs[0].dtype)
            new_shape = list(shape)
cyanguwa's avatar
cyanguwa committed
822
823
            new_shape[split_dim] = sum(split_sizes)
            ret.set_(grad_outputs[0].untyped_storage(),
824
825
                     grad_outputs[0].storage_offset(),
                     new_shape,
cyanguwa's avatar
cyanguwa committed
826
                     strides
827
            )
cyanguwa's avatar
cyanguwa committed
828
            return ret, None, None
829

cyanguwa's avatar
cyanguwa committed
830
        return torch.cat(grad_outputs, dim = split_dim), None, None
831
832


833

834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
class UnfusedDotProductAttention(torch.nn.Module):
    """Parallel attention w/o QKV and Proj Gemms
    BMM1 -> softmax + dropout -> BMM2
    """

    def __init__(
        self,
        norm_factor: float,
        attention_dropout: float = 0.0,
        attention_dropout_ctx: Optional[Callable] = nullcontext,
        layer_number: Optional[int] = None,
    ) -> None:
        super().__init__()

        self.norm_factor = norm_factor
        self.attention_dropout_ctx = attention_dropout_ctx
        self.layer_number = layer_number

852
        self.scale_mask_softmax = FusedScaleMaskSoftmax(attention_mask_func)
853
854
855
856
857
858

        # Dropout. Note that for a single iteration, this layer will generate
        # different outputs on different number of parallel partitions but
        # on average it should not be partition dependent.
        self.attention_dropout = torch.nn.Dropout(attention_dropout)

859
860
861
862
        # An FP16 training trick required for certain GPT-like models.
        self.apply_qk_layer_scaling = (
            bool(int(os.getenv("NVTE_APPLY_QK_LAYER_SCALING", "0"))) and layer_number is not None)

863
864
865
866
867
    def forward(
        self,
        query_layer: torch.Tensor,
        key_layer: torch.Tensor,
        value_layer: torch.Tensor,
868
869
870
        qkv_layout: str = "sbh3d",
        cu_seqlens_q: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
        cu_seqlens_kv: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
871
        attn_mask_type: str = "causal",
872
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
873
874
        core_attention_bias_type: str = "no_bias",
        core_attention_bias: Optional[torch.Tensor] = None,
875
876
    ) -> torch.Tensor:
        """core attention fprop"""
877

878
879
880
881
882
883
884
885
886
        assert (qkv_layout in QKVLayouts
            ), f"UnfusedDotProductAttention does not support qkv_layout = {qkv_layout}!"
        qkv_format = ''.join([i for i in qkv_layout.split('_')[0] if i.isalpha()])
        assert (qkv_format != 'thd'
            ), """UnfusedDotProductAttention does not support variable sequence lengths!"""
        if qkv_format == 'bshd':
            # convert to sbhd and use sbhd implementation for now
            query_layer, key_layer, value_layer = [x.transpose(0, 1)
                for x in [query_layer, key_layer, value_layer]]
887
888
889
890
        assert (
            attn_mask_type in AttnMaskTypes
        ), f"attn_mask_type {attn_mask_type} not supported"

891
        batch_size, seqlen = query_layer.shape[1], query_layer.shape[0]
892
        apply_qk_layer_scaling = self.apply_qk_layer_scaling and key_layer.dtype == torch.float16
893
894
895
896
897
898
899
900
901

        # [b, np, sq, sk]
        output_size = (
            query_layer.size(1),
            query_layer.size(2),
            query_layer.size(0),
            key_layer.size(0),
        )

902
903
904
905
906
907
908
909
        if key_layer.shape[2] != query_layer.shape[2]:
            assert (query_layer.shape[2]%key_layer.shape[2]==0
                ),"The number of attention heads must be divisible by the number of GQA groups!"
            key_layer = key_layer.repeat_interleave(
                    int(query_layer.shape[2]/key_layer.shape[2]), dim = 2)
            value_layer = value_layer.repeat_interleave(
                    int(query_layer.shape[2]/value_layer.shape[2]), dim = 2)

910
911
912
913
914
915
916
917
        # [sq, b, np, hn] -> [sq, b * np, hn]
        query_layer = query_layer.reshape(
            output_size[2], output_size[0] * output_size[1], -1
        )
        # [sk, b, np, hn] -> [sk, b * np, hn]
        key_layer = key_layer.reshape(output_size[3], output_size[0] * output_size[1], -1)

        # preallocting result tensor: [b * np, sq, sk]
918
919
        # WAR to set dtype to FP32 as ONNX lacks BF16 support for ConstantOfShape operator
        is_bf16 = query_layer.dtype == torch.bfloat16
920
921
922
923
        matmul_result = torch.empty(
            output_size[0] * output_size[1],
            output_size[2],
            output_size[3],
924
            dtype=torch.float32 if is_in_onnx_export_mode() and is_bf16 else query_layer.dtype,
925
926
927
            device=torch.cuda.current_device(),
        )

928
929
930
        if is_in_onnx_export_mode() and is_bf16:
            matmul_result = matmul_result.bfloat16()

931
932
933
934
935
        scale = self.norm_factor
        if apply_qk_layer_scaling:
            scale *= self.layer_number

        # Raw attention scores. [b * np, sq, sk]
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
        if core_attention_bias_type == "no_bias":
            matmul_result = torch.baddbmm(
                matmul_result,
                query_layer.transpose(0, 1),  # [b * np, sq, hn]
                key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
                beta=0.0,
                alpha=(1.0 / scale),
            )

        elif core_attention_bias_type == "pre_scale_bias":
            assert core_attention_bias is not None, "core_attention_bias should not be None!"
            assert (core_attention_bias.shape == torch.Size(1, *output_size[1:])
                    ), "core_attention_bias must be in [1, h, sq, skv] shape!"
            matmul_result = torch.bmm(
                query_layer.transpose(0, 1),  # [b * np, sq, hn]
                key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
            )
            matmul_result = (matmul_result.view(
                output_size[0], output_size[1], output_size[2], output_size[3])
                + core_attention_bias).view(-1, output_size[2], output_size[3])
            matmul_result /= scale

        elif core_attention_bias_type == "post_scale_bias":
            assert core_attention_bias is not None, "core_attention_bias should not be None!"
            assert (core_attention_bias.shape == torch.Size([1, *output_size[1:]])
                    ), "core_attention_bias must be in [1, h, sq, skv] shape!"
            matmul_result = torch.baddbmm(
                matmul_result,
                query_layer.transpose(0, 1),  # [b * np, sq, hn]
                key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
                beta=0.0,
                alpha=(1.0 / scale),
            )
            matmul_result = (matmul_result.view(
                output_size[0], output_size[1], output_size[2], output_size[3])
                + core_attention_bias).view(-1, output_size[2], output_size[3])
972
973
974
975
976
977

        # change view to [b, np, sq, sk]
        attention_scores = matmul_result.view(*output_size)

        # attention scores and attention mask [b, np, sq, sk]
        softmax_scale = self.layer_number if apply_qk_layer_scaling else None
978
979
        attention_probs = self.scale_mask_softmax(
            attention_scores, attention_mask, attn_mask_type, softmax_scale)
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        with self.attention_dropout_ctx():
            attention_probs = self.attention_dropout(attention_probs)

        # value_layer -> context layer.
        # [sk, b, np, hn] --> [b, np, sq, hn]
        output_size = (
            value_layer.size(1),
            value_layer.size(2),
            query_layer.size(0),
            value_layer.size(3),
        )

        # change view [sk, b * np, hn]
        value_layer = value_layer.reshape(
            value_layer.size(0), output_size[0] * output_size[1], -1
        )

        # change view [b * np, sq, sk]
        attention_probs = attention_probs.view(
            output_size[0] * output_size[1], output_size[2], -1
        )

        # matmul: [b * np, sq, hn]
        context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))

        # change view [b, np, sq, hn]
        context_layer = context_layer.view(*output_size)

1011
1012
1013
        if qkv_format == 'sbhd':
            # [b, np, sq, hn] --> [sq, b, np, hn]
            context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
1014

1015
1016
1017
1018
1019
1020
1021
1022
1023
            # [sq, b, np, hn] --> [sq, b, hp]
            context_layer = context_layer.view(seqlen, batch_size, -1)

        if qkv_format == 'bshd':
            # [b, np, sq, hn] --> [b, sq, np, hn]
            context_layer = context_layer.permute(0, 2, 1, 3).contiguous()

            # [b, sq, np, hn] --> [b, sq, hp]
            context_layer = context_layer.view(batch_size, seqlen, -1)
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057

        return context_layer


class _PrepareQKVForFA(torch.autograd.Function):
    """This class converts QKV from interleaved (s, b, ...) layout
       to separate contiguous q, k, v tensors in (b, s, ...) layout."""

    @staticmethod
    def forward(ctx,
                query_layer: torch.Tensor,
                key_layer: torch.Tensor,
                value_layer: torch.Tensor
    ) -> torch.Tensor:
        # All inputs received are non-contiguous tensors.
        # The `query_layer` tensor is used to access the
        # full memory region of the QKV tensor.
        qkv = tex.fa_prepare_fwd(query_layer)
        q, k, v = split_tensor_along_dim(qkv, 0, 3)
        query_layer = torch.squeeze(q, 0)
        key_layer = torch.squeeze(k, 0)
        value_layer = torch.squeeze(v, 0)
        return query_layer, key_layer, value_layer

    @staticmethod
    def backward(ctx,
                 dq: torch.Tensor,
                 dk: torch.Tensor,
                 dv: torch.Tensor
    ) -> Tuple[Union[torch.Tensor, None], ...]:
        dqkv = tex.fa_prepare_bwd(dq, dk, dv)
        dq, dk, dv = split_tensor_along_dim(dqkv, -1, 3)
        return dq, dk, dv

1058
1059
1060
1061
1062
1063
1064
def _get_qkv_layout(
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
        qkv_format: str = 'sbhd',
    ) -> str:
    """Get qkv layout.
1065

1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
    Parameters
    ----------
    q: torch.Tensor
        Query tensor.
    k: torch.Tensor
        Key tensor.
    v: torch.Tensor
        Value tensor.
    qkv_format: str, default = `sbhd`
        Dimension format for `q`, `k` and `v`, {`sbhd`, `bshd`, `thd`}. `s` stands for
        the sequence length dimension, `b` batch size, `h` the number of attention heads,
        `d` head size, and `t` the total number of sequences in a batch, i.e.
        `t = sum(s_i) for i = 0...b-1`.

    Returns
    ----------
    qkv_layout: str
       Memory layout of `q`, `k` and `v`. Each `qkv_format` can be mapped to one of five
       memory layouts. For example, `sb3hd` means `q`, `k`, `v` are created as one chunk
       of memory and that they are interleaved in the `2`nd dimension. `sbhd_sbh2d` means
       `q` and `kv` are created in two chunks and that `q` itself is contiguous and `k`, `v`
       are interleaved with each other in the `3`rd dimension, `k = kv[:,:,:,0,:]` and
       `v = kv[:,:,:,1,:]`.
       Mapping:
       `sbhd`: {`sb3hd`, `sbh3d`, `sbhd_sb2hd`, `sbhd_sbh2d`, `sbhd_sbhd_sbhd`}
       `bshd`: {`bs3hd`, `bsh3d`, `bshd_bs2hd`, `bshd_bsh2d`, `bshd_bshd_bshd`}
       `thd` : {`t3hd`, `th3d`, `thd_t2hd`, `thd_th2d`, `thd_thd_thd`}
    """
1094

1095
1096
    check_last_dim_contiguous = all(x.stride(-1) == 1 for x in [q, k, v])
    assert check_last_dim_contiguous, "q, k and v must have stride 1 in their last dimension!"
1097

1098
1099
    data_ptr = q.untyped_storage().data_ptr()
    check_ptrs_qkv = all(x.untyped_storage().data_ptr() == data_ptr for x in [q, k, v])
1100
    data_ptr = k.untyped_storage().data_ptr()
1101
    check_ptrs_kv = all(x.untyped_storage().data_ptr() == data_ptr for x in [k, v])
1102

1103
1104
    stride = q.stride()
    check_strides_qkv = all(stride == x.stride() for x in [q, k, v])
1105
    stride = k.stride()
1106
    check_strides_kv = all(stride == x.stride() for x in [k, v])
1107

1108
1109
    shape = q.shape
    check_shapes_qkv = all(shape == x.shape for x in [q, k, v])
1110
    shape = k.shape
1111
    check_shapes_kv = all(shape == x.shape for x in [k, v])
1112

1113
1114
1115
1116
1117
    last_dim_size = q.shape[-1]
    check_last_dim_offsets_qkv = all(i * last_dim_size == x.storage_offset()
                        for i, x in enumerate([q, k, v]))
    last_dim_size = k.shape[-1]
    check_last_dim_offsets_kv = all(i * last_dim_size == x.storage_offset()
1118
                        for i, x in enumerate([k, v]))
1119

1120
1121
1122
1123
1124
    last_two_dims_size = q.shape[-1] * q.shape[-2]
    check_last_two_dims_offsets_qkv = all(i * last_two_dims_size == x.storage_offset()
                        for i, x in enumerate([q, k, v]))
    last_two_dims_size = k.shape[-1] * k.shape[-2]
    check_last_two_dims_offsets_kv = all(i * last_two_dims_size == x.storage_offset()
1125
                        for i, x in enumerate([k, v]))
1126

1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
    qkv_layout = None
    if (check_ptrs_qkv and check_strides_qkv and check_shapes_qkv
        and check_last_two_dims_offsets_qkv
        and not check_last_dim_offsets_qkv):
        # sb3hd, bs3hd, t3hd
        qkv_layout = qkv_format[:-2] + '3' + qkv_format[-2:]
    elif check_ptrs_qkv and check_strides_qkv and check_shapes_qkv and check_last_dim_offsets_qkv:
        # sbh3d, bsh3d, th3d
        qkv_layout = qkv_format[:-1] + '3' + qkv_format[-1:]
    elif (check_ptrs_kv and check_strides_kv and check_shapes_kv
        and check_last_two_dims_offsets_kv
        and not check_last_dim_offsets_kv):
        # sbhd_sb2hd, bshd_bs2hd, thd_t2hd
        qkv_layout = qkv_format + '_' + qkv_format[:-2] + '2' + qkv_format[-2:]
    elif (check_ptrs_kv and check_strides_kv and check_shapes_kv
        and check_last_dim_offsets_kv):
        # sbhd_sbh2d, bshd_bsh2d, thd_th2d
        qkv_layout = qkv_format + '_' + qkv_format[:-1] + '2' + qkv_format[-1:]
    elif check_strides_kv and check_shapes_kv:
        # sbhd_sbhd_sbhd, bshd_bshd_bshd, thd_thd_thd
        qkv_layout = '_'.join(list([qkv_format])*3)
    else:
        raise Exception("The provided qkv memory layout is not supported!")

    return qkv_layout
1152

1153
1154

class FlashAttention(torch.nn.Module):
1155
    """Dot product attention, using HazyResearch flash-attn package:
1156
    https://github.com/Dao-AILab/flash-attention
1157
1158
1159
1160
1161
1162
1163
    """

    def __init__(
        self,
        norm_factor: float,
        attention_dropout: float = 0.0,
        attention_dropout_ctx: Optional[Callable] = nullcontext,
1164
1165
        attention_type: str = "self",
        layer_number: Optional[int] = None,
1166
        deterministic: bool = False,
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
    ) -> None:
        super().__init__()

        assert (
            _flash_attn_version >= _flash_attn_version_required
        ), f"FlashAttention minimum version {_flash_attn_version_required} is required."

        self.norm_factor = norm_factor
        self.attention_dropout_ctx = attention_dropout_ctx
        self.attention_dropout = attention_dropout
1177
1178
        self.attention_type = attention_type
        self.layer_number = 1 if layer_number is None else layer_number
1179
        self.deterministic = deterministic
1180
1181
1182
1183
1184
1185

    def forward(
        self,
        query_layer: torch.Tensor,
        key_layer: torch.Tensor,
        value_layer: torch.Tensor,
1186
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
1187
1188
1189
        qkv_layout: str = "sbh3d",
        cu_seqlens_q: Optional[torch.Tensor] = None,
        cu_seqlens_kv: Optional[torch.Tensor] = None,
1190
        attn_mask_type: str = "causal",
1191
        cp_group: Optional[dist_group_type] = None,
1192
        cp_global_ranks: List[int] = None,
1193
        cp_stream: torch.cuda.Stream = None,
1194
1195
1196
1197
    ) -> torch.Tensor:
        """flash-attn fprop"""

        assert (
1198
1199
1200
            query_layer.dtype in [torch.float16, torch.bfloat16]
            and key_layer.dtype in [torch.float16, torch.bfloat16]
            and value_layer.dtype in [torch.float16, torch.bfloat16]
1201
            ), "FlashAttention currently only supports FP16 and BF16."
1202
1203
        assert (
            query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda
1204
1205
1206
1207
1208
            ), "FlashAttention currently only supports CUDA tensors."
        assert (
            qkv_layout in QKVLayouts
            ), f"FlashAttention does not support qkv_layout = {qkv_layout}!"

1209
1210
        context_parallel = (cp_group is not None) and (get_distributed_world_size(cp_group) != 1)

1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
        qkv_format = ''.join([i for i in qkv_layout.split('_')[0] if i.isalpha()])

        if qkv_format == 'sbhd':
            # For now just 128, will make it more general in the future
            if (query_layer.shape[-1] == 128 and
                query_layer.shape[0] * query_layer.shape[1] >= 512 and
                qkv_layout == "sbh3d"):
                query_layer, key_layer, value_layer = _PrepareQKVForFA.apply(query_layer,
                                                                             key_layer,
                                                                             value_layer)
            else:
                query_layer, key_layer, value_layer = [x.transpose(0,1).contiguous()
                    for x in (query_layer, key_layer, value_layer)]
1224
        elif qkv_format == 'bshd':
1225
1226
1227
            query_layer, key_layer, value_layer = [x.contiguous()
                for x in (query_layer, key_layer, value_layer)]

1228
1229
1230
        global _cu_seqlens_q, _cu_seqlens_kv, _indices_q, _indices_kv
        batch_size, max_seqlen_q, max_seqlen_kv = (
                query_layer.shape[0], query_layer.shape[1], key_layer.shape[1])
1231

1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
        if qkv_format in ['sbhd', 'bshd']:
            if not context_parallel:
                # [b * s, h, d]
                query_layer, key_layer, value_layer = [
                    x.view(x.shape[0] * x.shape[1], *x.shape[2:])
                    for x in [query_layer, key_layer, value_layer]
                ]

            if attn_mask_type == 'padding':
                assert not context_parallel, "Padding mask not supported with context parallelism."

                if self.attention_type == "self":
                    assert (
                        max_seqlen_q == max_seqlen_kv
                    ), "Maximum sequence length for Q and KV should be the same."
                    if self.layer_number == 1:
                        _cu_seqlens_q, _indices_q = get_cu_seqlens_and_indices(attention_mask)
                    _cu_seqlens_kv = _cu_seqlens_q
                    query_layer_packed, key_layer_packed, value_layer_packed = PackTensors.apply(
                        _indices_q, query_layer, key_layer, value_layer
                    )
                else:
                    if self.layer_number == 1:
                        _cu_seqlens_q, _indices_q = get_cu_seqlens_and_indices(attention_mask[0])
                        _cu_seqlens_kv, _indices_kv = get_cu_seqlens_and_indices(attention_mask[1])
                    query_layer_packed = PackTensors.apply(_indices_q, query_layer)
                    key_layer_packed, value_layer_packed = PackTensors.apply(
                        _indices_kv, key_layer, value_layer
                    )
                query_layer, key_layer, value_layer = (
                    query_layer_packed, key_layer_packed, value_layer_packed)
                cu_seqlens_q, cu_seqlens_kv = _cu_seqlens_q, _cu_seqlens_kv
            else:
                if cu_seqlens_q is None:
                    cu_seqlens_q = torch.arange(
                            0,
                            (batch_size + 1) * max_seqlen_q,
                            step=max_seqlen_q,
                            dtype=torch.int32,
                            device=query_layer.device)
                if cu_seqlens_kv is None:
                    cu_seqlens_kv = torch.arange(
                            0,
                            (batch_size + 1) * max_seqlen_kv,
                            step=max_seqlen_kv,
                            dtype=torch.int32,
                            device=key_layer.device)
        elif qkv_format == 'thd':
            assert not context_parallel, "thd format is not supported for context parallelism!"
1281
1282
1283
1284
1285
1286
1287
1288
            assert (_flash_attn_2_available
                ), "flash-attn v2 is required for variable sequence length support!"
            assert (cu_seqlens_q is not None and cu_seqlens_kv is not None
                ), "cu_seqlens_q and cu_seqlens_kv can not be None when qkv_format = thd!"
            seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
            seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
            max_seqlen_q = seqlens_q.max().item()
            max_seqlen_kv = seqlens_kv.max().item()
1289

1290
        if context_parallel:
1291
            with self.attention_dropout_ctx():
1292
                output = flash_attn_forward_func_with_cp(
1293
                    query_layer, key_layer, value_layer,
1294
                    cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv,
1295
                    self.attention_dropout if self.training else 0.0,
1296
                    cp_group, cp_global_ranks, cp_stream,
1297
1298
                    softmax_scale=1.0/self.norm_factor,
                    causal=attn_mask_type=="causal",
1299
                    deterministic=self.deterministic
1300
1301
1302
                )
        else:
            with self.attention_dropout_ctx():
1303
1304
1305
1306
                fa_optional_forward_kwargs = {}
                if not _flash_attn_2_available:
                    fa_optional_forward_kwargs["deterministic"] = self.deterministic
                output = flash_attn_forward_func(
1307
                    query_layer, key_layer, value_layer,
1308
                    cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv,
1309
                    self.attention_dropout if self.training else 0.0,
1310
1311
                    softmax_scale=1.0/self.norm_factor, causal=attn_mask_type=="causal",
                    **fa_optional_forward_kwargs
1312
                )
1313

1314
1315
1316
        if attn_mask_type == 'padding':
            output = UnpackTensor.apply(_indices_q, batch_size * max_seqlen_q, output)

1317
1318
1319
        if qkv_format == 'sbhd':
            # (bs)hd -> bs(hd) -> sb(hd)
            output = output.view(batch_size, max_seqlen_q, -1).transpose(0, 1).contiguous()
1320
        elif qkv_format == 'bshd':
1321
1322
1323
1324
            # (bs)hd -> bs(hd)
            output = output.view(batch_size, max_seqlen_q, -1).contiguous()

        return output
1325
1326


1327
1328
1329
1330
1331
1332
class FusedAttnFunc_qkvpacked(torch.autograd.Function):
    """Function for FusedAttention with packed QKV input"""

    @staticmethod
    def forward(ctx, is_training, max_seqlen, cu_seqlens, qkv, qkv_dtype, attn_bias, attn_scale,
                dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type,
1333
                rng_gen, fused_attention_backend, use_FAv2_bwd):
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
        out, aux_ctx_tensors = fused_attn_fwd_qkvpacked(
            is_training, max_seqlen, cu_seqlens, qkv, qkv_dtype,
            fused_attention_backend, attn_bias,
            None, None, None, None, None,
            attn_scale, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type,
            rng_gen)

        ctx.save_for_backward(qkv, out, cu_seqlens)
        ctx.aux_ctx_tensors = aux_ctx_tensors
        ctx.max_seqlen = max_seqlen
        ctx.qkv_dtype = qkv_dtype
        ctx.attn_scale = attn_scale
        ctx.dropout_p = dropout_p
        ctx.fast_zero_fill = fast_zero_fill
        ctx.qkv_layout = qkv_layout
        ctx.attn_bias_type = attn_bias_type
        ctx.attn_mask_type = attn_mask_type
        ctx.fused_attention_backend = fused_attention_backend
1352
        ctx.use_FAv2_bwd = use_FAv2_bwd
1353
1354
1355
1356
1357
1358

        return out

    @staticmethod
    def backward(ctx, d_out):
        qkv, out, cu_seqlens = ctx.saved_tensors
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
        if ctx.use_FAv2_bwd:
            softmax_lse, rng_state = ctx.aux_ctx_tensors
            dqkv = torch.empty_like(qkv)
            maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
            d_out, q, k, v, out = [maybe_contiguous(x)
                for x in (d_out, qkv[:,0], qkv[:,1], qkv[:,2], out)]
            flash_attn_cuda_bwd(
                d_out, q, k, v, out, softmax_lse, dqkv[:,0], dqkv[:,1], dqkv[:,2],
                cu_seqlens, cu_seqlens, ctx.max_seqlen, ctx.max_seqlen,
                ctx.dropout_p, ctx.attn_scale, False,
                ctx.attn_mask_type == "causal", None, rng_state
            )
            dqkv = dqkv[..., :d_out.shape[-1]]
        else:
            dqkv, *rest = fused_attn_bwd_qkvpacked(
                ctx.max_seqlen, cu_seqlens, qkv, out, d_out,
                ctx.qkv_dtype, ctx.aux_ctx_tensors,
                ctx.fused_attention_backend,
                None, None, None, None, None, None, None, None, None,
                ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill,
                ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type)
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397

        # if no_bias, return dqkv
        if ctx.attn_bias_type == "no_bias":
            return (None, None, None, dqkv, None, None, None,
                    None, None, None, None, None, None,
                    None, None, None, None, None, None)
        # else, return (dqkv, dbias)
        return (None, None, None, dqkv, None, rest[0], None,
                None, None, None, None, None, None,
                None, None, None, None, None, None)

class FusedAttnFunc_kvpacked(torch.autograd.Function):
    """Function for FusedAttention with packed KV input"""

    @staticmethod
    def forward(ctx, is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
                q, kv, qkv_dtype, attn_bias, attn_scale, dropout_p, fast_zero_fill,
                qkv_layout, attn_bias_type, attn_mask_type,
1398
                rng_gen, fused_attention_backend, use_FAv2_bwd):
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
        out, aux_ctx_tensors = fused_attn_fwd_kvpacked(
            is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
            q, kv, qkv_dtype, fused_attention_backend, attn_bias,
            None, None, None, None, None,
            attn_scale, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type,
            rng_gen)

        ctx.save_for_backward(q, kv, out, cu_seqlens_q, cu_seqlens_kv)
        ctx.aux_ctx_tensors = aux_ctx_tensors
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_kv = max_seqlen_kv
        ctx.qkv_dtype = qkv_dtype
        ctx.attn_scale = attn_scale
        ctx.dropout_p = dropout_p
        ctx.fast_zero_fill = fast_zero_fill
        ctx.qkv_layout = qkv_layout
        ctx.attn_bias_type = attn_bias_type
        ctx.attn_mask_type = attn_mask_type
        ctx.fused_attention_backend = fused_attention_backend
1418
        ctx.use_FAv2_bwd = use_FAv2_bwd
1419
1420
1421
1422
1423
1424

        return out

    @staticmethod
    def backward(ctx, d_out):
        q, kv, out, cu_seqlens_q, cu_seqlens_kv = ctx.saved_tensors
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
        if ctx.use_FAv2_bwd:
            softmax_lse, rng_state = ctx.aux_ctx_tensors
            dq = torch.empty_like(q)
            dkv = torch.empty_like(kv)
            maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
            d_out, q, k, v, out = [maybe_contiguous(x)
                for x in (d_out, q, kv[:,0], kv[:,1], out)]
            flash_attn_cuda_bwd(
                d_out, q, k, v, out, softmax_lse, dq, dkv[:,0], dkv[:,1],
                cu_seqlens_q, cu_seqlens_kv, ctx.max_seqlen_q, ctx.max_seqlen_kv,
                ctx.dropout_p, ctx.attn_scale, False,
                ctx.attn_mask_type == "causal", None, rng_state
            )
            dq = dq[..., :d_out.shape[-1]]
            dkv = dkv[..., :d_out.shape[-1]]
        else:
            dq, dkv, *rest = fused_attn_bwd_kvpacked(
                ctx.max_seqlen_q, ctx.max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
                q, kv, out, d_out,
                ctx.qkv_dtype, ctx.aux_ctx_tensors,
                ctx.fused_attention_backend,
                None, None, None, None, None, None, None, None, None,
                ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill,
                ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type)
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459

        # if no_bias, return dqkv
        if ctx.attn_bias_type == "no_bias":
            return (None, None, None, None, None, dq, dkv, None, None, None,
                    None, None, None, None, None, None,
                    None, None, None, None, None, None)
        # else, return (dqkv, dbias)
        return (None, None, None, None, None, dq, dkv, None, rest[0], None,
                None, None, None, None, None, None,
                None, None, None, None, None, None)

1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
class FusedAttnFunc(torch.autograd.Function):
    """Function for FusedAttention with separate Q, K, V tensors"""

    @staticmethod
    def forward(ctx, is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
                q, k, v, qkv_dtype, attn_bias, attn_scale, dropout_p, fast_zero_fill,
                qkv_layout, attn_bias_type, attn_mask_type,
                rng_gen, fused_attention_backend, use_FAv2_bwd):
        out, aux_ctx_tensors = fused_attn_fwd(
            is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
            q, k, v, qkv_dtype, fused_attention_backend, attn_bias,
            None, None, None, None, None,
            attn_scale, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type,
            rng_gen)

        ctx.save_for_backward(q, k, v, out, cu_seqlens_q, cu_seqlens_kv)
        ctx.aux_ctx_tensors = aux_ctx_tensors
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_kv = max_seqlen_kv
        ctx.qkv_dtype = qkv_dtype
        ctx.attn_scale = attn_scale
        ctx.dropout_p = dropout_p
        ctx.fast_zero_fill = fast_zero_fill
        ctx.qkv_layout = qkv_layout
        ctx.attn_bias_type = attn_bias_type
        ctx.attn_mask_type = attn_mask_type
        ctx.fused_attention_backend = fused_attention_backend
        ctx.use_FAv2_bwd = use_FAv2_bwd

        return out

    @staticmethod
    def backward(ctx, d_out):
        q, k, v, out, cu_seqlens_q, cu_seqlens_kv = ctx.saved_tensors
        if ctx.use_FAv2_bwd:
            softmax_lse, rng_state = ctx.aux_ctx_tensors
            dq = torch.empty_like(q)
            dk = torch.empty_like(k)
            dv = torch.empty_like(v)
            maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
            d_out, q, k, v, out = [maybe_contiguous(x)
                for x in (d_out, q, k, v, out)]
            flash_attn_cuda_bwd(
                d_out, q, k, v, out, softmax_lse, dq, dk, dv,
                cu_seqlens_q, cu_seqlens_kv, ctx.max_seqlen_q, ctx.max_seqlen_kv,
                ctx.dropout_p, ctx.attn_scale, False,
                ctx.attn_mask_type == "causal", None, rng_state
            )
            dq = dq[..., :d_out.shape[-1]]
            dk = dk[..., :d_out.shape[-1]]
            dv = dv[..., :d_out.shape[-1]]
        else:
            dq, dk, dv, *rest = fused_attn_bwd(
                ctx.max_seqlen_q, ctx.max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
                q, k, v, out, d_out,
                ctx.qkv_dtype, ctx.aux_ctx_tensors,
                ctx.fused_attention_backend,
                None, None, None, None, None, None, None, None, None,
                ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill,
                ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type)

        # if no_bias, return dqkv
        if ctx.attn_bias_type == "no_bias":
            return (None, None, None, None, None, dq, dk, dv, None, None, None,
                    None, None, None, None, None, None,
                    None, None, None, None, None, None)
        # else, return (dqkv, dbias)
        return (None, None, None, None, None, dq, dk, dv, None, rest[0], None,
                None, None, None, None, None, None,
                None, None, None, None, None, None)

1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
class FusedAttention(torch.nn.Module):
    """Dot product attention, with multiple backends:

    1. FusedAttnBackend["F16_max512_seqlen"]
       cuDNN based fused attention for FP16/BF16 and <=512 sequence length.
    2. FusedAttnBackend["F16_arbitrary_seqlen"]
       cuDNN based fused attention for FP16/BF16 and any sequence length.

    Support matrix:

1541
1542
1543
1544
1545
1546
1547
1548
    | backend       | 1                       | 2                              |
    | flash based   | no                      | yes                            |
    | cuDNN based   | yes                     | yes                            |
    | qkv dtype     | fp16/bf16               | fp16/bf16                      |
    | attn_type     | self/cross              | self                           |
    | qkv_layout    |                         |                                |
    |  - qkv        | qkv_interleaved         | qkv_interleaved                |
    |  - (q,kv)     | kv_interleaved          |                                |
1549
    |  - (q,k,v)    | sb3hd, bs3hd            | sb3hd, bs3hd, sbh3d, bsh3d     |
1550
    |               | sbhd_sb2hd, bshd_bs2hd  | sbhd_sb2hd, bshd_bs2hd         |
1551
1552
    |               | bshd_bshd_bshd          | sbhd_sbh2d, bshd_bsh2d         |
    |               |                         | sbhd_sbhd_sbhd, bshd_bshd_bshd |
1553
1554
1555
1556
1557
1558
    | mask_type     | causal/no_mask          | causal                         |
    | bias_type     | no_bias/post_scale_bias | no_bias                        |
    | dropout       | yes                     | yes                            |
    | max_seqlen    | <=512                   | any                            |
    | head_dim      | 64                      | 64,128                         |
    | output dtype  | fp16/bf16               | fp16/bf16                      |
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
    """

    def __init__(
        self,
        norm_factor: float,
        attention_dropout: float = 0.0,
        attention_dropout_ctx: Optional[Callable] = nullcontext,
        attention_type: str = "self",
    ) -> None:
        super().__init__()

        self.norm_factor = norm_factor
        self.attention_dropout = attention_dropout
        self.attention_dropout_ctx = attention_dropout_ctx
        self.attention_type = attention_type
1574
        self.use_FAv2_bwd = (os.getenv("NVTE_FUSED_ATTN_USE_FAv2_BWD", "0") == "1"
1575
1576
                        and _flash_attn_2_available
                        and get_device_compute_capability() == 9.0)
1577
1578
1579
1580
1581
1582

    def forward(
        self,
        query_layer: torch.Tensor,
        key_layer: torch.Tensor,
        value_layer: torch.Tensor,
1583
1584
1585
        qkv_layout: str = "sbh3d",
        cu_seqlens_q: Optional[torch.Tensor] = None,
        cu_seqlens_kv: Optional[torch.Tensor] = None,
1586
        attn_mask_type: str = "causal",
1587
1588
        fused_attention_backend:
            tex.NVTE_Fused_Attn_Backend = tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend,
1589
1590
1591
1592
1593
1594
        core_attention_bias_type: str = "no_bias",
        core_attention_bias: Optional[torch.Tensor] = None,
        fast_zero_fill: bool = True,
    ) -> torch.Tensor:
        """fused attention fprop"""

1595
        assert (fused_attention_backend
1596
1597
            != tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend
            ), 'No fused attention backend supports this input combination!'
1598
1599
1600
1601
1602
1603
1604
1605
        assert (
            (query_layer.dtype in [torch.float16, torch.bfloat16])
            and (key_layer.dtype in [torch.float16, torch.bfloat16])
            and (value_layer.dtype in [torch.float16, torch.bfloat16])
            ), 'FusedAttention only supports FP16 and BF16 data types.'
        assert (
            query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda
            ), 'FusedAttention only supports CUDA tensors.'
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
        assert (
            qkv_layout in QKVLayouts
            ), f"FusedAttention does not support qkv_layout = {qkv_layout}!"

        qkv_format = ''.join([i for i in qkv_layout.split('_')[0] if i.isalpha()])
        if qkv_format in ['sbhd', 'bshd']:
            if qkv_format == 'sbhd':
                batch_size, max_seqlen_q, max_seqlen_kv = (
                    query_layer.shape[1], query_layer.shape[0], key_layer.shape[0])
            if qkv_format == 'bshd':
                batch_size, max_seqlen_q, max_seqlen_kv = (
                    query_layer.shape[0], query_layer.shape[1], key_layer.shape[1])
            if cu_seqlens_q is None:
                cu_seqlens_q = torch.arange(
                        0,
                        (batch_size + 1) * max_seqlen_q,
                        step=max_seqlen_q,
                        dtype=torch.int32,
                        device=query_layer.device)
            if cu_seqlens_kv is None:
                cu_seqlens_kv = torch.arange(
                        0,
                        (batch_size + 1) * max_seqlen_kv,
                        step=max_seqlen_kv,
                        dtype=torch.int32,
                        device=key_layer.device)
        if qkv_format == 'thd':
            assert (cu_seqlens_q is not None and cu_seqlens_kv is not None
                ), "cu_seqlens_q and cu_seqlens_kv can not be None when qkv_format = thd!"
            seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
            seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
            max_seqlen_q = seqlens_q.max().item()
            max_seqlen_kv = seqlens_kv.max().item()
1639
1640
1641

        qkv_dtype = TE_DType[query_layer.dtype]

1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
        use_FAv2_bwd = (self.use_FAv2_bwd
                and (fused_attention_backend
                    == tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen))
        with self.attention_dropout_ctx():
            output = FusedAttnFunc.apply(
                self.training,
                max_seqlen_q, max_seqlen_kv,
                cu_seqlens_q, cu_seqlens_kv,
                query_layer, key_layer, value_layer,
                qkv_dtype,
                core_attention_bias,
                1.0/self.norm_factor,
                self.attention_dropout if self.training else 0.0,
                fast_zero_fill,
                qkv_layout,
                core_attention_bias_type,
                attn_mask_type,
                None, # rng_gen
                fused_attention_backend,
                use_FAv2_bwd,
            )
1663

1664
1665
        # ...hd -> ...(hd)
        return output.view(*output.shape[:-2], -1)
1666
1667


1668
1669
1670
1671
1672
1673
1674
class DotProductAttention(torch.nn.Module):
    """Allows the model to jointly attend to information from different
    representation subspaces as described in the paper:
    `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.

    .. note::

1675
1676
        Argument :attr:`attention_mask` in the `forward` call is only used when
        :attr:`self_attn_mask_type` includes `"padding"` or `"arbitrary"`.
1677
1678
1679

    .. warning::

1680
1681
1682
1683
        FlashAttention uses a non-deterministic algorithm for optimal performance. To observe
        deterministic behavior at the cost of performance, use FlashAttention version < `2.0.0`
        and set the environment variable :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO=0`. In order
        to disable`flash-attn` entirely, set :attr:`NVTE_FLASH_ATTN=0`.
1684
1685
1686
1687
1688
1689
1690

    Parameters
    ----------
    num_attention_heads : int
                         number of attention heads in the transformer layer.
    kv_channels : int
                number of key-value channels.
1691
1692
1693
1694
1695
1696
1697
1698
    num_gqa_groups : Optional[int] = None
                    number of GQA groups in the transformer layer.
                    Grouped Query Attention is described in
                    `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
                    This only affects the keys and values, not the queries.
                    GQA-1 is equivalent to Multi-Query Attention
                    (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
                    is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
1699
1700
    attention_dropout: float, default = 0.0
                      dropout probability for the dropout op during multi-head attention.
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
    attn_mask_type: str, default = `causal`
                   type of attention mask passed into softmax operation, options are "`causal`",
                   "`padding`", "`arbitrary`", "`no_mask`". For the "`causal`" mask,
                   TransformerEngine calculates and applies an upper triangular mask to
                   the softmax input. An "`arbitrary`" mask is an arbitrary user defined mask
                   broadcastable to the shape of softmax input. The "`padding`" mask is used
                   for providing locations of padded tokens in the batch, which should be of
                   the shape [batch_size, 1, 1, seq_len]. No mask is applied for the "`no_mask`"
                   option. For the `"arbitrary"` and `"padding"` mask types, the argument
                   :attr:`attention_mask` must be passed into `forward` call. The "`causal`"
                   mask can also be applied in conjunction with "`padding`" mask by passing
                   in multiple mask type as a comma separated string, for example,
                   `attn_mask_type="causal,padding"`.
    attention_type: str, default = `self`
                   type of attention, either "`self`" and "`cross`".
1716
1717
1718
    layer_number: int, default = `None`
                 layer number of the current `DotProductAttention` when multiple such modules
                 are concatenated, for instance in consecutive transformer blocks.
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
    qkv_format: str, default = `sbhd`
               dimension format for `query_layer`, `key_layer` and `value_layer`,
               {`sbhd`, `bshd`, `thd`}. `s` stands for the sequence length, `b` batch size,
               `h` the number of heads, `d` head size, and `t` the total number of sequences
               in a batch, with `t = sum(s_i), for i = 0...b-1`. `sbhd` and `bshd` formats
               are used for when sequences in a batch are of equal length or padded to
               equal length, and the `thd` format is used for when sequences in a batch
               have different lengths. Please note that these formats do not reflect how
               tensors `query_layer`, `key_layer`, `value_layer` are laid out in memory.
               For that, please use `_get_qkv_layout` to gain the layout information.
1729
    attn_mask_type: {'causal', 'padding', 'no_mask', 'arbitrary'}, default = `causal`
1730
1731
1732
1733
1734
                   type of attention mask passed into softmax operation. Overridden by
                   :attr:`attn_mask_type` in the `forward` method. The forward
                   arg is useful for dynamically changing mask types, e.g. a different
                   mask for training and inference. The init arg is useful for cases
                   involving compilation/tracing, e.g. ONNX export.
1735
1736
1737
1738
1739
1740
1741
1742
1743

    Parallelism parameters
    ----------------------
    sequence_parallel : bool, default = `False`
                       if set to `True`, uses sequence parallelism.
    tp_size : int, default = 1
             tensor parallel world size.
    tp_group : ProcessGroup, default = `None`
              tensor parallel process group.
1744
1745
1746
1747
1748
1749
1750
1751
1752
    cp_group : ProcessGroup, default = `None`
              context parallel process group.
    cp_global_ranks : list of global rank IDs, default = `None`
                     global rank IDs of GPUs that are in cp_group.
    cp_stream : CUDA stream, default = `None`
               context parallelism splits flash attention into multiple steps for
               compute and communication overlapping. To address the wave quantization
               issue of each split step, we add an additional CUDA stream so that we
               can overlap two flash attention kernels.
1753
1754
1755
1756
1757
1758
    """

    def __init__(
        self,
        num_attention_heads: int,
        kv_channels: int,
1759
        num_gqa_groups: Optional[int] = None,
1760
        attention_dropout: float = 0.0,
1761
        qkv_format: str = "sbhd",
1762
        attn_mask_type: str = "causal",
1763
1764
1765
1766
1767
        sequence_parallel: bool = False,
        tp_size: int = 1,
        get_rng_state_tracker: Optional[Callable] = None,
        tp_group: Optional[dist_group_type] = None,
        layer_number: Optional[int] = None,
1768
        attention_type: str = "self",
1769
        cp_group: Optional[dist_group_type] = None,
1770
        cp_global_ranks: List[int] = None,
1771
        cp_stream: torch.cuda.Stream = None,
1772
1773
1774
    ) -> None:
        super().__init__()

1775
        self.qkv_format = qkv_format
1776
        self.attn_mask_type = attn_mask_type
1777
        self.tp_size = tp_size if tp_group is None else get_distributed_world_size(tp_group)
1778
1779
        self.tp_group = tp_group
        self.get_rng_state_tracker = get_rng_state_tracker
1780
        self.num_attention_heads = num_attention_heads
1781
1782
1783
        self.cp_group = cp_group
        self.cp_global_ranks = cp_global_ranks
        self.cp_stream = cp_stream
1784

1785
1786
1787
        self.hidden_size_per_attention_head = kv_channels
        self.num_gqa_groups = (
            num_attention_heads if num_gqa_groups is None else num_gqa_groups
1788
        )
1789
1790
1791
1792
        self.num_gqa_groups_per_partition = int(self.num_gqa_groups // tp_size)

        assert (num_attention_heads % self.num_gqa_groups == 0
                ), "The number of attention heads must be divisible by the number of GQA groups!"
1793
1794
1795
1796
1797
1798
1799
1800
1801

        if sequence_parallel or get_rng_state_tracker is None:
            attention_dropout_ctx = nullcontext
        else:
            attention_dropout_ctx = get_rng_state_tracker().fork

        norm_factor = math.sqrt(self.hidden_size_per_attention_head)

        self.device_compute_capability = get_device_compute_capability()
1802
1803
        self.deterministic = not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1")))

1804
1805
1806
1807
        self.use_flash_attention = (
            int(os.getenv("NVTE_FLASH_ATTN", "1"))
            and self.device_compute_capability >= 8.0
        )
1808
1809
1810
1811
        if _flash_attn_2_available and self.deterministic:
            self.use_flash_attention = False
            warnings.warn(
                "Disabling usage of FlashAttention since version 2 does not support deterministic"
1812
1813
                "execution. In order to use FA with deterministic behavior, please install"
                "FlashAttention version 1."
1814
1815
            )

1816
1817
1818
1819
        self.use_fused_attention = (
            int(os.getenv("NVTE_FUSED_ATTN", "1"))
            and self.device_compute_capability >= 8.0
        )
1820

1821
1822
1823
1824
1825
1826
1827
        assert (
            attention_type in AttnTypes
        ), f"attention_type {attention_type} not supported"

        self.attention_type = attention_type
        self.attention_dropout = attention_dropout

1828
1829
1830
1831
1832
1833
        attn_kwargs = {
            "attention_dropout": attention_dropout,
            "attention_dropout_ctx": attention_dropout_ctx,
        }

        if self.use_flash_attention:
1834
1835
1836
1837
1838
1839
            self.flash_attention = FlashAttention(norm_factor,
                                                  attention_type=attention_type,
                                                  layer_number=layer_number,
                                                  deterministic=self.deterministic,
                                                  **attn_kwargs)

1840
        # Instantiating three types since use of flash-attn and FusedAttention
1841
        # might be ruled out due to forward inputs.
1842
1843
1844
        if self.use_fused_attention:
            self.fused_attention = FusedAttention(
                norm_factor, **attn_kwargs,
1845
                attention_type=attention_type)
1846
1847
1848
1849
1850
1851
1852
        self.unfused_attention = UnfusedDotProductAttention(
            norm_factor, **attn_kwargs, layer_number=layer_number)

    def _checkpointed_attention_forward(
        self,
        attention_func: Callable,
        *forward_args: Tuple[torch.Tensor, ...],
1853
        **forward_kwargs: Dict[str, Any],
1854
1855
1856
    ) -> torch.Tensor:
        """Forward method with activation checkpointing."""

1857
1858
        def custom_forward(*input_args, **input_kwargs):
            return attention_func(*input_args, **input_kwargs)
1859
1860
1861
1862
1863
1864
1865

        hidden_states = checkpoint(
            custom_forward,
            False,
            self.get_rng_state_tracker,
            self.tp_group,
            *forward_args,
1866
            **forward_kwargs,
1867
1868
1869
1870
1871
1872
1873
1874
1875
        )

        return hidden_states

    def forward(
        self,
        query_layer: torch.Tensor,
        key_layer: torch.Tensor,
        value_layer: torch.Tensor,
1876
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
1877
1878
1879
        qkv_format: Optional[str] = None,
        cu_seqlens_q: Optional[torch.Tensor] = None,
        cu_seqlens_kv: Optional[torch.Tensor] = None,
1880
        attn_mask_type: Optional[str] = None,
1881
        checkpoint_core_attention: bool = False,
1882
1883
1884
        core_attention_bias_type: str = "no_bias",
        core_attention_bias: Optional[torch.Tensor] = None,
        fast_zero_fill: bool = True,
1885
1886
1887
1888
1889
1890
    ) -> torch.Tensor:
        """
        Dot Product Attention Layer.

        .. note::

1891
1892
            Argument :attr:`attention_mask` is only used when :attr:`attn_mask_type`
            includes '"padding"' or `"arbitrary"`.
1893
1894
1895
1896
1897
1898
1899
1900
1901

        .. note::

            Input tensors :attr:`query_layer`, :attr:`key_layer`, and :attr:`value_layer`
            must each be of shape (:attr:`sequence_length`, :attr:`batch_size`,
            :attr:`num_attention_heads`, :attr:`kv_channels`). Output of shape
            (:attr:`sequence_length`, :attr:`batch_size`, :attr:`num_attention_heads`
            * :attr:`kv_channels`) is returned.

1902
1903
        .. note::

1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
            DotProductAttention supports three backends: 1) FlashAttention which calls
            HazyResearch/Dao-AILab's `flash-attn <https://arxiv.org/pdf/2305.13245.pdf>`_
            PyTorch API, 2) FusedAttention which has multiple fused attention implementations
            based on `cuDNN Graph API
            <https://docs.nvidia.com/deeplearning/cudnn/developer-guide/index.html#op-fusion>`_
            (see :attr:`FusedAttention` for more details on FusedAttention backends), and 3)
            UnfusedDotProductAttention which is the native PyTorch implementation
            with fused scaled masked softmax.

        .. note::

            Users can use environment variables :attr:`NVTE_FLASH_ATTN`, :attr:`NVTE_FUSED_ATTN`,
            and :attr:`NVTE_FUSED_ATTN_BACKEND` to control which DotProductAttention backend,
            and FusedAttention backend if applicable, to use. TransformerEngine prioritizes
            FlashAttention over FusedAttention and over UnfusedDotProductAttention.
            If FusedAttention is being used, users can also choose to switch to flash-attn's
            implementation for backward by setting :attr:`NVTE_FUSED_ATTN_USE_FAv2_BWD=1`
            (default: 0), because of the performance differences between various versions of
1922
1923
1924
1925
1926
            flash-attn and FusedAttention. Further, :attr:`NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT`
            can be used to enable (:attr:`1`) or disable (:attr:`0`) the workspace related
            optimizations in FusedAttention. When unset, TransformerEngine determines the code path
            based on its internal logic. These optimizations trade memory for performance
            and should be used with care.
1927

1928
1929
1930
1931
1932
1933
1934
1935
        Parameters
        ----------
        query_layer : torch.Tensor
                     Query tensor.
        key_layer : torch.Tensor
                   Key tensor.
        value_layer : torch.Tensor
                     Value tensor.
1936
1937
1938
        attention_mask : Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], default = `None`
                        Boolean tensor used to mask out softmax input when not using flash-attn.
                        Can be a tuple of 2 masks for cross attention with padding masks.
1939
1940
1941
1942
1943
1944
1945
1946
        qkv_format: str, default = `None`
                   If provided, overrides :attr:`qkv_format` from initialization.
        cu_seqlens_q: Optional[torch.Tensor], default = `None`
                   Cumulative sum of sequence lengths in a batch for `query_layer`,
                   with shape [batch_size + 1] and dtype torch.int32.
        cu_seqlens_kv: Optional[torch.Tensor], default = `None`
                   Cumulative sum of sequence lengths in a batch for `key_layer` and `value_layer`,
                   with shape [batch_size + 1] and dtype torch.int32.
1947
        attn_mask_type: {'causal', 'padding', 'no_mask', 'arbitrary'}, default = `None`
1948
                       type of attention mask passed into softmax operation.
1949
1950
1951
1952
1953
        checkpoint_core_attention : bool, default = `False`
                                   If true, forward activations for attention are recomputed
                                   during the backward pass in order to save memory that would
                                   otherwise be occupied to store the forward activations until
                                   backprop.
1954
1955
1956
1957
        core_attention_bias_type: str, default = `no_bias`
                    Bias type, {`no_bias`, `pre_scale_bias`, 'post_scale_bias`}
        core_attention_bias: Optional[torch.Tensor], default = `None`
                    Bias tensor for Q * K.T
1958
        fast_zero_fill: bool, default = `True`
1959
                    Whether to use the fast path to set output tensors to 0 or not.
1960
1961
        """

1962
1963
1964
1965
        assert (
            query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda
            ), 'DotProductAttention only supports CUDA tensors.'

1966
1967
1968
        assert (key_layer.shape == value_layer.shape
            ), "Keys and values must have the same shape!"

1969
        if attn_mask_type is None:
1970
            attn_mask_type = self.attn_mask_type
1971
1972
        if qkv_format is None:
            qkv_format = self.qkv_format
1973
        attn_mask_type, causal_mask = _unpack_attn_mask_type(attn_mask_type)
1974

1975
        assert (key_layer.shape[-2] == self.num_gqa_groups_per_partition
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
            and value_layer.shape[-2] == self.num_gqa_groups_per_partition
            ), f"Keys and values must have num_gqa_group = {self.num_gqa_groups} heads!"
        assert (qkv_format in ['sbhd', 'bshd', 'thd']
            ), "DotProductAttention only supports qkv_format = {'sbhd', 'bshd', 'thd'}!"

        if qkv_format == 'thd':
            assert (all(len(x.shape) == 3 for x in (query_layer, key_layer, value_layer))
                ), "Queries, keys and values must be 3D tensors when qkv_format = thd!"
            assert (cu_seqlens_q is not None and cu_seqlens_kv is not None
                ), "cu_seqlens_q and cu_seqlens_kv can not be None when qkv_format = thd!"
            assert (cu_seqlens_q.shape == cu_seqlens_kv.shape
                and len(cu_seqlens_q.shape) == 1
                and len(cu_seqlens_kv.shape) == 1
                ), "cu_seqlens_q and cu_seqlens_q must both have shape [batch_size + 1]!"
            assert (cu_seqlens_q.dtype == torch.int32
                and cu_seqlens_kv.dtype == torch.int32
                ), "cu_seqlens_q and cu_seqlens_q must both be in dtype torch.int32!"
            seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
            seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
            max_seqlen_q = seqlens_q.max().item()
            max_seqlen_kv = seqlens_kv.max().item()

        if qkv_format in ['sbhd', 'bshd']:
            assert (all(len(x.shape) == 4 for x in (query_layer, key_layer, value_layer))
                ), f"Queries, keys and values must be 4D tensors when qkv_format = {qkv_format}!"
            if qkv_format == 'sbhd':
                max_seqlen_q, max_seqlen_kv = (query_layer.shape[0], key_layer.shape[0])
            if qkv_format == 'bshd':
                max_seqlen_q, max_seqlen_kv = (query_layer.shape[1], key_layer.shape[1])
            if cu_seqlens_q is not None:
                seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
                assert (all(seqlens_q <= max_seqlen_q)
                    ), """Sequence lengths indicated by cu_seqlens_q must be no greater than
                    the sequence dimention in 'query_layer'!"""
            if cu_seqlens_kv is not None:
                seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
                assert (all(seqlens_kv <= max_seqlen_kv)
                    ), """Sequence lengths indicated by cu_seqlens_kv must be no greater than
                    the sequence dimention in 'key_layer' and 'value_layer'!"""

        qkv_layout = _get_qkv_layout(query_layer, key_layer, value_layer,
            qkv_format = qkv_format)
2018

2019
2020
        # The priority for attention backends (subject to availability and clearing the filters)
        # is: FlashAttention > FusedAttention (cuDNN) > UnfusedDotProductAttention.
2021
        use_flash_attention = self.use_flash_attention
2022
2023
        use_fused_attention = self.use_fused_attention

2024
2025
2026
2027
        # The following section filters out some backends based on
        # certain asserts before executing the forward pass.

        # Filter: Input type.
2028
2029
2030
2031
2032
        if (query_layer.dtype not in [torch.bfloat16, torch.float16]
            or key_layer.dtype not in [torch.bfloat16, torch.float16]
            or value_layer.dtype not in [torch.bfloat16, torch.float16]
        ):
            use_flash_attention = False
2033
            use_fused_attention = False
2034

2035
        # Filter: Device and dimensions.
2036
2037
2038
2039
2040
2041
        if key_layer.shape[-1] > 64:
            if self.device_compute_capability in (8.6, 8.7):
                use_flash_attention = False
            elif not _flash_attn_2_available and self.device_compute_capability == 8.9:
                use_flash_attention = False

2042
2043
2044
        if not _flash_attn_2_available and self.num_gqa_groups != self.num_attention_heads:
            use_flash_attention = False

2045
2046
2047
        if core_attention_bias_type != "no_bias" or core_attention_bias is not None:
            use_flash_attention = False

2048
        # Filter: ONNX export.
2049
2050
        if is_in_onnx_export_mode():
            use_flash_attention = False
2051
2052
            use_fused_attention = False

2053
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
2064
2065
2066
2067
2068
2069
        # Filter: Attention mask type.
        #    attn_mask_type(s)   |     supported backends
        # ------------------------------------------------
        #   causal               |     All
        #   padding              |     UnfusedDotProductAttention, FlashAttention
        #   arbitrary            |     UnfusedDotProductAttention
        #   no_mask              |     All
        #   causal + padding     |     FlashAttention
        #
        if attn_mask_type == "arbitrary":
            use_flash_attention = False
            use_fused_attention = False
        elif attn_mask_type == "padding" and causal_mask:
            assert use_flash_attention, "No attention backend available for causal + padding masks."
        elif attn_mask_type == "padding":
            use_fused_attention = False

2070
2071
2072
2073
2074
2075
        if use_fused_attention:
            fused_attention_backend = tex.get_fused_attn_backend(
                TE_DType[query_layer.dtype],
                TE_DType[key_layer.dtype],
                QKVLayout[qkv_layout],
                AttnBiasType[core_attention_bias_type],
2076
                AttnMaskType[attn_mask_type],
2077
                self.attention_dropout,
2078
                max_seqlen_q, max_seqlen_kv,
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
                query_layer.shape[-1])
            # DPA does not support FP8; for FP8, use cpp_extensions modules directly
            is_backend_avail = (fused_attention_backend in
                [FusedAttnBackend["F16_max512_seqlen"], FusedAttnBackend["F16_arbitrary_seqlen"]])
            use_fused_attention = (use_fused_attention
                                  and is_backend_avail
                                  and self.num_gqa_groups == self.num_attention_heads)
            if (self.deterministic
                and fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]):
                use_fused_attention = False
                warnings.warn(
                    "Disabling usage of FusedAttention since the FusedAttention"
                    "backend does not support deterministic exection."
                )
2093
2094
2095
2096
2097
2098

        if use_flash_attention:
            if checkpoint_core_attention:
                return self._checkpointed_attention_forward(self.flash_attention,
                                                            query_layer,
                                                            key_layer,
2099
                                                            value_layer,
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115
2116
2117
2118
                                                            attention_mask=attention_mask,
                                                            qkv_layout=qkv_layout,
                                                            cu_seqlens_q=cu_seqlens_q,
                                                            cu_seqlens_kv=cu_seqlens_kv,
                                                            attn_mask_type=attn_mask_type,
                                                            cp_group=self.cp_group,
                                                            cp_global_ranks=self.cp_global_ranks,
                                                            cp_stream=self.cp_stream)
            return self.flash_attention(query_layer,
                                        key_layer,
                                        value_layer,
                                        attention_mask=attention_mask,
                                        qkv_layout=qkv_layout,
                                        cu_seqlens_q=cu_seqlens_q,
                                        cu_seqlens_kv=cu_seqlens_kv,
                                        attn_mask_type=attn_mask_type,
                                        cp_group=self.cp_group,
                                        cp_global_ranks=self.cp_global_ranks,
                                        cp_stream=self.cp_stream)
2119
2120
2121
2122

        assert (
            self.cp_group is None or get_distributed_world_size(self.cp_group) == 1
        ), "Context parallelism is only implemented with Flash Attention!"
2123

2124
2125
2126
        if use_fused_attention:
            if checkpoint_core_attention:
                return self._checkpointed_attention_forward(self.fused_attention,
2127
2128
2129
                              query_layer,
                              key_layer,
                              value_layer,
2130
2131
2132
2133
2134
2135
2136
2137
                              qkv_layout = qkv_layout,
                              cu_seqlens_q = cu_seqlens_q,
                              cu_seqlens_kv = cu_seqlens_kv,
                              attn_mask_type = attn_mask_type,
                              fused_attention_backend = fused_attention_backend,
                              core_attention_bias_type = core_attention_bias_type,
                              core_attention_bias = core_attention_bias,
                              fast_zero_fill = fast_zero_fill)
2138
            return self.fused_attention(query_layer, key_layer, value_layer,
2139
2140
2141
2142
2143
2144
2145
2146
                              qkv_layout = qkv_layout,
                              cu_seqlens_q = cu_seqlens_q,
                              cu_seqlens_kv = cu_seqlens_kv,
                              attn_mask_type = attn_mask_type,
                              fused_attention_backend = fused_attention_backend,
                              core_attention_bias_type = core_attention_bias_type,
                              core_attention_bias = core_attention_bias,
                              fast_zero_fill = fast_zero_fill)
2147

2148
2149
2150
2151
2152
2153
        if checkpoint_core_attention:
            return self._checkpointed_attention_forward(
                self.unfused_attention,
                query_layer,
                key_layer,
                value_layer,
2154
2155
2156
2157
2158
2159
2160
                qkv_layout = qkv_layout,
                cu_seqlens_q = cu_seqlens_q,
                cu_seqlens_kv = cu_seqlens_kv,
                attn_mask_type = attn_mask_type,
                attention_mask = attention_mask,
                core_attention_bias_type = core_attention_bias_type,
                core_attention_bias = core_attention_bias)
2161
2162
2163
        return self.unfused_attention(query_layer,
                key_layer,
                value_layer,
2164
2165
2166
2167
2168
2169
2170
                qkv_layout = qkv_layout,
                cu_seqlens_q = cu_seqlens_q,
                cu_seqlens_kv = cu_seqlens_kv,
                attn_mask_type = attn_mask_type,
                attention_mask = attention_mask,
                core_attention_bias_type = core_attention_bias_type,
                core_attention_bias = core_attention_bias)
2171
2172


2173
2174
2175
2176
2177
2178
2179
2180
class MultiheadAttention(torch.nn.Module):
    r"""
    Multi-head Attention (MHA), including Query,
    Key, Value and Output projection.

    .. note::

        Argument :attr:`attention_mask` will be ignored in the `forward` call when
2181
2182
        :attr:`attn_mask_type` is set to `"causal"`.

2183
2184
2185
2186
2187
2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
2198
2199
2200
2201
2202
2203
2204
2205
2206
2207
    Parameters
    ----------
    hidden_size : int
                 size of each input sample.
    num_attention_heads : int
                         number of attention heads in the transformer layer.
    kv_channels: int, default = `None`
                number of key-value channels. defaults to
                :attr:`hidden_size` / :attr:`num_attention_heads` if `None`.
    attention_dropout: float, default = 0.1
                      dropout probability for the dropout op during multi-head attention.
    layernorm_epsilon : float, default = 1e-5
                       a value added to the denominator of layer normalization
                       for numerical stability.
    init_method : Callable, default = `None`
                 used for initializing weights of QKV and FC1 weights in the following way:
                 `init_method(weight)`. When set to `None`, defaults to
                 `torch.nn.init.normal_(mean=0.0, std=0.023)`.
    output_layer_init_method : Callable, default = `None`
                              used for initializing weights of PROJ and FC2 in the following way:
                              `output_layer_init_method(weight)`. When set to `None`, defaults to
                              `torch.nn.init.normal_(mean=0.0, std=0.023)`.
    layer_number: int, default = `None`
                 layer number of the current `TransformerLayer` when multiple such modules are
                 concatenated to form a transformer block.
2208
    attn_mask_type: {'causal', 'padding', 'no_mask', 'arbitrary'}, default = `causal`
2209
2210
2211
2212
2213
                   type of attention mask passed into softmax operation. Overridden by
                   :attr:`attn_mask_type` in the `forward` method. The forward
                   arg is useful for dynamically changing mask types, e.g. a different
                   mask for training and inference. The init arg is useful for cases
                   involving compilation/tracing, e.g. ONNX export.
2214
2215
2216
2217
2218
2219
2220
2221
2222
2223
2224
2225
2226
2227
2228
2229
2230
2231
2232
2233
2234
2235
2236
2237
2238
2239
2240
2241
2242
2243
2244
2245
2246
2247
2248
2249
2250
2251
2252
2253
2254
2255
2256
2257
2258
2259
2260
2261
2262
2263
2264
2265
2266
2267
2268
2269
2270
2271
2272
2273
2274
2275
2276
2277
2278
2279
2280
2281
2282
2283
2284
2285
2286
2287
2288
2289
2290
2291
    num_gqa_groups : int, default = `None`
                         number of GQA groups in the transformer layer.
                         Grouped Query Attention is described in
                         `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
                         This only affects the keys and values, not the querys.
                         GQA-1 is equivalent to Multi-Query Attention
                         (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
                         is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
    return_layernorm_output : bool, default = `False`
                             if set to `True`, output of layernorm is returned from the forward
                             together with the output of the linear transformation.
                             Example use case: residual connection for transformer module is
                             taken post layernorm.
    input_layernorm: bool, default = `True`
                     if set to `False`, layer normalization to the input is not applied.
    attention_type: { 'self', 'cross' }, default = 'self'
                   type of attention applied.
    zero_centered_gamma : bool, default = 'False'
                         if set to 'True', gamma parameter in LayerNorm is initialized to 0 and
                         the LayerNorm formula changes to

                         .. math::
                            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} *
                            (1 + \gamma) + \beta
    normalization : { 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm'
                   type of normalization applied.
    qkv_weight_interleaved : bool, default = `True`
                            if set to `False`, the QKV weight is interpreted as a concatenation of
                            query, key, and value weights along the `0th` dimension. The default
                            interpretation is that the individual `q`, `k`, and `v` weights for each
                            attention head are interleaved. This parameter is set to `False` when
                            using :attr:`fuse_qkv_params=False`.
    bias : bool, default = `True`
          if set to `False`, the transformer layer will not learn any additive biases.
    device : Union[torch.device, str], default = "cuda"
          The device on which the parameters of the model will allocated. It is the user's
          responsibility to ensure all parameters are moved to the GPU before running the
          forward pass.

    Parallelism parameters
    ----------------------
    set_parallel_mode : bool, default = `False`
                      if set to `True`, QKV and FC1 layers are used as Column Parallel
                      whereas PROJ and FC2 is used as Row Parallel as described
                      `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
    sequence_parallel : bool, default = `False`
                       if set to `True`, uses sequence parallelism.
    tp_group : ProcessGroup, default = `None`
              tensor parallel process group.
    tp_size : int, default = 1
             used as TP (tensor parallel) world size when TP groups are not formed during
             initialization. In this case, users must call the
             `set_tensor_parallel_group(tp_group)` method on the initialized module before the
             forward pass to supply the tensor parallel group needed for tensor and sequence
             parallel collectives.

    Optimization parameters
    -----------------------
    fuse_wgrad_accumulation : bool, default = 'False'
                             if set to `True`, enables fusing of creation and accumulation of
                             the weight gradient. When enabled, it is assumed that the weights
                             have an additional `main_grad` attribute (used instead of the
                             regular `grad`) which is a pre-allocated buffer of the correct
                             size to accumulate gradients in.
    params_dtype : torch.dtype, default = `torch.get_default_dtype()`
                  it controls the type used to allocate the initial parameters. Useful when
                  the model is trained with lower precision and the original FP32 parameters
                  would not fit in GPU memory.
    return_bias : bool, default = `False`
                 when set to `True`, this module will not apply the additive bias itself, but
                 instead return the bias value during the forward pass together with the
                 output of the linear transformation :math:`y = xA^T`. This is useful when
                 the bias addition can be fused to subsequent operations.
    fuse_qkv_params: bool, default = 'False'
                    if set to `True`, `TransformerLayer` module exposes a single fused
                    parameter for query-key-value. This enables optimizations such as QKV
                    fusion without concatentations/splits and also enables the argument
                    `fuse_wgrad_accumulation`.
2292
2293
2294
2295
2296
2297
    """

    def __init__(
        self,
        hidden_size: int,
        num_attention_heads: int,
2298
2299
2300
2301
2302
        kv_channels: Optional[int] = None,
        attention_dropout: float = 0.1,
        layernorm_epsilon: float = 1e-5,
        init_method: Optional[Callable] = None,
        output_layer_init_method: Optional[Callable] = None,
2303
        layer_number: Optional[int] = None,
2304
        attn_mask_type: str = "causal",
2305
2306
        tp_group: Optional[dist_group_type] = None,
        tp_size: int = 1,
2307
        num_gqa_groups: Optional[int] = None,
2308
2309
2310
        fuse_wgrad_accumulation: bool = False,
        get_rng_state_tracker: Optional[Callable] = None,
        sequence_parallel: bool = False,
2311
        params_dtype: Optional[torch.dtype] = None,
2312
        return_bias: bool = False,
2313
2314
2315
2316
2317
2318
2319
2320
2321
2322
2323
        return_layernorm_output: bool = False,
        input_layernorm: bool = False,
        attention_type: str = "self",
        set_parallel_mode: bool = False,
        fuse_qkv_params: bool = False,
        zero_centered_gamma: bool = False,
        qkv_weight_interleaved: bool = True,
        ub_bulk_wgrad: bool = False,
        ub_bulk_dgrad: bool = False,
        ub_split_rs: bool = False,
        ub_split_ag: bool = False,
2324
2325
        ub_atomic_gemm_rs: bool = False,
        ub_atomic_gemm_ag: bool = False,
2326
        bias: bool = True,
2327
        normalization: str = "LayerNorm",
2328
        device: Union[torch.device, str] = "cuda",
2329
2330
    ) -> None:
        super().__init__()
2331
2332

        self.attn_mask_type = attn_mask_type
2333
        self.layer_number = layer_number
2334
2335
2336
2337
2338
        self.input_layernorm = input_layernorm
        self.attention_type = attention_type
        self.get_rng_state_tracker = get_rng_state_tracker
        self.tp_group = tp_group
        self.return_layernorm_output = return_layernorm_output
2339
        self.params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
2340
        self.num_attention_heads = num_attention_heads
2341
2342
2343
2344
2345
2346
2347
2348
        self.return_bias = return_bias

        kv_channels = kv_channels if kv_channels else (hidden_size // num_attention_heads)

        if init_method is None:
            init_method = get_default_init_method()
        if output_layer_init_method is None:
            output_layer_init_method = get_default_init_method()
2349
2350
2351
2352
2353

        if not fuse_qkv_params:
            qkv_weight_interleaved = False
        self.qkv_weight_interleaved = qkv_weight_interleaved

2354
2355
2356
        assert attention_type in AttnTypes, f"attention_type {attention_type} not supported"
        if layer_number is not None:
            assert layer_number > 0, "layer_number must be a positive integer"
2357
2358
2359
2360
2361
2362
2363

        tp_size = tp_size if tp_group is None else get_distributed_world_size(tp_group)
        self.tp_size = tp_size
        self.sequence_parallel = (tp_size > 1) and sequence_parallel

        self.hidden_size_per_attention_head = kv_channels
        self.num_attention_heads_per_partition = divide(num_attention_heads, tp_size)
2364
2365
2366
2367
        self.num_gqa_groups = (
            num_attention_heads if num_gqa_groups is None else num_gqa_groups
        )
        assert (num_attention_heads % self.num_gqa_groups == 0
cyanguwa's avatar
cyanguwa committed
2368
2369
                ), "The number of attention heads must be divisible by the number of GQA groups!"
        assert (self.num_gqa_groups % tp_size == 0
2370
2371
2372
                ), "The number of GQA groups must be divisible by tensor parallel size!"
        self.num_gqa_groups_per_partition = int(self.num_gqa_groups // tp_size)
        self.hidden_size_kv = int(hidden_size * self.num_gqa_groups // num_attention_heads)
2373
2374
2375
2376
2377
2378
2379

        common_gemm_kwargs = {
            "fuse_wgrad_accumulation": fuse_wgrad_accumulation,
            "tp_group": tp_group,
            "tp_size": tp_size,
            "get_rng_state_tracker": get_rng_state_tracker,
            "sequence_parallel": sequence_parallel,
2380
            "params_dtype": self.params_dtype,
2381
            "device": device,
2382
2383
2384
2385
        }

        qkv_parallel_mode = "column" if set_parallel_mode else None

cyanguwa's avatar
cyanguwa committed
2386
2387
2388
2389
        if self.attention_type == "self":
            parameters_split = {"query_": hidden_size,
                                "key_": self.hidden_size_kv,
                                "value_": self.hidden_size_kv} if not fuse_qkv_params else None
2390
2391
2392
            if self.input_layernorm:
                self.layernorm_qkv = LayerNormLinear(
                    hidden_size,
cyanguwa's avatar
cyanguwa committed
2393
                    hidden_size + 2 * self.hidden_size_kv,
2394
2395
2396
2397
2398
2399
                    eps=layernorm_epsilon,
                    init_method=init_method,
                    bias=bias,
                    return_bias=False,
                    parallel_mode=qkv_parallel_mode,
                    return_layernorm_output=return_layernorm_output,
cyanguwa's avatar
cyanguwa committed
2400
                    parameters_split=parameters_split,
2401
2402
2403
2404
                    zero_centered_gamma=zero_centered_gamma,
                    ub_bulk_wgrad=ub_bulk_wgrad,
                    ub_bulk_dgrad=ub_bulk_dgrad,
                    ub_split_ag=ub_split_ag,
2405
                    normalization=normalization,
2406
                    ub_atomic_gemm_ag=ub_atomic_gemm_ag,
2407
2408
2409
2410
2411
                    **common_gemm_kwargs,
                )
            else:
                self.qkv = Linear(
                    hidden_size,
cyanguwa's avatar
cyanguwa committed
2412
                    hidden_size + 2 * self.hidden_size_kv,
2413
2414
2415
2416
                    init_method=init_method,
                    bias=bias,
                    return_bias=False,
                    parallel_mode=qkv_parallel_mode,
cyanguwa's avatar
cyanguwa committed
2417
                    parameters_split=parameters_split,
2418
2419
                    **common_gemm_kwargs,
                )
cyanguwa's avatar
cyanguwa committed
2420
        elif self.attention_type == "cross":
2421
2422
2423
2424
2425
2426
2427
2428
2429
            if self.input_layernorm:
                self.layernorm_query = LayerNormLinear(
                    hidden_size,
                    hidden_size,
                    eps=layernorm_epsilon,
                    init_method=init_method,
                    bias=bias,
                    return_bias=False,
                    parallel_mode=qkv_parallel_mode,
cyanguwa's avatar
cyanguwa committed
2430
                    parameters_split=("query_",) if not fuse_qkv_params else None,
2431
2432
2433
2434
2435
                    return_layernorm_output=return_layernorm_output,
                    zero_centered_gamma=zero_centered_gamma,
                    ub_bulk_wgrad=ub_bulk_wgrad,
                    ub_bulk_dgrad=ub_bulk_dgrad,
                    ub_split_ag=ub_split_ag,
2436
                    normalization=normalization,
2437
                    ub_atomic_gemm_ag=ub_atomic_gemm_ag,
2438
2439
2440
2441
2442
2443
2444
2445
2446
2447
2448
2449
2450
2451
                    **common_gemm_kwargs,
                )
            else:
                self.query_layer = Linear(
                    hidden_size,
                    hidden_size,
                    init_method=init_method,
                    bias=bias,
                    return_bias=False,
                    parallel_mode=qkv_parallel_mode,
                    **common_gemm_kwargs,
                )
            self.key_value = Linear(
                hidden_size,
2452
                2 * self.hidden_size_kv,
2453
2454
2455
2456
2457
2458
2459
2460
2461
2462
2463
2464
                init_method=init_method,
                bias=bias,
                return_bias=False,
                parallel_mode=qkv_parallel_mode,
                parameters_split=("key_", "value_") if not fuse_qkv_params else None,
                **common_gemm_kwargs,
            )

        # Attention.
        self.core_attention = DotProductAttention(
            num_attention_heads,
            kv_channels,
2465
2466
            num_gqa_groups=self.num_gqa_groups,
            attention_dropout=attention_dropout,
2467
2468
2469
2470
            tp_size=tp_size,
            get_rng_state_tracker=get_rng_state_tracker,
            sequence_parallel=sequence_parallel,
            tp_group=tp_group,
2471
            layer_number=self.layer_number,
2472
2473
2474
2475
2476
2477
2478
2479
        )

        # Linear
        self.proj = Linear(
            hidden_size,
            hidden_size,
            init_method=output_layer_init_method,
            bias=bias,
2480
            return_bias=return_bias,
2481
2482
2483
            parallel_mode="row" if set_parallel_mode else None,
            ub_split_rs=ub_split_rs,
            ub_split_ag=ub_split_ag,
2484
2485
            ub_atomic_gemm_rs=ub_atomic_gemm_rs,
            ub_atomic_gemm_ag=ub_atomic_gemm_ag,
2486
2487
2488
2489
2490
            **common_gemm_kwargs,
        )


    def _allocate_memory(
2491
        self, inference_max_sequence_len: int, batch_size: int, dtype: torch.dtype
2492
2493
2494
2495
    ) -> torch.Tensor:
        return torch.empty(
            inference_max_sequence_len,
            batch_size,
2496
            self.num_gqa_groups_per_partition,
2497
            self.hidden_size_per_attention_head,
2498
            dtype=dtype,
2499
2500
2501
2502
2503
2504
2505
            device=torch.cuda.current_device(),
        )

    def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None:
        """Set TP group"""
        self.tp_group = tp_group

2506
2507
2508
    def set_context_parallel_running(
        self,
        cp_group: Union[dist_group_type, None],
2509
        cp_global_ranks: List[int],
2510
2511
2512
2513
2514
2515
2516
        cp_stream: torch.cuda.Stream,
    ) -> None:
        """Set CP group and CP dual-stream running"""
        self.core_attention.cp_group = cp_group
        self.core_attention.cp_global_ranks = cp_global_ranks
        self.core_attention.cp_stream = cp_stream

2517
2518
2519
    def forward(
        self,
        hidden_states: torch.Tensor,
2520
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
2521
        encoder_output: Optional[torch.Tensor] = None,
2522
        attn_mask_type: Optional[str] = None,
2523
2524
2525
        is_first_microbatch: Optional[bool] = None,
        checkpoint_core_attention: bool = False,
        inference_params: Optional[Any] = None,
2526
        rotary_pos_emb: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
2527
2528
2529
        core_attention_bias_type: str = "no_bias",
        core_attention_bias: Optional[torch.Tensor] = None,
        fast_zero_fill: bool = True,
2530
    ) -> Tuple[Union[torch.Tensor, None], ...]:
2531
2532
2533
2534
2535
        """
        Forward propagation for MultiheadAttention layer.

        .. note::

2536
            Argument :attr:`attention_mask` will be ignored when :attr:`attn_mask_type`
2537
2538
2539
2540
2541
2542
2543
2544
            is set to `"causal"`.

        Parameters
        ----------
        hidden_states : torch.Tensor
             Input tensor.
        attention_mask : Optional[torch.Tensor], default = `None`
             Boolean tensor used to mask out self-attention softmax input.
2545
        attn_mask_type: {'causal', 'padding', 'no_mask', arbitrary}, default = `None`
2546
                       type of attention mask passed into softmax operation.
2547
2548
2549
2550
2551
2552
2553
2554
2555
2556
2557
2558
2559
2560
2561
2562
2563
2564
2565
2566
2567
2568
2569
2570
2571
2572
2573
2574
2575
2576
2577
        encoder_output : Optional[torch.Tensor], default = `None`
             Output of the encoder block to be fed into the decoder block if using
             `layer_type="decoder"`.
        is_first_microbatch : {True, False, None}, default = None
                             During training using either gradient accumulation or
                             pipeline parallelism a minibatch of data is further split
                             into microbatches. Between the microbatches of the same minibatch
                             the model weights are not updated. Setting this parameter indicates
                             whether the current microbatch is the first in a minibatch or not.
                             When set, this parameter enables additional optimizations:

                             * during FP8 training, it allows caching of the FP8 versions of
                               the weights
                             * it also allows skipping gradient accumulation during the
                               first microbatch (since it is the first gradient being
                               produced)
        checkpoint_core_attention: bool, default = `False`
                                  If true, forward activations for core attention are recomputed
                                  during the backward pass in order to save memory that would
                                  otherwise be occupied to store the forward activations until
                                  backprop.
        rotary_pos_emb: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], default = `None`
                       Embeddings for query and key tensors for applying rotary position
                       embedding. By default no input embedding is applied.
        core_attention_bias_type: str, default = `no_bias`
                    Bias type, {`no_bias`, `pre_scale_bias`, 'post_scale_bias`}
        core_attention_bias: Optional[torch.Tensor], default = `None`
                    Bias tensor for Q * K.T
        fast_zero_fill: bool, default = `True`
                    Whether to set output tensors to 0 or not before use.
        """
2578
2579
        # hidden_states: [sq, b, h]

2580
        if attn_mask_type is None:
2581
2582
2583
            attn_mask_type = self.attn_mask_type

        if attn_mask_type == "padding" and attention_mask is not None:
2584
2585
2586
2587
            assert (
                attention_mask.dtype == torch.bool
            ), "Attention mask must be a boolean tensor"

2588
2589
        assert (core_attention_bias_type in AttnBiasTypes
                ), f"core_attention_bias_type {core_attention_bias_type} is not supported!"
2590

2591
2592
2593
2594
        # =================================================
        # Pre-allocate memory for key-values for inference.
        # =================================================

2595
        is_first_step = False
2596
2597
2598
2599
2600
        if inference_params and self.layer_number is not None:
            if self.layer_number not in inference_params.key_value_memory_dict:
                inf_max_seq_len = inference_params.max_sequence_len
                inf_max_batch_size = inference_params.max_batch_size
                inference_key_memory = self._allocate_memory(
2601
                    inf_max_seq_len, inf_max_batch_size, hidden_states.dtype
2602
2603
                )
                inference_value_memory = self._allocate_memory(
2604
                    inf_max_seq_len, inf_max_batch_size, hidden_states.dtype
2605
2606
2607
2608
2609
                )
                inference_params.key_value_memory_dict[self.layer_number] = (
                    inference_key_memory,
                    inference_value_memory,
                )
2610
                is_first_step = True
2611
2612
2613
2614
2615
2616
2617
2618
2619
2620
            else:
                (
                    inference_key_memory,
                    inference_value_memory,
                ) = inference_params.key_value_memory_dict[self.layer_number]

        # =====================
        # Query, Key, and Value
        # =====================

cyanguwa's avatar
cyanguwa committed
2621
2622
        if self.attention_type == "self":
            # Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn]
2623
2624
2625
2626
2627
2628
2629
2630
2631
2632
2633
2634
2635
2636
2637
            if self.input_layernorm:
                layernorm_qkv_outputs = self.layernorm_qkv(
                    hidden_states,
                    is_first_microbatch=is_first_microbatch,
                )
                if self.return_layernorm_output:
                    mixed_x_layer, layernorm_output = layernorm_qkv_outputs
                else:
                    mixed_x_layer = layernorm_qkv_outputs
            else:
                mixed_x_layer = self.qkv(
                    hidden_states,
                    is_first_microbatch=is_first_microbatch,
                )

cyanguwa's avatar
cyanguwa committed
2638
2639
            num_queries_per_key_value = (self.num_attention_heads_per_partition //
                                         self.num_gqa_groups_per_partition)
2640
            if self.qkv_weight_interleaved:
cyanguwa's avatar
cyanguwa committed
2641
                # [sq, b, ng * (np/ng + 2) * hn] --> [sq, b, ng, (np/ng + 2), hn]
2642
                new_tensor_shape = mixed_x_layer.size()[:-1] + (
cyanguwa's avatar
cyanguwa committed
2643
2644
                    self.num_gqa_groups_per_partition,
                    (num_queries_per_key_value + 2),
2645
2646
2647
2648
                    self.hidden_size_per_attention_head,
                )
                # split along second last dimension
                split_dim = -2
cyanguwa's avatar
cyanguwa committed
2649
2650
2651
2652
2653
2654
2655
2656
2657
            else:
                # [sq, b, ng * (np/ng + 2) * hn] --> [sq, b, (np/ng + 2), ng, hn]
                new_tensor_shape = mixed_x_layer.size()[:-1] + (
                    (num_queries_per_key_value + 2),
                    self.num_gqa_groups_per_partition,
                    self.hidden_size_per_attention_head
                )
                # split along third last dimension
                split_dim = -3
2658
2659
2660

            mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)

cyanguwa's avatar
cyanguwa committed
2661
2662
2663
2664
2665
2666
2667
2668
2669
            # qkv_weight_interleaved:
            #  [sq, b, ng, (np/ng + 2), hn]
            #  --> [sq, b, ng, np/ng, hn], [sq, b, ng, 1, hn], [sq, b, ng, 1, hn]
            # not qkv_weight_interleaved:
            #  [sq, b, (np/ng + 2), ng, hn]
            #  --> [sq, b, np/ng, np, hn], [sq, b, 1, ng, hn], [sq, b, 1, ng, hn]
            if not is_in_onnx_export_mode():
                query_layer, key_layer, value_layer = _SplitAlongDim.apply(
                    mixed_x_layer, split_dim, (num_queries_per_key_value, 1, 1)
2670
                )
2671
            else:
cyanguwa's avatar
cyanguwa committed
2672
2673
2674
2675
2676
2677
2678
2679
2680
2681
2682
2683
                query_layer, key_layer, value_layer = torch.split(
                    mixed_x_layer, (num_queries_per_key_value, 1, 1), dim = split_dim,
                 )

            # query: -> [sq, b, np, hn]
            # key, value: -> [sq, b, ng, hn]
            query_layer, key_layer, value_layer = (x.reshape(x.size(0), x.size(1), -1,
                                                             self.hidden_size_per_attention_head)
                                                   for x in (query_layer, key_layer, value_layer))

        elif self.attention_type == "cross":
            # Attention heads [sk, b, h] --> [sk, b, (ng * 2 * hn)]
2684
            mixed_kv_layer = self.key_value(
cyanguwa's avatar
cyanguwa committed
2685
                encoder_output,
2686
2687
2688
2689
                is_first_microbatch=is_first_microbatch,
            )

            if self.qkv_weight_interleaved:
cyanguwa's avatar
cyanguwa committed
2690
                # [sq, b, (ng * 2 * hn)] --> [sq, b, ng, 2 * hn]
2691
                new_tensor_shape = mixed_kv_layer.size()[:-1] + (
2692
                    self.num_gqa_groups_per_partition,
2693
2694
2695
2696
2697
                    2 * self.hidden_size_per_attention_head,
                )
                # split along last dimension
                split_dim = -1
            else:
cyanguwa's avatar
cyanguwa committed
2698
                # [sq, b, (ng * 2 * hn)] --> [sq, b, 2 * ng, hn]
2699
                new_tensor_shape = mixed_kv_layer.size()[:-1] + (
2700
                    2 * self.num_gqa_groups_per_partition,
2701
2702
2703
2704
2705
2706
2707
                    self.hidden_size_per_attention_head,
                )
                # split along second last dimension
                split_dim = -2

            mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)

cyanguwa's avatar
cyanguwa committed
2708
2709
2710
2711
2712
            # mixed_kv_layer --> 2 [sk, b, ng, hn]
            if not is_in_onnx_export_mode():
                key_layer, value_layer = _SplitAlongDim.apply(
                    mixed_kv_layer, split_dim, mixed_kv_layer.shape[split_dim] // 2,
                )
2713
            else:
cyanguwa's avatar
cyanguwa committed
2714
2715
2716
                key_layer, value_layer = torch.split(
                    mixed_kv_layer, mixed_kv_layer.shape[split_dim] // 2, dim = split_dim,
                )
2717
2718
2719
2720
2721
2722
2723
2724
2725
2726
2727
2728
2729
2730
2731
2732
2733
2734
2735
2736
2737
2738
2739
2740
2741
2742
2743
2744

            # Attention head [sq, b, h] --> [sq, b, hp]
            if self.input_layernorm:
                layernorm_query_outputs = self.layernorm_query(
                    hidden_states,
                    is_first_microbatch=is_first_microbatch,
                )
                if self.return_layernorm_output:
                    query_layer, layernorm_output = layernorm_query_outputs
                else:
                    query_layer = layernorm_query_outputs
            else:
                query_layer = self.query_layer(
                    hidden_states,
                    is_first_microbatch=is_first_microbatch,
                )

            # [sq, b, hp] --> [sq, b, np, hn]
            new_tensor_shape = query_layer.size()[:-1] + (
                self.num_attention_heads_per_partition,
                self.hidden_size_per_attention_head,
            )
            query_layer = query_layer.view(*new_tensor_shape)

        # ==================================
        # Adjust key and value for inference
        # ==================================

2745
2746
2747
2748
2749
        # duplicate the pos_emb for self attention
        if rotary_pos_emb is not None:
            if not isinstance(rotary_pos_emb, tuple):
                rotary_pos_emb = ((rotary_pos_emb,) * 2)

2750
2751
2752
2753
2754
2755
2756
2757
2758
2759
2760
2761
2762
2763
2764
2765
2766
2767
2768
        if inference_params and self.layer_number is not None:
            batch_start = inference_params.batch_size_offset
            batch_end = batch_start + key_layer.size(1)
            assert batch_end <= inference_key_memory.size(1)
            sequence_start = inference_params.sequence_len_offset
            sequence_end = sequence_start + key_layer.size(0)
            assert sequence_end <= inference_key_memory.size(0)
            # Copy key and values.
            inference_key_memory[
                sequence_start:sequence_end, batch_start:batch_end, ...
            ] = key_layer
            inference_value_memory[
                sequence_start:sequence_end, batch_start:batch_end, ...
            ] = value_layer
            key_layer = inference_key_memory[:sequence_end, batch_start:batch_end, ...]
            value_layer = inference_value_memory[
                :sequence_end, batch_start:batch_end, ...
            ]

2769
2770
2771
2772
2773
2774
2775
2776
2777
2778
2779
2780
2781
2782
2783
2784
2785
2786
2787
2788
            # adjust the key rotary positional embedding
            if rotary_pos_emb is not None:
                q_pos_emb, k_pos_emb = rotary_pos_emb
                # need to cross check this condition during inference
                # if not set_inference_key_value_memory:
                if not is_first_step:
                    # In inference, we compute one token at a time.
                    # Select the correct positional embedding
                    # (only the last token in the sequence)
                    q_pos_emb = q_pos_emb[sequence_end - 1 : sequence_end]
                else:
                    # In the first forward pass of inference,
                    # we use the entire provided prefix.
                    # q_pos_emb here has the rope embeddings of the entire
                    # prefix + to-be-generated output so
                    # we slice to just the prefix.
                    q_pos_emb = q_pos_emb[:sequence_end, :, :, :]
                k_pos_emb = k_pos_emb[:sequence_end, :, :, :]
                rotary_pos_emb = (q_pos_emb, k_pos_emb)

2789
2790
2791
2792
        # ==================================
        # core attention computation
        # ==================================

2793
2794
2795
2796
2797
        # apply relative positional encoding (rotary embedding)
        if rotary_pos_emb is not None:
            q_pos_emb, k_pos_emb = rotary_pos_emb
            query_layer = apply_rotary_pos_emb(query_layer, q_pos_emb)
            key_layer = apply_rotary_pos_emb(key_layer, k_pos_emb)
2798
            value_layer = value_layer.contiguous()
2799

2800
2801
2802
2803
        context_layer = self.core_attention(
            query_layer,
            key_layer,
            value_layer,
2804
2805
2806
            qkv_format='sbhd',
            cu_seqlens_q=None,
            cu_seqlens_kv=None,
2807
2808
            attention_mask=attention_mask,
            attn_mask_type=attn_mask_type,
2809
2810
2811
2812
            checkpoint_core_attention=checkpoint_core_attention,
            core_attention_bias_type=core_attention_bias_type,
            core_attention_bias=core_attention_bias,
            fast_zero_fill=fast_zero_fill,
2813
2814
2815
2816
2817
2818
        )

        # =================
        # Output. [sq, b, h]
        # =================

2819
        projection_output = self.proj(
2820
2821
2822
            context_layer, is_first_microbatch=is_first_microbatch
        )

2823
2824
2825
2826
2827
2828
2829
2830
        if self.return_bias:
            attention_output, attention_bias = projection_output
        else:
            attention_output, attention_bias = projection_output, None

        outputs = (attention_output,)
        if self.return_bias:
            outputs += (attention_bias,)
2831
        if self.input_layernorm and self.return_layernorm_output:
2832
2833
            outputs += (layernorm_output,)
        return outputs if len(outputs) > 1 else outputs[0]