attention.py 163 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.

"""Attention."""
6
import collections
7
import os
8
import warnings
9
10
11
import math
from importlib.metadata import version
from contextlib import nullcontext
12
from typing import Any, Callable, List, Optional, Tuple, Union, Dict
13
from pkg_resources import packaging
cyanguwa's avatar
cyanguwa committed
14
import numpy as np
15
16

import torch
17
import torch.nn.functional as F
18
19

import transformer_engine_extensions as tex
20
21
22
23
24
from transformer_engine.pytorch.cpp_extensions.fused_attn import (
    fused_attn_fwd_qkvpacked,
    fused_attn_bwd_qkvpacked,
    fused_attn_fwd_kvpacked,
    fused_attn_bwd_kvpacked,
25
26
    fused_attn_fwd,
    fused_attn_bwd,
27
28
29
30
31
    QKVLayout,
    AttnBiasType,
    AttnMaskType,
    FusedAttnBackend,
)
32
33
34
35
36
37
from transformer_engine.pytorch.module import LayerNormLinear, Linear
from transformer_engine.pytorch.utils import (
    divide,
    attention_mask_func,
    split_tensor_along_dim,
    get_device_compute_capability,
38
    get_default_init_method,
39
40
41
42
)
from transformer_engine.pytorch.constants import (
    AttnMaskTypes,
    AttnTypes,
43
    AttnBiasTypes,
44
    QKVLayouts,
45
    dist_group_type,
46
    TE_DType,
47
48
49
50
)
from transformer_engine.pytorch.softmax import FusedScaleMaskSoftmax
from transformer_engine.pytorch.distributed import (
    get_distributed_world_size,
51
    get_distributed_rank,
52
53
54
    checkpoint,
)
from transformer_engine.pytorch.export import is_in_onnx_export_mode
55
from transformer_engine.pytorch.jit import jit_fuser, no_torch_dynamo
56
57

_flash_attn_version = packaging.version.Version(version("flash-attn"))
58
_flash_attn_version_required = packaging.version.Version("2.0.6")
59
_flash_attn_2_1_plus = _flash_attn_version >= packaging.version.Version("2.1")
60
_flash_attn_2_3_plus = _flash_attn_version >= packaging.version.Version("2.3")
61
62
_flash_attn_2_4_plus = _flash_attn_version >= packaging.version.Version("2.4")
_flash_attn_2_4_1_plus = _flash_attn_version >= packaging.version.Version("2.4.1")
63

64
if _flash_attn_version >= _flash_attn_version_required:
65
    from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_forward_func # pylint: disable=no-name-in-module
66
    from flash_attn_2_cuda import varlen_bwd as flash_attn_cuda_bwd # pylint: disable=no-name-in-module
67
68
    from flash_attn.flash_attn_interface import _flash_attn_varlen_forward as _flash_attn_forward # pylint: disable=no-name-in-module,ungrouped-imports
    from flash_attn.flash_attn_interface import _flash_attn_varlen_backward as _flash_attn_backward # pylint: disable=no-name-in-module
69
70


71
_cu_seqlens_q, _cu_seqlens_kv, _indices_q, _indices_kv = None, None, None, None
72
_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
73
74


75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
__all__ = ["DotProductAttention", "InferenceParams", "MultiheadAttention"]


class InferenceParams: # pylint: disable=too-few-public-methods
    """
    Inference parameters that are passed to the main model in order
    to efficienly calculate and store the context during inference.

    Parameters
    ----------
    max_batch_size : int
                    maximum batch size during inference.
    max_sequence_length : int
                         maximum sequence length during inference.
    """

    def __init__(self, max_batch_size, max_sequence_length):
        self.max_sequence_length = max_sequence_length
        self.max_batch_size = max_batch_size
        self.sequence_len_offset = 0
        self.batch_size_offset = 0
        self.key_value_memory_dict = {}

    def swap_key_value_dict(self, batch_indices):
        """
        Reorders the KV cache using the specified batch indices.

        Parameters
        ----------
        batch_indices : List[int]
                       Sequence of indices to reorder along the batch dimensions of
                       the KV cache. Must have a length equal to the batch size.
        """
        if len(self.key_value_memory_dict) == 0:
            raise ValueError("should not swap when dict in empty")

        for layer_number, inference_memory in self.key_value_memory_dict.items():
            inference_key_memory, inference_value_memory = inference_memory
            assert (
                len(batch_indices) == inference_key_memory.shape[1]
            )  # make sure batch size is the same
            new_inference_key_memory = inference_key_memory[:, batch_indices]
            new_inference_value_memory = inference_value_memory[:, batch_indices]
            self.key_value_memory_dict[layer_number] = (
                new_inference_key_memory,
                new_inference_value_memory,
            )
122

123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
@torch.no_grad()
def get_alibi(
    num_heads: int,
    max_seqlen_q: int,
    max_seqlen_kv: int,
) -> torch.Tensor:
    """
    Generate ALiBi bias in the shape of [1, num_heads, max_seqlen_q, max_seqlen_kv].
    """
    n = 2 ** math.floor(math.log2(num_heads))
    m_0 = 2.0 ** (-8.0 / n)
    m = torch.pow(m_0, torch.arange(1, 1 + n))

    if n < num_heads:
        m_hat_0 = 2.0 ** (-4.0 / n)
        m_hat = torch.pow(m_hat_0, torch.arange(1, 1 + 2 * (num_heads - n), 2))
        m = torch.cat([m, m_hat])

    a = torch.ones(max_seqlen_q, max_seqlen_kv)
    b = torch.triu(a,diagonal=1)
    c = b.cumsum(dim=-1)
    bb = torch.tril(a,diagonal=-1)
    cc = bb.cumsum(dim=0)
    d = c - cc
    bias = d.repeat(1, num_heads, 1, 1)

    for i in range(num_heads):
        bias[0,i,:,:] = m[i] * bias[0,i,:,:]

    bias = bias.to(dtype=torch.float32, device="cuda")
    return bias

def get_cu_seqlens(mask: torch.Tensor) -> torch.Tensor:
    """
    Given a padding mask of shape [batch_size, 1, 1, max_seqlen], returns an int32
    tensor of shape [batch_size + 1] containing the cumulative sequence lengths of
    the samples in a batch.
    """
    mask = mask.squeeze(1).squeeze(1)
    reduced_mask = mask.sum(dim=1)
    cu_seqlens = reduced_mask.cumsum(dim=0).to(torch.int32)
    zero = torch.zeros(1, dtype=torch.int32, device="cuda")
    cu_seqlens = torch.cat((zero, cu_seqlens))

    return cu_seqlens

169

170
171
172
def get_cu_seqlens_and_indices(mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Given a padding mask of shape [batch_size, 1, 1, max_seqlen], returns an int32
173
174
175
    tensor of shape [batch_size + 1] containing the cumulative sequence lengths of
    the samples in a batch, and another int32 tensor of shape [batch_size * max_seqlen, 1, 1]
    containing the indices for the valid tokens.
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
    """
    mask = mask.squeeze(1).squeeze(1)
    bs, seqlen = mask.shape

    reduced_mask = mask.sum(dim=1)
    cu_seqlens = reduced_mask.cumsum(dim=0).to(torch.int32)
    zero = torch.zeros(1, dtype=torch.int32, device="cuda")
    cu_seqlens = torch.cat((zero, cu_seqlens))

    mask = mask.reshape(-1)
    indices = mask.nonzero()
    indices = indices.unsqueeze(-1)

    num_nonzeros = indices.shape[0]
    pad_amount = bs * seqlen - num_nonzeros
    indices = F.pad(input=indices, pad=(0, 0, 0, 0, 0, pad_amount),
                    mode="constant", value=float(bs * seqlen))

    return cu_seqlens, indices


197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
def get_indices(max_seqlen: int, cu_seqlens: torch.Tensor) -> torch.Tensor:
    """
    Given max_seqlen and cu_seqlens of shape [batch_size + 1], returns an int32
    tensor of shape [batch_size * max_seqlen, 1, 1] containing the indices for
    the valid tokens in a batch.
    """
    bs = len(cu_seqlens) - 1
    seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
    indices = [i*max_seqlen + ii for i,j in enumerate(seqlens) for ii in range(j)]
    indices = torch.Tensor(indices).unsqueeze(1).unsqueeze(1).to(
                    dtype=torch.int64, device="cuda")

    num_nonzeros = indices.shape[0]
    pad_amount = bs * max_seqlen - num_nonzeros
    indices = F.pad(input=indices, pad=(0, 0, 0, 0, 0, pad_amount),
                    mode="constant", value=float(bs * max_seqlen))

    return indices


217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
@jit_fuser
def pack_tensor(
    indices: torch.Tensor,
    tensor: torch.Tensor,
) -> torch.Tensor:
    """
    Packs the given tensor using the `indices`.
    """
    padding_indice = torch.zeros(
        1, tensor.shape[1], tensor.shape[2], dtype=tensor.dtype, device=tensor.device)
    tensor = torch.cat((tensor, padding_indice), dim=0)

    indices = indices.repeat(1, tensor.shape[1], tensor.shape[2])
    packed = torch.gather(tensor, 0, indices)
    return packed


@jit_fuser
def pack_2_tensors(
    indices: torch.Tensor,
    t1: torch.Tensor,
    t2: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Packs the given 2 tensors using the `indices`.
    """
    t1_packed = pack_tensor(indices, t1)
    t2_packed = pack_tensor(indices, t2)
    return t1_packed, t2_packed


@jit_fuser
def pack_3_tensors(
    indices: torch.Tensor,
    t1: torch.Tensor,
    t2: torch.Tensor,
    t3: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Packs the given 3 tensors using the `indices`.
    """
    t1_packed = pack_tensor(indices, t1)
    t2_packed = pack_tensor(indices, t2)
    t3_packed = pack_tensor(indices, t3)
    return t1_packed, t2_packed, t3_packed


@jit_fuser
def unpack_tensor(
    indices: torch.Tensor,
    dim0: int,
    tensor: torch.Tensor,
) -> torch.Tensor:
    """
    Inverse of `pack_tensor`.
    """
    indices = indices.repeat(1, tensor.shape[1], tensor.shape[2])
    unpacked = torch.zeros(
        dim0 + 1, tensor.shape[1], tensor.shape[2], dtype=tensor.dtype, device=tensor.device)
    unpacked.scatter_(0, indices, tensor)
    unpacked = unpacked[0:-1,:,:]
    return unpacked


@jit_fuser
def unpack_2_tensors(
    indices: torch.Tensor,
    dim0: int,
    t1: torch.Tensor,
    t2: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Inverse of `pack_2_tensors`.
    """
    t1_unpacked = unpack_tensor(indices, dim0, t1)
    t2_unpacked = unpack_tensor(indices, dim0, t2)
    return t1_unpacked, t2_unpacked


@jit_fuser
def unpack_3_tensors(
    indices: torch.Tensor,
    dim0: int,
    t1: torch.Tensor,
    t2: torch.Tensor,
    t3: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Inverse of `pack_3_tensors`.
    """
    t1_unpacked = unpack_tensor(indices, dim0, t1)
    t2_unpacked = unpack_tensor(indices, dim0, t2)
    t3_unpacked = unpack_tensor(indices, dim0, t3)
    return t1_unpacked, t2_unpacked, t3_unpacked


class PackTensors(torch.autograd.Function):
    """
    Autograd function to pack tensors.
    """
    @staticmethod
    def forward(
        ctx,
        indices: torch.Tensor,
        *tensors: Tuple[torch.Tensor, ...]
    ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
        assert 1 <= len(tensors) <= 3, f"Packing {len(tensors)} tensors not supported."
        ctx.indices = indices
        ctx.dim0 = tensors[0].shape[0]
        if len(tensors) == 1:
            return pack_tensor(indices, *tensors)
        if len(tensors) == 2:
            return pack_2_tensors(indices, *tensors)
        return pack_3_tensors(indices, *tensors)

    @staticmethod
    def backward(ctx, *grad_outputs: Tuple[torch.Tensor, ...]):
        if len(grad_outputs) == 1:
            return None, unpack_tensor(ctx.indices, ctx.dim0, *grad_outputs)
        if len(grad_outputs) == 2:
            return None, *unpack_2_tensors(ctx.indices, ctx.dim0, *grad_outputs)
        return None, *unpack_3_tensors(ctx.indices, ctx.dim0, *grad_outputs)


class UnpackTensor(torch.autograd.Function):
    """
    Autograd function to unpack a tensor.
    """
    @staticmethod
    def forward(
        ctx,
        indices: torch.Tensor,
        dim0: int,
        tensor: torch.Tensor,
    ) -> torch.Tensor:
        ctx.indices = indices
        return unpack_tensor(indices, dim0, tensor)

    @staticmethod
    def backward(ctx, grad_output):
        return None, None, pack_tensor(ctx.indices, grad_output)


360
361
362
def flash_attn_p2p_communicate(rank, send_tensor, send_dst,
                               recv_tensor, recv_src,
                               cp_group, batch_p2p_comm):
363
    """Point-to-point communications of KV and dKV in Attention with context parallelism"""
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
    send_recv_ops = []

    if batch_p2p_comm:
        if rank % 2 == 0:
            send_op = torch.distributed.P2POp(torch.distributed.isend,
                                              send_tensor,
                                              send_dst,
                                              cp_group)
            recv_op = torch.distributed.P2POp(torch.distributed.irecv,
                                              recv_tensor,
                                              recv_src,
                                              cp_group)
            send_recv_ops.append(send_op)
            send_recv_ops.append(recv_op)
        else:
            recv_op = torch.distributed.P2POp(torch.distributed.irecv,
                                              recv_tensor,
                                              recv_src,
                                              cp_group)
            send_op = torch.distributed.P2POp(torch.distributed.isend,
                                              send_tensor,
                                              send_dst,
                                              cp_group)
            send_recv_ops.append(recv_op)
            send_recv_ops.append(send_op)
        send_recv_reqs = torch.distributed.batch_isend_irecv(send_recv_ops)
    else:
        if rank % 2 == 0:
            send_op = torch.distributed.isend(send_tensor, send_dst, cp_group)
            recv_op = torch.distributed.irecv(recv_tensor, recv_src, cp_group)
            send_recv_ops.append(send_op)
            send_recv_ops.append(recv_op)
        else:
            recv_op = torch.distributed.irecv(recv_tensor, recv_src, cp_group)
            send_op = torch.distributed.isend(send_tensor, send_dst, cp_group)
            send_recv_ops.append(recv_op)
            send_recv_ops.append(send_op)
        send_recv_reqs = send_recv_ops

    return send_recv_reqs


406
@jit_fuser
407
def flash_attn_fwd_out_correction(out, out_per_step, softmax_lse, softmax_lse_per_step):
408
    """Merge partial outputs of each step in Attention with context parallelism"""
409
410
411
412
413
414
    softmax_lse_corrected_exp = torch.exp(softmax_lse_per_step - softmax_lse).transpose(1, 2)
    softmax_lse_corrected_exp = softmax_lse_corrected_exp.unsqueeze(-1)
    out_corrected = out_per_step*softmax_lse_corrected_exp
    out.add_(out_corrected)


415
@jit_fuser
416
def flash_attn_fwd_softmax_lse_correction(softmax_lse, softmax_lse_per_step):
417
    """Merge softmax stats of each step in Attention with context parallelism"""
418
419
420
421
422
    softmax_lse.exp_()
    softmax_lse.add_(softmax_lse_per_step.to(torch.double).exp())
    softmax_lse.log_()


423
class AttnFuncWithCP(torch.autograd.Function):
424
    """
425
426
    Attention implementation with context parallelism.
    Split attention compute into multiple steps, and overlap current-step
427
428
429
430
    compute with next-step communication.
    """

    @staticmethod
431
432
433
    def forward(ctx, is_training, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
                dropout_p, cp_group, cp_global_ranks, cp_stream, softmax_scale, attn_mask_type,
                deterministic, use_fused_attention):
434
435
436
437
438
439
440
441
442
        if softmax_scale is None:
            softmax_scale = q.shape[-1] ** (-0.5)

        cp_size = get_distributed_world_size(cp_group)
        rank = get_distributed_rank(cp_group)
        send_dst = cp_global_ranks[(rank + 1) % cp_size]
        recv_src = cp_global_ranks[(rank + cp_size - 1) % cp_size]
        batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2)

443
444
445
446
447
        causal = (attn_mask_type == "causal")

        if causal:
            # [b, s, np, hn] -> [b, 2, s//2, np, hn]
            q, k, v = [x.view(x.shape[0], 2, x.shape[1]//2, *x.shape[2:]) for x in [q, k, v]]
448
        assert(q.shape[-1] % 8 == 0), "hidden size per attention head should be multiple of 8"
449
450
451
452
453
        fa_optional_forward_kwargs = {}
        if _flash_attn_2_3_plus:
            fa_optional_forward_kwargs["window_size"] = [-1, 0] if causal else [-1, -1]
        if _flash_attn_2_4_plus:
            fa_optional_forward_kwargs["alibi_slopes"] = None
454

455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
        # Flash Attn inputs
        q_inputs = [None, None]
        kv_inputs = [None, None]
        # Flash Attn outputs
        out_per_step = [None for _ in range(cp_size)]
        softmax_lse_per_step = [None for _ in range(cp_size)]
        rng_states = [None for _ in range(cp_size)]

        # create two streams to resolve wave quantization issue of Flash Attn in each step
        flash_attn_streams = [torch.cuda.current_stream(), cp_stream]
        # synchronize fwd results correction across steps
        fwd_results_correction_done = torch.cuda.Event()

        p2p_comm_buffers = [None for _ in range(cp_size)]
        p2p_comm_buffers[0] = torch.cat((k.unsqueeze(0), v.unsqueeze(0)), dim=0)
        send_recv_reqs = [[], []]

        for i in range(cp_size+1):
            if i < cp_size:
                with torch.cuda.stream(flash_attn_streams[i%2]):
                    # wait until KV is received
                    for req in send_recv_reqs[(i+1)%2]:
                        req.wait()

                    if i < (cp_size-1):
                        p2p_comm_buffers[i+1] = torch.empty_like(p2p_comm_buffers[i])
                        send_recv_reqs[i%2] = flash_attn_p2p_communicate(rank,
                                                                         p2p_comm_buffers[i],
                                                                         send_dst,
                                                                         p2p_comm_buffers[i+1],
                                                                         recv_src,
                                                                         cp_group,
                                                                         batch_p2p_comm)

                    kv_inputs[i%2] = p2p_comm_buffers[i]
                    if causal:
                        if i == 0:
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
                            if use_fused_attention:
                                # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
                                q_inputs[i%2] = q.view(q.shape[0], -1, *q.shape[-2:])
                                # [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn]
                                kv_inputs[i%2] = kv_inputs[i%2].view(
                                    2, k.shape[0], -1, *k.shape[-2:])
                                out_per_step[i], [softmax_lse_per_step[i], rng_states[i]] = \
                                fused_attn_fwd(
                                    is_training, max_seqlen_q, max_seqlen_k, cu_seqlens_q,
                                    cu_seqlens_k, q_inputs[i%2], kv_inputs[i%2][0],
                                    kv_inputs[i%2][1], TE_DType[q.dtype],
                                    tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
                                    attn_scale=softmax_scale, dropout=dropout_p,
                                    qkv_layout="bshd_bshd_bshd", attn_mask_type="causal",
                                )
                            else:
                                # [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
                                q_inputs[i%2] = q.view(-1, *q.shape[-2:])
                                # [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn]
                                kv_inputs[i%2] = kv_inputs[i%2].view(2, -1, *k.shape[-2:])
                                _, _, _, _, out_per_step[i], \
                                softmax_lse_per_step[i], _, rng_states[i] = _flash_attn_forward(
                                    q_inputs[i%2], kv_inputs[i%2][0], kv_inputs[i%2][1],
                                    cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
                                    dropout_p, softmax_scale, causal=True, return_softmax=False,
                                    **fa_optional_forward_kwargs
                                )
519
                        elif i <= rank:
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
                            if use_fused_attention:
                                # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
                                q_inputs[i%2] = q.view(q.shape[0], -1, *q.shape[-2:])
                                # [2, b, 2, sk//2, np, hn] -> [2, b, sk//2, np, hn]
                                kv_inputs[i%2] = kv_inputs[i%2][:, :, 0, ...].contiguous()
                                out_per_step[i], [softmax_lse_per_step[i], rng_states[i]] = \
                                fused_attn_fwd(
                                    is_training, max_seqlen_q, max_seqlen_k//2, cu_seqlens_q,
                                    cu_seqlens_k//2, q_inputs[i%2], kv_inputs[i%2][0],
                                    kv_inputs[i%2][1], TE_DType[q.dtype],
                                    tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
                                    attn_scale=softmax_scale, dropout=dropout_p,
                                    qkv_layout="bshd_bshd_bshd", attn_mask_type="no_mask",
                                )
                            else:
                                # [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
                                q_inputs[i%2] = q.view(-1, *q.shape[-2:])
                                # [2, b, 2, sk//2, np, hn] -> [2, b, sk//2, np, hn]
                                kv_inputs[i%2] = kv_inputs[i%2][:, :, 0, ...].contiguous()
                                # [2, b, sk//2, np, hn] -> [2, b*sk//2, np, hn]
                                kv_inputs[i%2] = kv_inputs[i%2].view(2, -1, *k.shape[-2:])
                                if _flash_attn_2_3_plus:
                                    fa_optional_forward_kwargs["window_size"] = [-1, -1]
                                _, _, _, _, out_per_step[i], \
                                softmax_lse_per_step[i], _, rng_states[i] = _flash_attn_forward(
                                    q_inputs[i%2], kv_inputs[i%2][0], kv_inputs[i%2][1],
                                    cu_seqlens_q, cu_seqlens_k//2, max_seqlen_q, max_seqlen_k//2,
                                    dropout_p, softmax_scale, causal=False, return_softmax=False,
                                    **fa_optional_forward_kwargs
                                )
                        else:
                            if use_fused_attention:
                                # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
                                q_inputs[i%2] = q[:, 1, ...].contiguous()
                                # [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn]
                                kv_inputs[i%2] = kv_inputs[i%2].view(
                                    2, k.shape[0], -1, *k.shape[-2:])
                                out_per_step[i], [softmax_lse_per_step[i], rng_states[i]] = \
                                fused_attn_fwd(
                                    is_training, max_seqlen_q//2, max_seqlen_k, cu_seqlens_q//2,
                                    cu_seqlens_k, q_inputs[i%2], kv_inputs[i%2][0],
                                    kv_inputs[i%2][1], TE_DType[q.dtype],
                                    tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
                                    attn_scale=softmax_scale, dropout=dropout_p,
                                    qkv_layout="bshd_bshd_bshd", attn_mask_type="no_mask",
                                )
                            else:
                                # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] -> [b*sq//2, np, hn]
                                q_inputs[i%2] = q[:, 1, ...].contiguous().view(-1, *q.shape[-2:])
                                # [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn]
                                kv_inputs[i%2] = kv_inputs[i%2].view(2, -1, *k.shape[-2:])
                                if _flash_attn_2_3_plus:
                                    fa_optional_forward_kwargs["window_size"] = [-1, -1]
                                _, _, _, _, out_per_step[i], \
                                softmax_lse_per_step[i], _, rng_states[i] = _flash_attn_forward(
                                    q_inputs[i%2], kv_inputs[i%2][0], kv_inputs[i%2][1],
                                    cu_seqlens_q//2, cu_seqlens_k, max_seqlen_q//2, max_seqlen_k,
                                    dropout_p, softmax_scale, causal=False, return_softmax=False,
                                    **fa_optional_forward_kwargs
                                )
                    else:
                        if use_fused_attention:
                            out_per_step[i], [softmax_lse_per_step[i], rng_states[i]] = \
                            fused_attn_fwd(
                                is_training, max_seqlen_q, max_seqlen_k, cu_seqlens_q,
                                cu_seqlens_k, q, kv_inputs[i%2][0],
                                kv_inputs[i%2][1], TE_DType[q.dtype],
                                tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
                                attn_scale=softmax_scale, dropout=dropout_p,
                                qkv_layout="bshd_bshd_bshd", attn_mask_type="no_mask",
590
                            )
591
                        else:
592
593
594
                            # [b, sq, np, hn] -> [b*sq, np, hn]
                            q_inputs[i%2] = q.view(-1, *q.shape[-2:])
                            # [2, b, sk, np, hn] -> [2, b*sk, np, hn]
595
                            kv_inputs[i%2] = kv_inputs[i%2].view(2, -1, *k.shape[-2:])
596
597
598
                            _, _, _, _, out_per_step[i], \
                            softmax_lse_per_step[i], _, rng_states[i] = _flash_attn_forward(
                                q_inputs[i%2], kv_inputs[i%2][0], kv_inputs[i%2][1],
599
600
601
                                cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
                                dropout_p, softmax_scale, causal=False, return_softmax=False,
                                **fa_optional_forward_kwargs
602
                            )
603
604
605
606
607
608

            if i > 0:
                # wait until fwd restuls correction of last step is done
                if i > 1:
                    flash_attn_streams[(i-1)%2].wait_event(fwd_results_correction_done)

609
610
611
612
                if use_fused_attention:
                    # [b, np, sq, 1] -> [b, np, sq]
                    softmax_lse_per_step[i-1].squeeze_(-1)

613
                with torch.cuda.stream(flash_attn_streams[(i-1)%2]):
614
615
616
617
                    if i == 1:
                        out = torch.empty_like(q).zero_()
                        softmax_lse = torch.clone(softmax_lse_per_step[0]).to(torch.double)
                        if causal:
618
619
620
621
                            # [b, np, sq] -> [b, np, 2, sq//2]
                            softmax_lse_ = softmax_lse.view(
                                *softmax_lse.shape[:-1], 2, softmax_lse.shape[-1]//2
                            )
622
623
624
                    elif (i-1) <= rank or not causal:
                        flash_attn_fwd_softmax_lse_correction(softmax_lse,
                                                              softmax_lse_per_step[i-1])
625
                    else:
626
627
                        flash_attn_fwd_softmax_lse_correction(softmax_lse_[..., 1, :],
                                                              softmax_lse_per_step[i-1])
628
629
630
631
632
633
634
635
636
637

                if i < cp_size:
                    flash_attn_streams[(i-1)%2].record_event(fwd_results_correction_done)

        torch.cuda.current_stream().wait_stream(flash_attn_streams[1])

        softmax_lse = softmax_lse.to(torch.float)
        for i in range(cp_size):
            # [b*sq, np, hn] -> [b, sq, np, hn] or [b*sq//2, np, hn] -> [b, sq//2, np, hn]
            out_ = out_per_step[i].view(out.shape[0], -1, *out.shape[-2:])
638
            if i <= rank or not causal:
639
640
641
642
643
644
645
646
647
648
649
                flash_attn_fwd_out_correction(out.view(*out_.shape),
                                              out_,
                                              softmax_lse,
                                              softmax_lse_per_step[i])
            else:
                flash_attn_fwd_out_correction(out[:, 1, ...],
                                              out_,
                                              softmax_lse_[..., 1, :],
                                              softmax_lse_per_step[i])

        kv = p2p_comm_buffers[-1]
650
651
652
653
        if use_fused_attention:
            out = out.view(out.shape[0], -1, *out.shape[-2:])
        else:
            out = out.view(-1, *out.shape[-2:])
654
655
656
657
658
659
660
661
662
663
        ctx.save_for_backward(q, kv, out, softmax_lse, cu_seqlens_q, cu_seqlens_k)
        ctx.rng_states = rng_states
        ctx.cp_group = cp_group
        ctx.cp_global_ranks = cp_global_ranks
        ctx.dropout_p = dropout_p
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_k = max_seqlen_k
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
        ctx.deterministic = deterministic
664
        ctx.use_fused_attention = use_fused_attention
665
666
667
668
669
670
671
672
673
674
675
676
        return out

    @staticmethod
    def backward(ctx, dout):
        q, kv, out, softmax_lse, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors

        cp_size = get_distributed_world_size(ctx.cp_group)
        rank = get_distributed_rank(ctx.cp_group)
        send_dst = ctx.cp_global_ranks[(rank + cp_size - 1) % cp_size]
        recv_src = ctx.cp_global_ranks[(rank + 1) % cp_size]
        batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2)

677
678
679
680
681
682
683
684
685
686
        if ctx.causal:
            # [b, np, sq] -> [b, np, 2, sq//2]
            softmax_lse_ = softmax_lse.view(*softmax_lse.shape[:-1], 2, softmax_lse.shape[-1]//2)
            softmax_lse_ = softmax_lse_[..., 1, :].contiguous()
            if ctx.use_fused_attention:
                # [b, np, sq//2] -> [b, np, sq//2, 1]
                softmax_lse_.unsqueeze_(-1)
        if ctx.use_fused_attention:
            # [b, np, sq] -> [b, np, sq, 1]
            softmax_lse.unsqueeze_(-1)
687
688
689
690
691
692
693
694
695
696
        out = out.view(*q.shape)
        dout = dout.view(*q.shape)
        # Flash Attn outputs
        dq = torch.empty_like(q)

        p2p_comm_buffers = [torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device), \
                            torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device)]
        p2p_comm_buffers[0][0].copy_(kv)
        send_recv_reqs = []

697
698
699
700
701
702
        fa_optional_backward_kwargs = {}
        if _flash_attn_2_4_plus:
            fa_optional_backward_kwargs["alibi_slopes"] = None
        if _flash_attn_2_4_1_plus:
            fa_optional_backward_kwargs["deterministic"] = ctx.deterministic

703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
        for i in range(cp_size):
            # wait until KV is received
            for req in send_recv_reqs:
                req.wait()

            send_tensor = p2p_comm_buffers[i%2]
            recv_tensor = p2p_comm_buffers[(i+1)%2]
            if i == 0:
                send_tensor = send_tensor[0]
                recv_tensor = recv_tensor[0]
            if i == (cp_size-1):
                send_tensor = send_tensor[1]
                recv_tensor = recv_tensor[1]

            send_recv_reqs = flash_attn_p2p_communicate(rank,
                                                        send_tensor,
                                                        send_dst,
                                                        recv_tensor,
                                                        recv_src,
                                                        ctx.cp_group,
                                                        batch_p2p_comm)

            kv = p2p_comm_buffers[i%2][0]
            # In reversed order of fwd
            if ctx.causal:
                if i == (cp_size-1):
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
                    if ctx.use_fused_attention:
                        # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
                        q_ = q.view(q.shape[0], -1, *q.shape[-2:])
                        # [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn]
                        kv_ = kv.view(*kv.shape[0:2], -1, *kv.shape[-2:])
                        # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
                        out_ = out.view(out.shape[0], -1, *out.shape[-2:])
                        dout_ = dout.view(dout.shape[0], -1, *dout.shape[-2:])
                        dq_, dk_, dv_, _ = fused_attn_bwd(
                            ctx.max_seqlen_q, ctx.max_seqlen_k,
                            cu_seqlens_q, cu_seqlens_k,
                            q_, kv_[0], kv_[1], out_, dout_, TE_DType[q.dtype],
                            [softmax_lse, ctx.rng_states[cp_size-i-1]],
                            tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
                            attn_scale=ctx.softmax_scale,
                            dropout=ctx.dropout_p,
                            qkv_layout="bshd_bshd_bshd",
                            attn_mask_type="causal",
                        )
                    else:
                        # [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
                        q_ = q.view(-1, *q.shape[-2:])
                        dq_ = torch.empty_like(q_)
                        # [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn]
                        kv_ = kv.view(2, -1, *kv.shape[-2:])
                        dkv_ = torch.empty_like(kv_)
                        # [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
                        out_ = out.view(-1, *out.shape[-2:])
                        dout_ = dout.view(-1, *dout.shape[-2:])
                        if _flash_attn_2_3_plus:
                            fa_optional_backward_kwargs["window_size"] = [-1, 0]
                        _flash_attn_backward(
                            dout_, q_, kv_[0], kv_[1], out_, softmax_lse,
                            dq_, dkv_[0], dkv_[1], cu_seqlens_q, cu_seqlens_k,
                            ctx.max_seqlen_q, ctx.max_seqlen_k,
                            ctx.dropout_p, ctx.softmax_scale, True,
                            rng_state=ctx.rng_states[cp_size-i-1],
                            **fa_optional_backward_kwargs
                        )
                elif i >= (cp_size-rank-1):
                    if ctx.use_fused_attention:
                        # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
                        q_ = q.view(q.shape[0], -1, *q.shape[-2:])
                        # [2, b, 2, sk//2, np, hn] -> [2, b, sk//2, np, hn]
                        kv_ = kv[:, :, 0, ...].contiguous()
                        # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
                        out_ = out.view(out.shape[0], -1, *out.shape[-2:])
                        dout_ = dout.view(dout.shape[0], -1, *dout.shape[-2:])
                        dq_, dk_, dv_, _ = fused_attn_bwd(
                            ctx.max_seqlen_q, ctx.max_seqlen_k//2,
                            cu_seqlens_q, cu_seqlens_k//2,
                            q_, kv_[0], kv_[1], out_, dout_, TE_DType[q.dtype],
                            [softmax_lse, ctx.rng_states[cp_size-i-1]],
                            tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
                            attn_scale=ctx.softmax_scale,
                            dropout=ctx.dropout_p,
                            qkv_layout="bshd_bshd_bshd",
                            attn_mask_type="no_mask",
                        )
                    else:
                        # [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
                        q_ = q.view(-1, *q.shape[-2:])
                        dq_ = torch.empty_like(q_)
                        # [2, b, 2, sk//2, np, hn] -> [2, b, sk//2, np, hn] -> [2, b*sk//2, np, hn]
                        kv_ = kv[:, :, 0, ...].contiguous().view(2, -1, *kv.shape[-2:])
                        dkv_ = torch.empty_like(kv_)
                        # [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
                        out_ = out.view(-1, *out.shape[-2:])
                        dout_ = dout.view(-1, *dout.shape[-2:])
                        if _flash_attn_2_3_plus:
                            fa_optional_backward_kwargs["window_size"] = [-1, -1]
                        _flash_attn_backward(
                            dout_, q_, kv_[0], kv_[1], out_, softmax_lse,
                            dq_, dkv_[0], dkv_[1], cu_seqlens_q, cu_seqlens_k//2,
                            ctx.max_seqlen_q, ctx.max_seqlen_k//2,
                            ctx.dropout_p, ctx.softmax_scale, False,
                            rng_state=ctx.rng_states[cp_size-i-1],
                            **fa_optional_backward_kwargs
                        )
                else:
                    if ctx.use_fused_attention:
                        # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
                        q_ = q[:, 1, ...].contiguous()
                        # [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn]
                        kv_ = kv.view(*kv.shape[0:2], -1, *kv.shape[-2:])
                        # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
                        out_ = out[:, 1, ...].contiguous()
                        dout_ = dout[:, 1, ...].contiguous()
                        dq_, dk_, dv_, _ = fused_attn_bwd(
                            ctx.max_seqlen_q//2, ctx.max_seqlen_k,
                            cu_seqlens_q//2, cu_seqlens_k,
                            q_, kv_[0], kv_[1], out_, dout_, TE_DType[q.dtype],
                            [softmax_lse_, ctx.rng_states[cp_size-i-1]],
                            tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
                            attn_scale=ctx.softmax_scale,
                            dropout=ctx.dropout_p,
                            qkv_layout="bshd_bshd_bshd",
                            attn_mask_type="no_mask",
                        )
                    else:
                        # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] -> [b*sq//2, np, hn]
                        q_ = q[:, 1, ...].contiguous().view(-1, *q.shape[-2:])
                        dq_ = torch.empty_like(q_)
                        # [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn]
                        kv_ = kv.view(2, -1, *kv.shape[-2:])
                        dkv_ = torch.empty_like(kv_)
                        # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] -> [b*sq//2, np, hn]
                        out_ = out[:, 1, ...].contiguous().view(-1, *out.shape[-2:])
                        dout_ = dout[:, 1, ...].contiguous().view(-1, *dout.shape[-2:])
                        if _flash_attn_2_3_plus:
                            fa_optional_backward_kwargs["window_size"] = [-1, -1]
                        _flash_attn_backward(
                            dout_, q_, kv_[0], kv_[1], out_, softmax_lse_,
                            dq_, dkv_[0], dkv_[1], cu_seqlens_q//2, cu_seqlens_k,
                            ctx.max_seqlen_q//2, ctx.max_seqlen_k,
                            ctx.dropout_p, ctx.softmax_scale, False,
                            rng_state=ctx.rng_states[cp_size-i-1],
                            **fa_optional_backward_kwargs
                        )
            else:
                if ctx.use_fused_attention:
                    dq_, dk_, dv_, _ = fused_attn_bwd(
                        ctx.max_seqlen_q, ctx.max_seqlen_k,
                        cu_seqlens_q, cu_seqlens_k,
                        q, kv[0], kv[1], out, dout, TE_DType[q.dtype],
                        [softmax_lse, ctx.rng_states[cp_size-i-1]],
                        tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
                        attn_scale=ctx.softmax_scale,
                        dropout=ctx.dropout_p,
                        qkv_layout="bshd_bshd_bshd",
                        attn_mask_type="no_mask",
                    )
                else:
                    # [b, sq, np, hn] -> [b*sq, np, hn]
863
864
                    q_ = q.view(-1, *q.shape[-2:])
                    dq_ = torch.empty_like(q_)
865
                    # [2, b, sk, np, hn] -> [2, b*sk, np, hn]
866
867
                    kv_ = kv.view(2, -1, *kv.shape[-2:])
                    dkv_ = torch.empty_like(kv_)
868
                    # [b, sq, np, hn] -> [b*sq, np, hn]
869
870
                    out_ = out.view(-1, *out.shape[-2:])
                    dout_ = dout.view(-1, *dout.shape[-2:])
871
872
                    if _flash_attn_2_3_plus:
                        fa_optional_backward_kwargs["window_size"] = [-1, -1]
873
874
875
876
877
                    _flash_attn_backward(
                        dout_, q_, kv_[0], kv_[1], out_, softmax_lse,
                        dq_, dkv_[0], dkv_[1], cu_seqlens_q, cu_seqlens_k,
                        ctx.max_seqlen_q, ctx.max_seqlen_k,
                        ctx.dropout_p, ctx.softmax_scale, False,
878
                        **fa_optional_backward_kwargs
879
880
                    )

881
882
883
884
885
886
887
            if i >= (cp_size-rank-1) or not ctx.causal:
                # [b*sq, np, hn] -> [b, 2, sq//2, np, hn] if causal
                # [b*sq, np, hn] -> [b, sq, np, hn] if not causal
                dq_ = dq_.view(*dq.shape)
            else:
                # [b*sq//2, np, hn] -> [b, sq//2, np, hn]
                dq_ = dq_.view(dq.shape[0], *dq.shape[2:])
888

889
            if ctx.causal:
890
891
892
893
894
895
896
897
898
899
900
901
                if i > (cp_size-rank-1):
                    dq.add_(dq_)
                elif i == (cp_size-rank-1):
                    if rank == (cp_size-1):
                        dq.copy_(dq_)
                    else:
                        dq[:, 0, ...].copy_(dq_[:, 0, ...])
                        dq[:, 1, ...].add_(dq_[:, 1, ...])
                elif i > 0:
                    dq[:, 1, ...].add_(dq_)
                else:
                    dq[:, 1, ...].copy_(dq_)
902
903
904
905
906
            else:
                if i == 0:
                    dq.copy_(dq_)
                else:
                    dq.add_(dq_)
907

908
909
910
            # wait until dKV is received
            for req in send_recv_reqs:
                req.wait()
911

912
913
914
915
916
917
918
919
920
921
            dkv = p2p_comm_buffers[(i+1)%2][1]
            if ctx.use_fused_attention:
                dkv_ = torch.cat((dk_.unsqueeze(0), dv_.unsqueeze(0)), dim=0)
            if ctx.causal and i >= (cp_size-rank-1) and i != (cp_size-1):
                # [2, b*sk//2, np, hn] -> [2, b, sk//2, np, hn]
                dkv_ = dkv_.view(*dkv.shape[0:2], *dkv.shape[3:])
            else:
                # [2, b*sk, np, hn] -> [2, b, 2, sk//2, np, hn] if causal
                # [2, b*sk, np, hn] -> [2, b, sk, np, hn] if not causal
                dkv_ = dkv_.view(*dkv.shape)
922

923
            if ctx.causal:
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
                if i == (cp_size-1):
                    if rank == 0:
                        dkv[:, :, 0, ...].add_(dkv_[:, :, 0, ...])
                        dkv[:, :, 1, ...].copy_(dkv_[:, :, 1, ...])
                    else:
                        dkv.add_(dkv_)
                elif i >= (cp_size-rank-1):
                    if i == 0 and rank == (cp_size-1):
                        dkv[:, :, 0, ...].copy_(dkv_)
                    else:
                        dkv[:, :, 0, ...].add_(dkv_)
                elif i > 0:
                    dkv.add_(dkv_)
                else:
                    dkv.copy_(dkv_)
            else:
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
                if i == 0:
                    dkv.copy_(dkv_)
                else:
                    dkv.add_(dkv_)

        if ctx.causal:
            # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
            dq = dq.view(q.shape[0], -1, *q.shape[-2:])
            # [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn]
            dkv = dkv.view(*kv.shape[0:2], -1, *kv.shape[-2:])
        return None, dq, dkv[0], dkv[1], None, None, None, None, None, None, \
                None, None, None, None, None, None


def attn_forward_func_with_cp(
    is_training, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p,
    cp_group, cp_global_ranks, cp_stream, softmax_scale=None, attn_mask_type="causal",
    deterministic=False, use_fused_attention=False
) -> torch.Tensor:
    """Attention implementation with context parallelism"""
    assert (attn_mask_type in ["causal", "no_mask"]
        ), f"Mask type of {attn_mask_type} is not supported with context parallelism!"
    out = AttnFuncWithCP.apply(
        is_training, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
        dropout_p, cp_group, cp_global_ranks, cp_stream, softmax_scale, attn_mask_type,
        deterministic, use_fused_attention
966
967
968
969
    )
    return out


970
971
972
973
974
975
976
class RotaryPositionEmbedding(torch.nn.Module):
    """
    Implements Rotary Position Embedding from https://arxiv.org/abs/2104.09864.
    """
    def __init__(
        self,
        dim: int,
977
        rotary_percent: float = 1.0,
978
979
980
981
982
983
984
985
        seq_len_interpolation_factor: Optional[int] = None,
        pretrained_max_position_embeddings: Optional[int] = None,
    ):
        """
        Parameters
        ----------
        dim: int
            rotary embedding dimension
986
987
        rotary_percent: float
            Percent of rotary dimension to use for rotary position embeddings.
988
989
990
991
992
993
994
        seq_len_interpolation_factor: int
            if not None, discrete positions will be interpolated by this factor via the trick in
            https://arxiv.org/abs/2306.15595
        pretrained_max_position_embeddings: int
            pre-trained max_position_embeddings before position interpolation
        """
        super().__init__()
995
996
        if rotary_percent < 1.0:
            dim = int(dim * rotary_percent)
997
        self.seq_len_interpolation_factor = seq_len_interpolation_factor
998
999
1000
1001
1002
1003
1004
        inv_freq = 1.0 / (
            10000
            ** (
                torch.arange(0, dim, 2, dtype=torch.float32, device=torch.cuda.current_device())
                / dim
            )
        )
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
        self.register_buffer('inv_freq', inv_freq)
        self.pretrained_max_position_embeddings = pretrained_max_position_embeddings

    def forward(self, max_seq_len: int, offset: int = 0):
        """
        Create rotary position embedding frequencies

        Parameters
        ----------
        max_seq_len: int
            sequence length of a sample
        offset: int, default = 0
            fixed offset for freqencies
        """
1019
1020
1021
1022
        seq = (
            torch.arange(max_seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
            + offset
        )
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040

        if (self.pretrained_max_position_embeddings is not None
            and self.seq_len_interpolation_factor is not None):
            if (max_seq_len >
                self.pretrained_max_position_embeddings * self.seq_len_interpolation_factor):
                # dynamic linear scaling (length > position we have learned)
                seq *= 1 / (max_seq_len / self.pretrained_max_position_embeddings)
            else:
                # fixed linear scaling
                seq *= 1 / self.seq_len_interpolation_factor

        freqs = torch.einsum('i , j -> i j', seq, self.inv_freq)
        # first part even vector components, second part odd vector components,
        #  2 * dim in dimension size
        emb = torch.cat((freqs, freqs), dim=-1)
        # emb [seq_length, .., dim]
        return emb.reshape(emb.size(0), 1, 1, emb.size(1))

1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092

class FusedRoPEFunc(torch.autograd.Function):
    """
    Function for FusedRoPE

    This implementation assumes the input tensor to be in `sbhd`, `bshd` or `thd` format and
    the RoPE tensor to be of shape (s, 1, 1, d). It accepts arbitrary memory layouts to avoid
    the expensive `.contiguous()` calls, thus it may not achieve the best memory access pattern.
    """

    @staticmethod
    def forward(
        ctx,
        t: torch.Tensor,
        freqs: torch.Tensor,
        tensor_format: str = "sbhd",
        cu_seqlens: Union[torch.Tensor, None] = None,
    ) -> torch.Tensor:
        if tensor_format == "sbhd":
            output = tex.fused_rope_forward(t, freqs, False)
        elif tensor_format == "bshd":
            output = tex.fused_rope_forward(
                t.transpose(0, 1), freqs, True
            ).transpose(0, 1)
        elif tensor_format == "thd":
            output = tex.fused_rope_thd_forward(t, cu_seqlens, freqs)
        else:
            raise ValueError(f"Unsupported tensor_format: {tensor_format}.")
        ctx.save_for_backward(freqs, cu_seqlens)
        ctx.tensor_format = tensor_format

        return output

    @staticmethod
    def backward(
        ctx, grad_output: torch.Tensor
    ) -> Tuple[Union[torch.Tensor, None], ...]:
        freqs, cu_seqlens = ctx.saved_tensors
        if ctx.tensor_format == "sbhd":
            grad_input = tex.fused_rope_backward(grad_output, freqs, False)
        elif ctx.tensor_format == "bshd":
            grad_input = tex.fused_rope_backward(
                grad_output.transpose(0, 1), freqs, True
            ).transpose(0, 1)
        elif ctx.tensor_format == "thd":
            grad_input = tex.fused_rope_thd_backward(grad_output, cu_seqlens, freqs)
        else:
            raise ValueError(f"Unsupported tensor_format: {ctx.tensor_format}.")

        return grad_input, None, None, None, None


1093
1094
1095
1096
1097
1098
1099
1100
1101
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
    """
    change sign so the last dimension becomes [-odd, +even]
    """
    x = x.view(x.shape[:-1] + torch.Size((2, x.shape[-1] // 2)))
    x1, x2 = x.unbind(dim=-2)
    return torch.cat((-x2, x1), dim=-1)


1102
def apply_rotary_pos_emb(
1103
1104
1105
1106
1107
1108
    t: torch.Tensor,
    freqs: torch.Tensor,
    tensor_format: str = "sbhd",
    fused: bool = False,
    cu_seqlens: Union[torch.Tensor, None] = None,
) -> torch.Tensor:
1109
    """
1110
    Apply rotary positional embedding tensor to the input tensor.
1111

1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
    Parameters
    ----------
    t: torch.Tensor
        Input tensor of shape `[s, b, h, d]`, `[s, b, h, d]` or `[t, h, d]`, on which
        rotary positional embedding will be applied.
    freqs: torch.Tensor
        Rotary positional embedding tensor of shape `[s2, 1, 1, d2]` and dtype 'float',
        with `s2 >= s` and `d2 <= d`.
    fused: bool, default = False
        Whether to use a fused applying RoPE implementation.
    tensor_format: {'sbhd', 'bshd', 'thd'}, default = 'sbhd'
        is `bshd` if `t` is of shape `[bs, seq, ...]`, or `sbhd` if `t` is
        of shape `[seq, bs, ...]`. 'thd' is only supported when `fused` is True.
    cu_seqlens: torch.Tensor, default = None.
        Cumulative sum of sequence lengths in a batch for `t`, with shape [b + 1] and
        dtype torch.int32. Only valid when `tensor_format` is 'thd'.
1128
    """
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
    if fused:
        assert (
            tensor_format != "thd" or cu_seqlens is not None
        ), "cu_seqlens must not be None when tensor_format is 'thd'."
        return FusedRoPEFunc.apply(t, freqs, tensor_format, cu_seqlens)

    assert tensor_format in ("sbhd", "bshd"), (
        "Only formats `sbhd` or `bshd` are supported for input tensor `t` "
        f"when fused is False, got {tensor_format}."
    )

1140
1141
1142
1143
1144
    max_seq_len = freqs.shape[0]
    cur_seq_len = t.shape[1] if tensor_format == "bshd" else t.shape[0]

    # Only apply the rotary embeddings up to the sequence length of the running
    # input.
1145
1146
1147
1148
    assert cur_seq_len <= max_seq_len, (
        f"Rotary Embeddings only supported up to {max_seq_len} sequence length!"
    )
    freqs = freqs[:cur_seq_len]
1149
    if tensor_format == "bshd":
1150
1151
1152
1153
        freqs = freqs.transpose(0, 1)  # [seq, 1, 1, dim] -> [1, seq, 1, dim]
    # cos/sin first then dtype conversion for better precision
    cos_ = torch.cos(freqs).to(t.dtype)
    sin_ = torch.sin(freqs).to(t.dtype)
1154

1155
1156
1157
1158
1159
1160
    rot_dim = freqs.shape[-1]
    # ideally t_pass is empty so rotary pos embedding is applied to all tensor t
    t, t_pass = t[..., :rot_dim], t[..., rot_dim:]

    # first part is cosine component
    # second part is sine component, need to change signs with _rotate_half method
1161
    t = (t * cos_) + (_rotate_half(t) * sin_)
1162
1163
1164
    return torch.cat((t, t_pass), dim=-1)


cyanguwa's avatar
cyanguwa committed
1165
class _SplitAlongDim(torch.autograd.Function):
1166
1167
1168
1169
1170
    """"""

    @staticmethod
    def forward(ctx,
                mixed_x_layer: torch.Tensor,
cyanguwa's avatar
cyanguwa committed
1171
1172
                split_dim: int,
                split_size_or_sections: Union[int, List[int], Tuple[int]],
1173
    ) -> Tuple[torch.Tensor, ...]:
cyanguwa's avatar
cyanguwa committed
1174
1175
1176
        ctx.split_dim = split_dim
        ctx.split_size_or_sections = split_size_or_sections
        return torch.split(mixed_x_layer, split_size_or_sections, dim = split_dim)
1177
1178
1179
1180
1181
1182

    @staticmethod
    def backward(ctx,
                 *grad_outputs):
        assert len(grad_outputs) > 0, "No gradients received for backprop!"

cyanguwa's avatar
cyanguwa committed
1183
1184
1185
1186
1187
1188
1189
1190
1191
        if isinstance(ctx.split_size_or_sections, (list, tuple)):
            split_sizes = ctx.split_size_or_sections
            assert (len(grad_outputs) == len(split_sizes)
                ), "Unequal number of gradients vs split sections for backprop!"
        if isinstance(ctx.split_size_or_sections, int):
            split_sizes = [ctx.split_size_or_sections] * len(grad_outputs)
        dims = len(grad_outputs[0].shape)
        split_dim = (ctx.split_dim + dims) % dims

1192
1193
        noop_ok = True
        strides = grad_outputs[0].stride()
1194
        data_ptr = grad_outputs[0].untyped_storage().data_ptr()
cyanguwa's avatar
cyanguwa committed
1195
        shape = list(grad_outputs[0].shape)
1196
        for i, tensor in enumerate(grad_outputs):
cyanguwa's avatar
cyanguwa committed
1197
1198
1199
            shape_i = shape
            shape_i[split_dim] = split_sizes[i]
            offset_size = sum(split_sizes[:i]) * np.prod(shape[split_dim+1:])
1200
            if (tensor.stride() != strides or
cyanguwa's avatar
cyanguwa committed
1201
                list(tensor.shape) != shape_i or
1202
                tensor.untyped_storage().data_ptr() != data_ptr or
cyanguwa's avatar
cyanguwa committed
1203
                tensor.storage_offset() != offset_size):
1204
1205
1206
1207
1208
1209
1210
                noop_ok = False
                break

        if noop_ok:
            ret = torch.Tensor().to(device=grad_outputs[0].device,
                                    dtype=grad_outputs[0].dtype)
            new_shape = list(shape)
cyanguwa's avatar
cyanguwa committed
1211
1212
            new_shape[split_dim] = sum(split_sizes)
            ret.set_(grad_outputs[0].untyped_storage(),
1213
1214
                     grad_outputs[0].storage_offset(),
                     new_shape,
cyanguwa's avatar
cyanguwa committed
1215
                     strides
1216
            )
cyanguwa's avatar
cyanguwa committed
1217
            return ret, None, None
1218

cyanguwa's avatar
cyanguwa committed
1219
        return torch.cat(grad_outputs, dim = split_dim), None, None
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239


class UnfusedDotProductAttention(torch.nn.Module):
    """Parallel attention w/o QKV and Proj Gemms
    BMM1 -> softmax + dropout -> BMM2
    """

    def __init__(
        self,
        norm_factor: float,
        attention_dropout: float = 0.0,
        attention_dropout_ctx: Optional[Callable] = nullcontext,
        layer_number: Optional[int] = None,
    ) -> None:
        super().__init__()

        self.norm_factor = norm_factor
        self.attention_dropout_ctx = attention_dropout_ctx
        self.layer_number = layer_number

1240
        self.scale_mask_softmax = FusedScaleMaskSoftmax(attention_mask_func)
1241
1242
1243
1244
1245
1246

        # Dropout. Note that for a single iteration, this layer will generate
        # different outputs on different number of parallel partitions but
        # on average it should not be partition dependent.
        self.attention_dropout = torch.nn.Dropout(attention_dropout)

1247
1248
1249
1250
        # An FP16 training trick required for certain GPT-like models.
        self.apply_qk_layer_scaling = (
            bool(int(os.getenv("NVTE_APPLY_QK_LAYER_SCALING", "0"))) and layer_number is not None)

1251
1252
1253
1254
1255
    def forward(
        self,
        query_layer: torch.Tensor,
        key_layer: torch.Tensor,
        value_layer: torch.Tensor,
1256
1257
1258
        qkv_layout: str = "sbh3d",
        cu_seqlens_q: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
        cu_seqlens_kv: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
1259
        attn_mask_type: str = "causal",
1260
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
1261
1262
        core_attention_bias_type: str = "no_bias",
        core_attention_bias: Optional[torch.Tensor] = None,
1263
    ) -> torch.Tensor:
1264
        """Unfused attention fprop"""
1265

1266
1267
1268
1269
1270
1271
1272
1273
1274
        assert (qkv_layout in QKVLayouts
            ), f"UnfusedDotProductAttention does not support qkv_layout = {qkv_layout}!"
        qkv_format = ''.join([i for i in qkv_layout.split('_')[0] if i.isalpha()])
        assert (qkv_format != 'thd'
            ), """UnfusedDotProductAttention does not support variable sequence lengths!"""
        if qkv_format == 'bshd':
            # convert to sbhd and use sbhd implementation for now
            query_layer, key_layer, value_layer = [x.transpose(0, 1)
                for x in [query_layer, key_layer, value_layer]]
1275

1276
        batch_size, seqlen = query_layer.shape[1], query_layer.shape[0]
1277
        apply_qk_layer_scaling = self.apply_qk_layer_scaling and key_layer.dtype == torch.float16
1278
1279
1280
1281
1282
1283
1284
1285
1286

        # [b, np, sq, sk]
        output_size = (
            query_layer.size(1),
            query_layer.size(2),
            query_layer.size(0),
            key_layer.size(0),
        )

1287
1288
1289
1290
1291
1292
1293
1294
        if key_layer.shape[2] != query_layer.shape[2]:
            assert (query_layer.shape[2]%key_layer.shape[2]==0
                ),"The number of attention heads must be divisible by the number of GQA groups!"
            key_layer = key_layer.repeat_interleave(
                    int(query_layer.shape[2]/key_layer.shape[2]), dim = 2)
            value_layer = value_layer.repeat_interleave(
                    int(query_layer.shape[2]/value_layer.shape[2]), dim = 2)

1295
1296
1297
1298
1299
1300
1301
1302
        # [sq, b, np, hn] -> [sq, b * np, hn]
        query_layer = query_layer.reshape(
            output_size[2], output_size[0] * output_size[1], -1
        )
        # [sk, b, np, hn] -> [sk, b * np, hn]
        key_layer = key_layer.reshape(output_size[3], output_size[0] * output_size[1], -1)

        # preallocting result tensor: [b * np, sq, sk]
1303
1304
        # WAR to set dtype to FP32 as ONNX lacks BF16 support for ConstantOfShape operator
        is_bf16 = query_layer.dtype == torch.bfloat16
1305
1306
1307
1308
        matmul_result = torch.empty(
            output_size[0] * output_size[1],
            output_size[2],
            output_size[3],
1309
            dtype=torch.float32 if is_in_onnx_export_mode() and is_bf16 else query_layer.dtype,
1310
1311
1312
            device=torch.cuda.current_device(),
        )

1313
1314
1315
        if is_in_onnx_export_mode() and is_bf16:
            matmul_result = matmul_result.bfloat16()

1316
1317
1318
1319
1320
        scale = self.norm_factor
        if apply_qk_layer_scaling:
            scale *= self.layer_number

        # Raw attention scores. [b * np, sq, sk]
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
        if core_attention_bias_type == "no_bias":
            matmul_result = torch.baddbmm(
                matmul_result,
                query_layer.transpose(0, 1),  # [b * np, sq, hn]
                key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
                beta=0.0,
                alpha=(1.0 / scale),
            )

        elif core_attention_bias_type == "pre_scale_bias":
            assert core_attention_bias is not None, "core_attention_bias should not be None!"
            assert (core_attention_bias.shape == torch.Size(1, *output_size[1:])
                    ), "core_attention_bias must be in [1, h, sq, skv] shape!"
            matmul_result = torch.bmm(
                query_layer.transpose(0, 1),  # [b * np, sq, hn]
                key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
            )
            matmul_result = (matmul_result.view(
                output_size[0], output_size[1], output_size[2], output_size[3])
                + core_attention_bias).view(-1, output_size[2], output_size[3])
            matmul_result /= scale

1343
1344
1345
1346
1347
1348
1349
        elif core_attention_bias_type in ["post_scale_bias", "alibi"]:
            if core_attention_bias_type == "post_scale_bias":
                assert core_attention_bias is not None, "core_attention_bias should not be None!"
                assert (core_attention_bias.shape == torch.Size([1, *output_size[1:]])
                        ), "core_attention_bias must be in [1, h, sq, skv] shape!"
            if core_attention_bias_type == "alibi":
                core_attention_bias = get_alibi(output_size[1], output_size[2], output_size[3])
1350
1351
1352
1353
1354
1355
1356
1357
1358
            matmul_result = torch.baddbmm(
                matmul_result,
                query_layer.transpose(0, 1),  # [b * np, sq, hn]
                key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
                beta=0.0,
                alpha=(1.0 / scale),
            )
            matmul_result = (matmul_result.view(
                output_size[0], output_size[1], output_size[2], output_size[3])
1359
1360
                + core_attention_bias).view(-1, output_size[2], output_size[3]).to(
                dtype=query_layer.dtype)
1361
1362
1363
1364
1365
1366

        # change view to [b, np, sq, sk]
        attention_scores = matmul_result.view(*output_size)

        # attention scores and attention mask [b, np, sq, sk]
        softmax_scale = self.layer_number if apply_qk_layer_scaling else None
1367
1368
        attention_probs = self.scale_mask_softmax(
            attention_scores, attention_mask, attn_mask_type, softmax_scale)
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        with self.attention_dropout_ctx():
            attention_probs = self.attention_dropout(attention_probs)

        # value_layer -> context layer.
        # [sk, b, np, hn] --> [b, np, sq, hn]
        output_size = (
            value_layer.size(1),
            value_layer.size(2),
            query_layer.size(0),
            value_layer.size(3),
        )

        # change view [sk, b * np, hn]
        value_layer = value_layer.reshape(
            value_layer.size(0), output_size[0] * output_size[1], -1
        )

        # change view [b * np, sq, sk]
        attention_probs = attention_probs.view(
            output_size[0] * output_size[1], output_size[2], -1
        )

        # matmul: [b * np, sq, hn]
        context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))

        # change view [b, np, sq, hn]
        context_layer = context_layer.view(*output_size)

1400
1401
1402
        if qkv_format == 'sbhd':
            # [b, np, sq, hn] --> [sq, b, np, hn]
            context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
1403

1404
1405
1406
1407
1408
1409
1410
1411
1412
            # [sq, b, np, hn] --> [sq, b, hp]
            context_layer = context_layer.view(seqlen, batch_size, -1)

        if qkv_format == 'bshd':
            # [b, np, sq, hn] --> [b, sq, np, hn]
            context_layer = context_layer.permute(0, 2, 1, 3).contiguous()

            # [b, sq, np, hn] --> [b, sq, hp]
            context_layer = context_layer.view(batch_size, seqlen, -1)
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446

        return context_layer


class _PrepareQKVForFA(torch.autograd.Function):
    """This class converts QKV from interleaved (s, b, ...) layout
       to separate contiguous q, k, v tensors in (b, s, ...) layout."""

    @staticmethod
    def forward(ctx,
                query_layer: torch.Tensor,
                key_layer: torch.Tensor,
                value_layer: torch.Tensor
    ) -> torch.Tensor:
        # All inputs received are non-contiguous tensors.
        # The `query_layer` tensor is used to access the
        # full memory region of the QKV tensor.
        qkv = tex.fa_prepare_fwd(query_layer)
        q, k, v = split_tensor_along_dim(qkv, 0, 3)
        query_layer = torch.squeeze(q, 0)
        key_layer = torch.squeeze(k, 0)
        value_layer = torch.squeeze(v, 0)
        return query_layer, key_layer, value_layer

    @staticmethod
    def backward(ctx,
                 dq: torch.Tensor,
                 dk: torch.Tensor,
                 dv: torch.Tensor
    ) -> Tuple[Union[torch.Tensor, None], ...]:
        dqkv = tex.fa_prepare_bwd(dq, dk, dv)
        dq, dk, dv = split_tensor_along_dim(dqkv, -1, 3)
        return dq, dk, dv

1447

1448
1449
1450
1451
1452
1453
1454
def _get_qkv_layout(
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
        qkv_format: str = 'sbhd',
    ) -> str:
    """Get qkv layout.
1455

1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
    Parameters
    ----------
    q: torch.Tensor
        Query tensor.
    k: torch.Tensor
        Key tensor.
    v: torch.Tensor
        Value tensor.
    qkv_format: str, default = `sbhd`
        Dimension format for `q`, `k` and `v`, {`sbhd`, `bshd`, `thd`}. `s` stands for
        the sequence length dimension, `b` batch size, `h` the number of attention heads,
        `d` head size, and `t` the total number of sequences in a batch, i.e.
        `t = sum(s_i) for i = 0...b-1`.

    Returns
    ----------
    qkv_layout: str
       Memory layout of `q`, `k` and `v`. Each `qkv_format` can be mapped to one of five
       memory layouts. For example, `sb3hd` means `q`, `k`, `v` are created as one chunk
       of memory and that they are interleaved in the `2`nd dimension. `sbhd_sbh2d` means
       `q` and `kv` are created in two chunks and that `q` itself is contiguous and `k`, `v`
       are interleaved with each other in the `3`rd dimension, `k = kv[:,:,:,0,:]` and
       `v = kv[:,:,:,1,:]`.
       Mapping:
       `sbhd`: {`sb3hd`, `sbh3d`, `sbhd_sb2hd`, `sbhd_sbh2d`, `sbhd_sbhd_sbhd`}
       `bshd`: {`bs3hd`, `bsh3d`, `bshd_bs2hd`, `bshd_bsh2d`, `bshd_bshd_bshd`}
       `thd` : {`t3hd`, `th3d`, `thd_t2hd`, `thd_th2d`, `thd_thd_thd`}
    """
1484

1485
1486
    check_last_dim_contiguous = all(x.stride(-1) == 1 for x in [q, k, v])
    assert check_last_dim_contiguous, "q, k and v must have stride 1 in their last dimension!"
1487

1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
    def run_iteratively(q, k, v):
        data_ptr = q.untyped_storage().data_ptr()
        check_ptrs_qkv = all(x.untyped_storage().data_ptr() == data_ptr for x in [q, k, v])
        data_ptr = k.untyped_storage().data_ptr()
        check_ptrs_kv = all(x.untyped_storage().data_ptr() == data_ptr for x in [k, v])

        stride = q.stride()
        check_strides_qkv = all(stride == x.stride() for x in [q, k, v])
        stride = k.stride()
        check_strides_kv = all(stride == x.stride() for x in [k, v])

        shape = q.shape
        check_shapes_qkv = all(shape == x.shape for x in [q, k, v])
        shape = k.shape
        check_shapes_kv = all(shape == x.shape for x in [k, v])

        last_dim_size = q.shape[-1]
        check_last_dim_offsets_qkv = all(i * last_dim_size == x.storage_offset()
                            for i, x in enumerate([q, k, v]))
        last_dim_size = k.shape[-1]
        check_last_dim_offsets_kv = all(i * last_dim_size == x.storage_offset()
                            for i, x in enumerate([k, v]))

        last_two_dims_size = q.shape[-1] * q.shape[-2]
        check_last_two_dims_offsets_qkv = all(i * last_two_dims_size == x.storage_offset()
                            for i, x in enumerate([q, k, v]))
        last_two_dims_size = k.shape[-1] * k.shape[-2]
        check_last_two_dims_offsets_kv = all(i * last_two_dims_size == x.storage_offset()
                            for i, x in enumerate([k, v]))

        if (check_ptrs_qkv and check_strides_qkv and check_shapes_qkv
            and check_last_two_dims_offsets_qkv
            and not check_last_dim_offsets_qkv):
            # sb3hd, bs3hd, t3hd
            qkv_layout = qkv_format[:-2] + '3' + qkv_format[-2:]
        elif (check_ptrs_qkv and check_strides_qkv and check_shapes_qkv
            and check_last_dim_offsets_qkv):
            # sbh3d, bsh3d, th3d
            qkv_layout = qkv_format[:-1] + '3' + qkv_format[-1:]
        elif (check_ptrs_kv and check_strides_kv and check_shapes_kv
            and check_last_two_dims_offsets_kv
            and not check_last_dim_offsets_kv):
            # sbhd_sb2hd, bshd_bs2hd, thd_t2hd
            qkv_layout = qkv_format + '_' + qkv_format[:-2] + '2' + qkv_format[-2:]
        elif (check_ptrs_kv and check_strides_kv and check_shapes_kv
            and check_last_dim_offsets_kv):
            # sbhd_sbh2d, bshd_bsh2d, thd_th2d
            qkv_layout = qkv_format + '_' + qkv_format[:-1] + '2' + qkv_format[-1:]
        elif check_strides_kv and check_shapes_kv:
            # sbhd_sbhd_sbhd, bshd_bshd_bshd, thd_thd_thd
            qkv_layout = '_'.join(list([qkv_format])*3)
        else:
            qkv_layout = 'not_supported'

        return qkv_layout

    qkv_layout = run_iteratively(q, k, v)
    if qkv_layout == 'not_supported':
        # force q,k,v to be contiguous and run get_layout again
        q, k, v = [x.contiguous() for x in [q, k, v]]
        qkv_layout = run_iteratively(q, k, v)
    if qkv_layout == 'not_supported':
1550
1551
        raise Exception("The provided qkv memory layout is not supported!")

1552
    return qkv_layout, q, k, v
1553

1554

1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
def check_set_window_size(
        attn_mask_type: str,
        window_size: Tuple[int, int] = None,
    ):
    """Check if sliding window size is compliant with mask type and if not,
    assert or set it to the appropriate size
    """
    if "causal" in attn_mask_type:
        if window_size is None:
            window_size = (-1, 0)
        else:
            assert (
                window_size[1] == 0
            ), "window_size[1] should be 0 when self_attn_mask_type includes 'causal'!"
    else:
        if window_size is None:
            window_size = (-1, -1)
    return window_size
1573

1574

1575
class FlashAttention(torch.nn.Module):
1576
    """Dot product attention, using HazyResearch flash-attn package:
1577
    https://github.com/Dao-AILab/flash-attention
1578
1579
1580
1581
1582
1583
1584
    """

    def __init__(
        self,
        norm_factor: float,
        attention_dropout: float = 0.0,
        attention_dropout_ctx: Optional[Callable] = nullcontext,
1585
1586
        attention_type: str = "self",
        layer_number: Optional[int] = None,
1587
        deterministic: bool = False,
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
    ) -> None:
        super().__init__()

        assert (
            _flash_attn_version >= _flash_attn_version_required
        ), f"FlashAttention minimum version {_flash_attn_version_required} is required."

        self.norm_factor = norm_factor
        self.attention_dropout_ctx = attention_dropout_ctx
        self.attention_dropout = attention_dropout
1598
1599
        self.attention_type = attention_type
        self.layer_number = 1 if layer_number is None else layer_number
1600
        self.deterministic = deterministic
1601
1602
1603
1604
1605
1606

    def forward(
        self,
        query_layer: torch.Tensor,
        key_layer: torch.Tensor,
        value_layer: torch.Tensor,
1607
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
1608
1609
1610
        qkv_layout: str = "sbh3d",
        cu_seqlens_q: Optional[torch.Tensor] = None,
        cu_seqlens_kv: Optional[torch.Tensor] = None,
1611
1612
        max_seqlen_q: Optional[int] = None,
        max_seqlen_kv: Optional[int] = None,
1613
        attn_mask_type: str = "causal",
1614
        window_size: Optional[Tuple[int, int]] = None,
1615
        alibi_slopes: Optional[torch.Tensor] = None,
1616
        cp_group: Optional[dist_group_type] = None,
1617
        cp_global_ranks: List[int] = None,
1618
        cp_stream: torch.cuda.Stream = None,
1619
1620
1621
    ) -> torch.Tensor:
        """flash-attn fprop"""

1622
1623
        window_size = check_set_window_size(attn_mask_type, window_size)

1624
        assert (
1625
1626
1627
            query_layer.dtype in [torch.float16, torch.bfloat16]
            and key_layer.dtype in [torch.float16, torch.bfloat16]
            and value_layer.dtype in [torch.float16, torch.bfloat16]
1628
            ), "FlashAttention currently only supports FP16 and BF16."
1629
1630
        assert (
            query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda
1631
1632
1633
1634
1635
            ), "FlashAttention currently only supports CUDA tensors."
        assert (
            qkv_layout in QKVLayouts
            ), f"FlashAttention does not support qkv_layout = {qkv_layout}!"

1636
1637
        context_parallel = (cp_group is not None) and (get_distributed_world_size(cp_group) != 1)

1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
        qkv_format = ''.join([i for i in qkv_layout.split('_')[0] if i.isalpha()])

        if qkv_format == 'sbhd':
            # For now just 128, will make it more general in the future
            if (query_layer.shape[-1] == 128 and
                query_layer.shape[0] * query_layer.shape[1] >= 512 and
                qkv_layout == "sbh3d"):
                query_layer, key_layer, value_layer = _PrepareQKVForFA.apply(query_layer,
                                                                             key_layer,
                                                                             value_layer)
            else:
                query_layer, key_layer, value_layer = [x.transpose(0,1).contiguous()
                    for x in (query_layer, key_layer, value_layer)]
1651
        elif qkv_format == 'bshd':
1652
1653
1654
            query_layer, key_layer, value_layer = [x.contiguous()
                for x in (query_layer, key_layer, value_layer)]

1655
        global _cu_seqlens_q, _cu_seqlens_kv, _indices_q, _indices_kv
1656
        batch_size = query_layer.shape[0]
1657

1658
        if qkv_format in ['sbhd', 'bshd']:
1659
            max_seqlen_q, max_seqlen_kv = query_layer.shape[1], key_layer.shape[1]
1660
1661
1662
1663
1664
1665
1666
            if not context_parallel:
                # [b * s, h, d]
                query_layer, key_layer, value_layer = [
                    x.view(x.shape[0] * x.shape[1], *x.shape[2:])
                    for x in [query_layer, key_layer, value_layer]
                ]

1667
            if 'padding' in attn_mask_type:
1668
                assert not context_parallel, "Padding mask not supported with context parallelism!"
1669
1670
1671
1672
1673
1674

                if self.attention_type == "self":
                    assert (
                        max_seqlen_q == max_seqlen_kv
                    ), "Maximum sequence length for Q and KV should be the same."
                    if self.layer_number == 1:
1675
1676
1677
1678
1679
1680
1681
                        if cu_seqlens_q is None:
                            assert (attention_mask is not None
                                ), "Please provide attention_mask for padding!"
                            _cu_seqlens_q, _indices_q = get_cu_seqlens_and_indices(attention_mask)
                        else:
                            _cu_seqlens_q = cu_seqlens_q
                            _indices_q = get_indices(max_seqlen_q, cu_seqlens_q)
1682
1683
1684
1685
1686
1687
                    _cu_seqlens_kv = _cu_seqlens_q
                    query_layer_packed, key_layer_packed, value_layer_packed = PackTensors.apply(
                        _indices_q, query_layer, key_layer, value_layer
                    )
                else:
                    if self.layer_number == 1:
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
                        if cu_seqlens_q is None or cu_seqlens_kv is None:
                            assert (attention_mask is not None
                                ), "Please provide attention_mask for padding!"
                            _cu_seqlens_q, _indices_q = get_cu_seqlens_and_indices(
                                attention_mask[0])
                            _cu_seqlens_kv, _indices_kv = get_cu_seqlens_and_indices(
                                attention_mask[1])
                        else:
                            _cu_seqlens_q = cu_seqlens_q
                            _cu_seqlens_kv = cu_seqlens_kv
                            _indices_q = get_indices(max_seqlen_q, cu_seqlens_q)
                            _indices_kv = get_indices(max_seqlen_kv, cu_seqlens_kv)
1700
1701
1702
1703
1704
1705
1706
1707
                    query_layer_packed = PackTensors.apply(_indices_q, query_layer)
                    key_layer_packed, value_layer_packed = PackTensors.apply(
                        _indices_kv, key_layer, value_layer
                    )
                query_layer, key_layer, value_layer = (
                    query_layer_packed, key_layer_packed, value_layer_packed)
                cu_seqlens_q, cu_seqlens_kv = _cu_seqlens_q, _cu_seqlens_kv
            else:
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
                if self.layer_number == 1:
                    if cu_seqlens_q is None:
                        cu_seqlens_q = torch.arange(
                                0,
                                (batch_size + 1) * max_seqlen_q,
                                step=max_seqlen_q,
                                dtype=torch.int32,
                                device=query_layer.device)
                    if cu_seqlens_kv is None:
                        cu_seqlens_kv = torch.arange(
                                0,
                                (batch_size + 1) * max_seqlen_kv,
                                step=max_seqlen_kv,
                                dtype=torch.int32,
                                device=key_layer.device)
                    _cu_seqlens_q, _cu_seqlens_kv = cu_seqlens_q, cu_seqlens_kv
                else:
                    cu_seqlens_q, cu_seqlens_kv = _cu_seqlens_q, _cu_seqlens_kv
1726
        elif qkv_format == 'thd':
1727
            assert not context_parallel, "thd format not supported with context parallelism!"
1728
1729
            assert (cu_seqlens_q is not None and cu_seqlens_kv is not None
                ), "cu_seqlens_q and cu_seqlens_kv can not be None when qkv_format = thd!"
1730
1731
1732
1733
1734
1735
            if max_seqlen_q is None:
                seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
                max_seqlen_q = seqlens_q.max().item()
            if max_seqlen_kv is None:
                seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
                max_seqlen_kv = seqlens_kv.max().item()
1736

1737
        if context_parallel:
1738
1739
1740
            assert (
                window_size in ((-1, -1), (-1, 0))
                ), "Sliding window attention is not supported with context parallelism."
1741
1742
1743
            assert (
                alibi_slopes is None
            ), "Alibi slope bias addition is not supported with context parallelism."
1744
            with self.attention_dropout_ctx():
1745
1746
                output = attn_forward_func_with_cp(
                    self.training, query_layer, key_layer, value_layer,
1747
                    cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv,
1748
                    self.attention_dropout if self.training else 0.0,
1749
                    cp_group, cp_global_ranks, cp_stream,
1750
                    softmax_scale=1.0/self.norm_factor,
1751
                    attn_mask_type=attn_mask_type,
1752
                    deterministic=self.deterministic
1753
1754
                )
        else:
1755
1756
1757
1758
1759
1760
1761
1762

            from .cpu_offload import CPUOffloadEnabled
            if CPUOffloadEnabled:
                tensor_list = [query_layer, key_layer, value_layer, cu_seqlens_q, cu_seqlens_kv]
                for tensor in tensor_list:
                    if tensor is not None:
                        tensor.activation_offloading = True

1763
            with self.attention_dropout_ctx():
1764
                fa_optional_forward_kwargs = {}
1765
1766
                if _flash_attn_2_3_plus:
                    fa_optional_forward_kwargs["window_size"] = window_size
1767
1768
1769
1770
                if _flash_attn_2_4_plus:
                    fa_optional_forward_kwargs["alibi_slopes"] = alibi_slopes
                if _flash_attn_2_4_1_plus:
                    fa_optional_forward_kwargs["deterministic"] = self.deterministic
1771
                output = flash_attn_forward_func(
1772
                    query_layer, key_layer, value_layer,
1773
                    cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv,
1774
                    self.attention_dropout if self.training else 0.0,
1775
                    softmax_scale=1.0/self.norm_factor, causal="causal" in attn_mask_type,
1776
                    **fa_optional_forward_kwargs,
1777
                )
1778

1779
        if 'padding' in attn_mask_type:
1780
1781
            output = UnpackTensor.apply(_indices_q, batch_size * max_seqlen_q, output)

1782
1783
1784
        if qkv_format == 'sbhd':
            # (bs)hd -> bs(hd) -> sb(hd)
            output = output.view(batch_size, max_seqlen_q, -1).transpose(0, 1).contiguous()
1785
        elif qkv_format == 'bshd':
1786
1787
            # (bs)hd -> bs(hd)
            output = output.view(batch_size, max_seqlen_q, -1).contiguous()
1788
1789
1790
        elif qkv_format == 'thd':
            # thd -> t(hd)
            output = output.view(output.shape[0], -1).contiguous()
1791
1792

        return output
1793
1794


1795
1796
1797
1798
1799
1800
class FusedAttnFunc_qkvpacked(torch.autograd.Function):
    """Function for FusedAttention with packed QKV input"""

    @staticmethod
    def forward(ctx, is_training, max_seqlen, cu_seqlens, qkv, qkv_dtype, attn_bias, attn_scale,
                dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type,
1801
                rng_gen, fused_attention_backend, use_FAv2_bwd):
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
        out, aux_ctx_tensors = fused_attn_fwd_qkvpacked(
            is_training, max_seqlen, cu_seqlens, qkv, qkv_dtype,
            fused_attention_backend, attn_bias,
            None, None, None, None, None,
            attn_scale, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type,
            rng_gen)

        ctx.save_for_backward(qkv, out, cu_seqlens)
        ctx.aux_ctx_tensors = aux_ctx_tensors
        ctx.max_seqlen = max_seqlen
        ctx.qkv_dtype = qkv_dtype
        ctx.attn_scale = attn_scale
        ctx.dropout_p = dropout_p
        ctx.fast_zero_fill = fast_zero_fill
        ctx.qkv_layout = qkv_layout
        ctx.attn_bias_type = attn_bias_type
        ctx.attn_mask_type = attn_mask_type
        ctx.fused_attention_backend = fused_attention_backend
1820
        ctx.use_FAv2_bwd = use_FAv2_bwd
1821
1822
1823
1824
1825
1826

        return out

    @staticmethod
    def backward(ctx, d_out):
        qkv, out, cu_seqlens = ctx.saved_tensors
1827
1828
        if not ctx.aux_ctx_tensors[0].is_contiguous():
            ctx.aux_ctx_tensors[0] = ctx.aux_ctx_tensors[0].contiguous()
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
        if ctx.use_FAv2_bwd:
            softmax_lse, rng_state = ctx.aux_ctx_tensors
            dqkv = torch.empty_like(qkv)
            maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
            d_out, q, k, v, out = [maybe_contiguous(x)
                for x in (d_out, qkv[:,0], qkv[:,1], qkv[:,2], out)]
            flash_attn_cuda_bwd(
                d_out, q, k, v, out, softmax_lse, dqkv[:,0], dqkv[:,1], dqkv[:,2],
                cu_seqlens, cu_seqlens, ctx.max_seqlen, ctx.max_seqlen,
                ctx.dropout_p, ctx.attn_scale, False,
1839
                "causal" in ctx.attn_mask_type, None, rng_state
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
            )
            dqkv = dqkv[..., :d_out.shape[-1]]
        else:
            dqkv, *rest = fused_attn_bwd_qkvpacked(
                ctx.max_seqlen, cu_seqlens, qkv, out, d_out,
                ctx.qkv_dtype, ctx.aux_ctx_tensors,
                ctx.fused_attention_backend,
                None, None, None, None, None, None, None, None, None,
                ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill,
                ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type)
1850

1851
1852
        # if no_bias or alibi, return dqkv
        if ctx.attn_bias_type in ["no_bias", "alibi"]:
1853
1854
1855
1856
1857
1858
1859
1860
            return (None, None, None, dqkv, None, None, None,
                    None, None, None, None, None, None,
                    None, None, None, None, None, None)
        # else, return (dqkv, dbias)
        return (None, None, None, dqkv, None, rest[0], None,
                None, None, None, None, None, None,
                None, None, None, None, None, None)

1861

1862
1863
1864
1865
1866
1867
1868
class FusedAttnFunc_kvpacked(torch.autograd.Function):
    """Function for FusedAttention with packed KV input"""

    @staticmethod
    def forward(ctx, is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
                q, kv, qkv_dtype, attn_bias, attn_scale, dropout_p, fast_zero_fill,
                qkv_layout, attn_bias_type, attn_mask_type,
1869
                rng_gen, fused_attention_backend, use_FAv2_bwd):
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
        out, aux_ctx_tensors = fused_attn_fwd_kvpacked(
            is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
            q, kv, qkv_dtype, fused_attention_backend, attn_bias,
            None, None, None, None, None,
            attn_scale, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type,
            rng_gen)

        ctx.save_for_backward(q, kv, out, cu_seqlens_q, cu_seqlens_kv)
        ctx.aux_ctx_tensors = aux_ctx_tensors
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_kv = max_seqlen_kv
        ctx.qkv_dtype = qkv_dtype
        ctx.attn_scale = attn_scale
        ctx.dropout_p = dropout_p
        ctx.fast_zero_fill = fast_zero_fill
        ctx.qkv_layout = qkv_layout
        ctx.attn_bias_type = attn_bias_type
        ctx.attn_mask_type = attn_mask_type
        ctx.fused_attention_backend = fused_attention_backend
1889
        ctx.use_FAv2_bwd = use_FAv2_bwd
1890
1891
1892
1893
1894
1895

        return out

    @staticmethod
    def backward(ctx, d_out):
        q, kv, out, cu_seqlens_q, cu_seqlens_kv = ctx.saved_tensors
1896
1897
        if not ctx.aux_ctx_tensors[0].is_contiguous():
            ctx.aux_ctx_tensors[0] = ctx.aux_ctx_tensors[0].contiguous()
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
        if ctx.use_FAv2_bwd:
            softmax_lse, rng_state = ctx.aux_ctx_tensors
            dq = torch.empty_like(q)
            dkv = torch.empty_like(kv)
            maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
            d_out, q, k, v, out = [maybe_contiguous(x)
                for x in (d_out, q, kv[:,0], kv[:,1], out)]
            flash_attn_cuda_bwd(
                d_out, q, k, v, out, softmax_lse, dq, dkv[:,0], dkv[:,1],
                cu_seqlens_q, cu_seqlens_kv, ctx.max_seqlen_q, ctx.max_seqlen_kv,
                ctx.dropout_p, ctx.attn_scale, False,
1909
                "causal" in ctx.attn_mask_type, None, rng_state
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
            )
            dq = dq[..., :d_out.shape[-1]]
            dkv = dkv[..., :d_out.shape[-1]]
        else:
            dq, dkv, *rest = fused_attn_bwd_kvpacked(
                ctx.max_seqlen_q, ctx.max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
                q, kv, out, d_out,
                ctx.qkv_dtype, ctx.aux_ctx_tensors,
                ctx.fused_attention_backend,
                None, None, None, None, None, None, None, None, None,
                ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill,
                ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type)
1922

1923
1924
        # if no_bias or alibi, return dqkv
        if ctx.attn_bias_type in ["no_bias", "alibi"]:
1925
1926
1927
1928
1929
1930
1931
1932
            return (None, None, None, None, None, dq, dkv, None, None, None,
                    None, None, None, None, None, None,
                    None, None, None, None, None, None)
        # else, return (dqkv, dbias)
        return (None, None, None, None, None, dq, dkv, None, rest[0], None,
                None, None, None, None, None, None,
                None, None, None, None, None, None)

1933

1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
class FusedAttnFunc(torch.autograd.Function):
    """Function for FusedAttention with separate Q, K, V tensors"""

    @staticmethod
    def forward(ctx, is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
                q, k, v, qkv_dtype, attn_bias, attn_scale, dropout_p, fast_zero_fill,
                qkv_layout, attn_bias_type, attn_mask_type,
                rng_gen, fused_attention_backend, use_FAv2_bwd):
        out, aux_ctx_tensors = fused_attn_fwd(
            is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
            q, k, v, qkv_dtype, fused_attention_backend, attn_bias,
            None, None, None, None, None,
            attn_scale, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type,
            rng_gen)

1949
1950
1951
1952
1953
1954
1955
1956
1957
        from .cpu_offload import CPUOffloadEnabled
        if CPUOffloadEnabled:
            tensor_list = [q, k, v, out, cu_seqlens_q, cu_seqlens_kv]
            qkv_layout = 'sbhd_sbhd_sbhd'
            for tensor in tensor_list:
                if tensor is not None:
                    tensor.activation_offloading = True


1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
        ctx.save_for_backward(q, k, v, out, cu_seqlens_q, cu_seqlens_kv)
        ctx.aux_ctx_tensors = aux_ctx_tensors
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_kv = max_seqlen_kv
        ctx.qkv_dtype = qkv_dtype
        ctx.attn_scale = attn_scale
        ctx.dropout_p = dropout_p
        ctx.fast_zero_fill = fast_zero_fill
        ctx.qkv_layout = qkv_layout
        ctx.attn_bias_type = attn_bias_type
        ctx.attn_mask_type = attn_mask_type
        ctx.fused_attention_backend = fused_attention_backend
        ctx.use_FAv2_bwd = use_FAv2_bwd

        return out

    @staticmethod
    def backward(ctx, d_out):
        q, k, v, out, cu_seqlens_q, cu_seqlens_kv = ctx.saved_tensors
1977
1978
        if not ctx.aux_ctx_tensors[0].is_contiguous():
            ctx.aux_ctx_tensors[0] = ctx.aux_ctx_tensors[0].contiguous()
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
        if ctx.use_FAv2_bwd:
            softmax_lse, rng_state = ctx.aux_ctx_tensors
            dq = torch.empty_like(q)
            dk = torch.empty_like(k)
            dv = torch.empty_like(v)
            maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
            d_out, q, k, v, out = [maybe_contiguous(x)
                for x in (d_out, q, k, v, out)]
            flash_attn_cuda_bwd(
                d_out, q, k, v, out, softmax_lse, dq, dk, dv,
                cu_seqlens_q, cu_seqlens_kv, ctx.max_seqlen_q, ctx.max_seqlen_kv,
                ctx.dropout_p, ctx.attn_scale, False,
1991
                "causal" in ctx.attn_mask_type, None, rng_state
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
            )
            dq = dq[..., :d_out.shape[-1]]
            dk = dk[..., :d_out.shape[-1]]
            dv = dv[..., :d_out.shape[-1]]
        else:
            dq, dk, dv, *rest = fused_attn_bwd(
                ctx.max_seqlen_q, ctx.max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
                q, k, v, out, d_out,
                ctx.qkv_dtype, ctx.aux_ctx_tensors,
                ctx.fused_attention_backend,
                None, None, None, None, None, None, None, None, None,
                ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill,
                ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type)

2006
2007
        # if no_bias or alibi, return dqkv
        if ctx.attn_bias_type in ["no_bias", "alibi"]:
2008
2009
2010
2011
2012
2013
2014
2015
            return (None, None, None, None, None, dq, dk, dv, None, None, None,
                    None, None, None, None, None, None,
                    None, None, None, None, None, None)
        # else, return (dqkv, dbias)
        return (None, None, None, None, None, dq, dk, dv, None, rest[0], None,
                None, None, None, None, None, None,
                None, None, None, None, None, None)

2016

2017
2018
2019
2020
2021
2022
2023
2024
2025
2026
class FusedAttention(torch.nn.Module):
    """Dot product attention, with multiple backends:

    1. FusedAttnBackend["F16_max512_seqlen"]
       cuDNN based fused attention for FP16/BF16 and <=512 sequence length.
    2. FusedAttnBackend["F16_arbitrary_seqlen"]
       cuDNN based fused attention for FP16/BF16 and any sequence length.

    Support matrix:

2027
2028
2029
2030
    | backend       | 1                       | 2                              |
    | flash based   | no                      | yes                            |
    | cuDNN based   | yes                     | yes                            |
    | qkv dtype     | fp16/bf16               | fp16/bf16                      |
2031
    | attn_type     | self/cross              | self/cross                     |
2032
    | qkv_layout    |                         |                                |
2033
    |  - (q,k,v)    | sb3hd, bs3hd            | sb3hd, bs3hd, sbh3d, bsh3d     |
2034
    |               | sbhd_sb2hd, bshd_bs2hd  | sbhd_sb2hd, bshd_bs2hd         |
2035
2036
    |               | bshd_bshd_bshd          | sbhd_sbh2d, bshd_bsh2d         |
    |               |                         | sbhd_sbhd_sbhd, bshd_bshd_bshd |
2037
2038
    | mask_type     | causal/padding/no_mask  | causal/padding/no_mask         |
    | bias_type     | post_scale_bias/no_bias | post_scale_bias/alibi/no_bias  |
2039
    | dropout       | yes                     | yes                            |
2040
2041
    | max_seqlen    | <=512, multiple of 64   | any, multiple of 64            |
    | head_dim      | 64                      | <=128, multiple of 8           |
2042
    | output dtype  | fp16/bf16               | fp16/bf16                      |
2043
2044
2045
2046
2047
2048
2049
2050
    """

    def __init__(
        self,
        norm_factor: float,
        attention_dropout: float = 0.0,
        attention_dropout_ctx: Optional[Callable] = nullcontext,
        attention_type: str = "self",
2051
2052
        layer_number: Optional[int] = None,
        deterministic: bool = False,
2053
2054
2055
2056
2057
2058
2059
    ) -> None:
        super().__init__()

        self.norm_factor = norm_factor
        self.attention_dropout = attention_dropout
        self.attention_dropout_ctx = attention_dropout_ctx
        self.attention_type = attention_type
2060
        self.use_FAv2_bwd = (os.getenv("NVTE_FUSED_ATTN_USE_FAv2_BWD", "0") == "1"
Tim Moon's avatar
Tim Moon committed
2061
                        and get_device_compute_capability() == (9, 0))
2062
2063
2064
2065
2066
2067
2068
2069
2070
2071
2072
2073
2074
2075
2076
2077
        self.layer_number = 1 if layer_number is None else layer_number
        if deterministic:
            # workspace optimization path is deterministic
            os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "-1"

        # CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT
        # - unset:       enables workspace optimization when required workspace is <= 256MB
        #                or when bias gradient needs to be computed
        # - n:           enables workspace optimization when required workspace is <= n bytes
        # - -1:          enables workspace optimization always
        # - 0:           disables workspace optimization always
        if "NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT" in os.environ:
            if os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] == "0":
                os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "0"
            if os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] == "1":
                os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "-1"
2078

2079
    @no_torch_dynamo()
2080
2081
2082
2083
2084
    def forward(
        self,
        query_layer: torch.Tensor,
        key_layer: torch.Tensor,
        value_layer: torch.Tensor,
2085
2086
2087
        qkv_layout: str = "sbh3d",
        cu_seqlens_q: Optional[torch.Tensor] = None,
        cu_seqlens_kv: Optional[torch.Tensor] = None,
2088
2089
        max_seqlen_q: Optional[int] = None,
        max_seqlen_kv: Optional[int] = None,
2090
        attn_mask_type: str = "causal",
2091
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
2092
2093
        fused_attention_backend:
            tex.NVTE_Fused_Attn_Backend = tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend,
2094
2095
2096
        core_attention_bias_type: str = "no_bias",
        core_attention_bias: Optional[torch.Tensor] = None,
        fast_zero_fill: bool = True,
2097
2098
2099
        cp_group: Optional[dist_group_type] = None,
        cp_global_ranks: List[int] = None,
        cp_stream: torch.cuda.Stream = None,
2100
2101
2102
    ) -> torch.Tensor:
        """fused attention fprop"""

2103
        assert (fused_attention_backend
2104
2105
            != tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend
            ), 'No fused attention backend supports this input combination!'
2106
2107
2108
2109
2110
2111
2112
2113
        assert (
            (query_layer.dtype in [torch.float16, torch.bfloat16])
            and (key_layer.dtype in [torch.float16, torch.bfloat16])
            and (value_layer.dtype in [torch.float16, torch.bfloat16])
            ), 'FusedAttention only supports FP16 and BF16 data types.'
        assert (
            query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda
            ), 'FusedAttention only supports CUDA tensors.'
2114
2115
2116
2117
        assert (
            qkv_layout in QKVLayouts
            ), f"FusedAttention does not support qkv_layout = {qkv_layout}!"

2118
2119
        context_parallel = (cp_group is not None) and (get_distributed_world_size(cp_group) != 1)

2120
        qkv_format = ''.join([i for i in qkv_layout.split('_')[0] if i.isalpha()])
2121
2122
2123
2124
        assert (
            qkv_format != 'thd'
            ), 'FusedAttention does not support qkv_format = thd!'

2125
2126
2127
2128
2129
2130
2131
        if qkv_format in ['sbhd', 'bshd']:
            if qkv_format == 'sbhd':
                batch_size, max_seqlen_q, max_seqlen_kv = (
                    query_layer.shape[1], query_layer.shape[0], key_layer.shape[0])
            if qkv_format == 'bshd':
                batch_size, max_seqlen_q, max_seqlen_kv = (
                    query_layer.shape[0], query_layer.shape[1], key_layer.shape[1])
2132
            if 'padding' in attn_mask_type:
2133
2134
                assert not context_parallel, "Padding mask not supported with context parallelism!"

2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
2146
2147
2148
2149
2150
2151
2152
2153
2154
2155
2156
2157
2158
2159
2160
2161
2162
2163
2164
2165
2166
2167
2168
2169
2170
                global _cu_seqlens_q, _cu_seqlens_kv
                if (cu_seqlens_q is not None and cu_seqlens_kv is not None):
                    # use cu_seqlens when both cu_seqlens and attention_mask are present
                    if self.layer_number == 1:
                        _cu_seqlens_q, _cu_seqlens_kv = cu_seqlens_q, cu_seqlens_kv
                elif attention_mask is not None:
                    if self.attention_type == "self":
                        if self.layer_number == 1:
                            _cu_seqlens_q = get_cu_seqlens(attention_mask)
                            _cu_seqlens_kv = _cu_seqlens_q
                    else:
                        if self.layer_number == 1:
                            _cu_seqlens_q = get_cu_seqlens(attention_mask[0])
                            _cu_seqlens_kv = get_cu_seqlens(attention_mask[1])
                else:
                    raise Exception("Please provide attention_mask or cu_seqlens for padding!")
                cu_seqlens_q, cu_seqlens_kv = _cu_seqlens_q, _cu_seqlens_kv
            else:
                if self.layer_number == 1:
                    if cu_seqlens_q is None:
                        cu_seqlens_q = torch.arange(
                                0,
                                (batch_size + 1) * max_seqlen_q,
                                step=max_seqlen_q,
                                dtype=torch.int32,
                                device=query_layer.device)
                    if cu_seqlens_kv is None:
                        cu_seqlens_kv = torch.arange(
                                0,
                                (batch_size + 1) * max_seqlen_kv,
                                step=max_seqlen_kv,
                                dtype=torch.int32,
                                device=key_layer.device)
                    _cu_seqlens_q, _cu_seqlens_kv = cu_seqlens_q, cu_seqlens_kv
                else:
                    cu_seqlens_q, cu_seqlens_kv = _cu_seqlens_q, _cu_seqlens_kv
2171
2172
2173

        qkv_dtype = TE_DType[query_layer.dtype]

2174
        use_FAv2_bwd = (self.use_FAv2_bwd
2175
                and (core_attention_bias_type == "no_bias")
2176
2177
                and (fused_attention_backend
                    == tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen))
2178
2179
2180
2181
2182
2183
2184
2185
2186
2187
2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
2198
2199
2200
2201
2202
2203
2204
2205
2206
2207
2208
2209
2210
2211
2212
2213
2214
2215
2216
2217
2218
2219
2220

        if context_parallel:
            assert (fused_attention_backend
                == tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen
                ), f"{fused_attention_backend} does not work with context parallelism!"
            assert (core_attention_bias_type == "no_bias"), \
                "Core attention bias has not been supported with context parallelism yet!"
            if qkv_format == 'sbhd':
                query_layer, key_layer, value_layer = [x.transpose(0,1).contiguous()
                    for x in (query_layer, key_layer, value_layer)]
            with self.attention_dropout_ctx():
                output = attn_forward_func_with_cp(
                    self.training,
                    query_layer, key_layer, value_layer,
                    cu_seqlens_q, cu_seqlens_kv,
                    max_seqlen_q, max_seqlen_kv,
                    self.attention_dropout if self.training else 0.0,
                    cp_group, cp_global_ranks, cp_stream,
                    softmax_scale=1.0/self.norm_factor,
                    attn_mask_type=attn_mask_type,
                    use_fused_attention=True,
                )
            if qkv_format == 'sbhd':
                output = output.transpose(0,1).contiguous()
        else:
            with self.attention_dropout_ctx():
                output = FusedAttnFunc.apply(
                    self.training,
                    max_seqlen_q, max_seqlen_kv,
                    cu_seqlens_q, cu_seqlens_kv,
                    query_layer, key_layer, value_layer,
                    qkv_dtype,
                    core_attention_bias,
                    1.0/self.norm_factor,
                    self.attention_dropout if self.training else 0.0,
                    fast_zero_fill,
                    qkv_layout,
                    core_attention_bias_type,
                    attn_mask_type,
                    None, # rng_gen
                    fused_attention_backend,
                    use_FAv2_bwd,
                )
2221

2222
2223
        # ...hd -> ...(hd)
        return output.view(*output.shape[:-2], -1)
2224
2225


2226
2227
2228
2229
2230
2231
2232
class DotProductAttention(torch.nn.Module):
    """Allows the model to jointly attend to information from different
    representation subspaces as described in the paper:
    `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.

    .. note::

2233
        Argument :attr:`attention_mask` in the `forward` call is only used when
2234
        :attr:`attn_mask_type` includes '"padding"' or `"arbitrary"`.
2235
2236
2237

    .. warning::

2238
        FlashAttention uses a non-deterministic algorithm for optimal performance. To observe
2239
        deterministic behavior at the cost of performance, use FlashAttention version >= `2.4.1`
2240
2241
        and set the environment variable :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO=0`. In order
        to disable`flash-attn` entirely, set :attr:`NVTE_FLASH_ATTN=0`.
2242
2243
2244
2245
2246
2247
2248

    Parameters
    ----------
    num_attention_heads : int
                         number of attention heads in the transformer layer.
    kv_channels : int
                number of key-value channels.
2249
2250
2251
2252
2253
2254
2255
2256
    num_gqa_groups : Optional[int] = None
                    number of GQA groups in the transformer layer.
                    Grouped Query Attention is described in
                    `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
                    This only affects the keys and values, not the queries.
                    GQA-1 is equivalent to Multi-Query Attention
                    (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
                    is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
2257
2258
    attention_dropout: float, default = 0.0
                      dropout probability for the dropout op during multi-head attention.
2259
    attn_mask_type: str, default = `causal`
2260
2261
2262
2263
2264
2265
2266
2267
2268
2269
2270
2271
2272
2273
                   type of attention mask passed into softmax operation, options are "`no_mask`",
                   "`padding`", "`causal`", "`padding,causal`", "`causal,padding`", and
                   "`arbitrary`", where "`padding,causal`" and "`causal,padding`" are equivalent.
                   This arg can be overridden by :attr:`attn_mask_type` in the `forward` method.
                   It is useful for cases involving compilation/tracing, e.g. ONNX export, and the
                   forward arg is useful for dynamically changing mask types, e.g. a different mask
                   for training and inference. For "`no_mask`", no attention mask is applied. For
                   "`causal`" or the causal mask in "`padding,causal`", TransformerEngine calculates
                   and applies an upper triangular mask to the softmax input. No user input is
                   needed. For "`padding`" or the padding mask in "`padding,causal`", users need to
                   provide the locations of padded tokens either via :attr:`cu_seqlens_q` and
                   :attr:`cu_seqlens_kv` in the shape of [batch_size + 1] or :attr:`attention_mask`
                   in the shape [batch_size, 1, 1, max_seq_len]. For the "`arbitrary`" mask, users
                   need to provide a mask that is broadcastable to the shape of softmax input.
2274
2275
2276
2277
2278
2279
    window_size: Optional[Tuple[int, int]], default = `None`
                sliding window size for local attention, where query at position i attends to keys
                in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
                + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
                window and causal mask specifically. Similar to :attr:`attn_mask_type`, it can
                be overridden by :attr:`window_size` in `forward` as well.
2280
2281
    attention_type: str, default = `self`
                   type of attention, either "`self`" and "`cross`".
2282
2283
2284
    layer_number: int, default = `None`
                 layer number of the current `DotProductAttention` when multiple such modules
                 are concatenated, for instance in consecutive transformer blocks.
2285
2286
2287
2288
2289
2290
2291
2292
2293
2294
    qkv_format: str, default = `sbhd`
               dimension format for `query_layer`, `key_layer` and `value_layer`,
               {`sbhd`, `bshd`, `thd`}. `s` stands for the sequence length, `b` batch size,
               `h` the number of heads, `d` head size, and `t` the total number of sequences
               in a batch, with `t = sum(s_i), for i = 0...b-1`. `sbhd` and `bshd` formats
               are used for when sequences in a batch are of equal length or padded to
               equal length, and the `thd` format is used for when sequences in a batch
               have different lengths. Please note that these formats do not reflect how
               tensors `query_layer`, `key_layer`, `value_layer` are laid out in memory.
               For that, please use `_get_qkv_layout` to gain the layout information.
2295
2296
2297
2298
2299
2300
2301
2302
2303

    Parallelism parameters
    ----------------------
    sequence_parallel : bool, default = `False`
                       if set to `True`, uses sequence parallelism.
    tp_size : int, default = 1
             tensor parallel world size.
    tp_group : ProcessGroup, default = `None`
              tensor parallel process group.
2304
2305
2306
2307
2308
2309
2310
2311
2312
    cp_group : ProcessGroup, default = `None`
              context parallel process group.
    cp_global_ranks : list of global rank IDs, default = `None`
                     global rank IDs of GPUs that are in cp_group.
    cp_stream : CUDA stream, default = `None`
               context parallelism splits flash attention into multiple steps for
               compute and communication overlapping. To address the wave quantization
               issue of each split step, we add an additional CUDA stream so that we
               can overlap two flash attention kernels.
2313
2314
2315
2316
2317
2318
    """

    def __init__(
        self,
        num_attention_heads: int,
        kv_channels: int,
2319
        num_gqa_groups: Optional[int] = None,
2320
        attention_dropout: float = 0.0,
2321
        qkv_format: str = "sbhd",
2322
        attn_mask_type: str = "causal",
2323
        window_size: Optional[Tuple[int, int]] = None,
2324
2325
2326
2327
2328
        sequence_parallel: bool = False,
        tp_size: int = 1,
        get_rng_state_tracker: Optional[Callable] = None,
        tp_group: Optional[dist_group_type] = None,
        layer_number: Optional[int] = None,
2329
        attention_type: str = "self",
2330
        cp_group: Optional[dist_group_type] = None,
2331
        cp_global_ranks: List[int] = None,
2332
        cp_stream: torch.cuda.Stream = None,
2333
2334
2335
    ) -> None:
        super().__init__()

2336
        self.qkv_format = qkv_format
2337
2338
2339
        attn_mask_type = attn_mask_type.replace(",","_")
        if attn_mask_type == "causal_padding":
            attn_mask_type = "padding_causal"
2340
        self.attn_mask_type = attn_mask_type
2341
2342
        self.window_size = window_size
        self.window_size = check_set_window_size(attn_mask_type, self.window_size)
2343
        self.tp_size = tp_size if tp_group is None else get_distributed_world_size(tp_group)
2344
2345
        self.tp_group = tp_group
        self.get_rng_state_tracker = get_rng_state_tracker
2346
        self.num_attention_heads = num_attention_heads
2347
2348
2349
        self.cp_group = cp_group
        self.cp_global_ranks = cp_global_ranks
        self.cp_stream = cp_stream
2350

2351
2352
2353
        self.hidden_size_per_attention_head = kv_channels
        self.num_gqa_groups = (
            num_attention_heads if num_gqa_groups is None else num_gqa_groups
2354
        )
2355
2356
2357
2358
        self.num_gqa_groups_per_partition = int(self.num_gqa_groups // tp_size)

        assert (num_attention_heads % self.num_gqa_groups == 0
                ), "The number of attention heads must be divisible by the number of GQA groups!"
2359
2360
2361
2362
2363
2364
2365
2366
2367

        if sequence_parallel or get_rng_state_tracker is None:
            attention_dropout_ctx = nullcontext
        else:
            attention_dropout_ctx = get_rng_state_tracker().fork

        norm_factor = math.sqrt(self.hidden_size_per_attention_head)

        self.device_compute_capability = get_device_compute_capability()
2368
2369
        self.deterministic = not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1"))) \
                             or torch.are_deterministic_algorithms_enabled()
2370

2371
2372
        self.use_flash_attention = (
            int(os.getenv("NVTE_FLASH_ATTN", "1"))
Tim Moon's avatar
Tim Moon committed
2373
            and self.device_compute_capability >= (8, 0)
2374
        )
2375
        if not _flash_attn_2_4_1_plus and self.deterministic:
2376
2377
            self.use_flash_attention = False
            warnings.warn(
2378
2379
2380
                "Disabling usage of FlashAttention since version <2.4.1 does not support "
                "deterministic execution. In order to use FA with deterministic behavior,"
                " please install FlashAttention version >=2.4.1."
2381
2382
            )

2383
2384
        self.use_fused_attention = (
            int(os.getenv("NVTE_FUSED_ATTN", "1"))
Tim Moon's avatar
Tim Moon committed
2385
            and self.device_compute_capability >= (8, 0)
2386
        )
2387

2388
2389
2390
2391
2392
2393
2394
        assert (
            attention_type in AttnTypes
        ), f"attention_type {attention_type} not supported"

        self.attention_type = attention_type
        self.attention_dropout = attention_dropout

2395
2396
2397
2398
2399
2400
        attn_kwargs = {
            "attention_dropout": attention_dropout,
            "attention_dropout_ctx": attention_dropout_ctx,
        }

        if self.use_flash_attention:
2401
2402
2403
2404
2405
2406
            self.flash_attention = FlashAttention(norm_factor,
                                                  attention_type=attention_type,
                                                  layer_number=layer_number,
                                                  deterministic=self.deterministic,
                                                  **attn_kwargs)

2407
        # Instantiating three types since use of flash-attn and FusedAttention
2408
        # might be ruled out due to forward inputs.
2409
        if self.use_fused_attention:
2410
2411
2412
2413
2414
            self.fused_attention = FusedAttention(norm_factor,
                                                  attention_type=attention_type,
                                                  layer_number=layer_number,
                                                  deterministic=self.deterministic,
                                                  **attn_kwargs)
2415
2416
2417
2418
2419
2420
2421
        self.unfused_attention = UnfusedDotProductAttention(
            norm_factor, **attn_kwargs, layer_number=layer_number)

    def _checkpointed_attention_forward(
        self,
        attention_func: Callable,
        *forward_args: Tuple[torch.Tensor, ...],
2422
        **forward_kwargs: Dict[str, Any],
2423
2424
2425
    ) -> torch.Tensor:
        """Forward method with activation checkpointing."""

2426
2427
        def custom_forward(*input_args, **input_kwargs):
            return attention_func(*input_args, **input_kwargs)
2428
2429
2430
2431
2432
2433
2434

        hidden_states = checkpoint(
            custom_forward,
            False,
            self.get_rng_state_tracker,
            self.tp_group,
            *forward_args,
2435
            **forward_kwargs,
2436
2437
2438
2439
        )

        return hidden_states

2440
2441
2442
2443
2444
2445
    def set_context_parallel_group(
        self,
        cp_group: Union[dist_group_type, None],
        cp_global_ranks: List[int],
        cp_stream: torch.cuda.Stream,
    ) -> None:
2446
2447
2448
2449
2450
2451
2452
2453
2454
2455
2456
2457
2458
        """
        Set the context parallel attributes for the given
        module before executing the forward pass.

        Parameters
        ----------
        cp_group : ProcessGroup
                  context parallel process group.
        cp_global_ranks : List[int]
                         list of global ranks in the context group.
        cp_stream : torch.cuda.Stream
                   cuda stream for context parallel execution.
        """
2459
2460
2461
2462
        self.cp_group = cp_group
        self.cp_global_ranks = cp_global_ranks
        self.cp_stream = cp_stream

2463
    @no_torch_dynamo(recursive=False)
2464
2465
2466
2467
2468
    def forward(
        self,
        query_layer: torch.Tensor,
        key_layer: torch.Tensor,
        value_layer: torch.Tensor,
2469
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
2470
2471
2472
        qkv_format: Optional[str] = None,
        cu_seqlens_q: Optional[torch.Tensor] = None,
        cu_seqlens_kv: Optional[torch.Tensor] = None,
2473
2474
        max_seqlen_q: Optional[int] = None,
        max_seqlen_kv: Optional[int] = None,
2475
        attn_mask_type: Optional[str] = None,
2476
        window_size: Optional[Tuple[int, int]] = None,
2477
        alibi_slopes: Optional[torch.Tensor] = None,
2478
        checkpoint_core_attention: bool = False,
2479
2480
2481
        core_attention_bias_type: str = "no_bias",
        core_attention_bias: Optional[torch.Tensor] = None,
        fast_zero_fill: bool = True,
2482
2483
2484
2485
2486
2487
    ) -> torch.Tensor:
        """
        Dot Product Attention Layer.

        .. note::

2488
2489
            Argument :attr:`attention_mask` is only used when :attr:`attn_mask_type`
            includes '"padding"' or `"arbitrary"`.
2490
2491
2492
2493
2494
2495
2496
2497
2498

        .. note::

            Input tensors :attr:`query_layer`, :attr:`key_layer`, and :attr:`value_layer`
            must each be of shape (:attr:`sequence_length`, :attr:`batch_size`,
            :attr:`num_attention_heads`, :attr:`kv_channels`). Output of shape
            (:attr:`sequence_length`, :attr:`batch_size`, :attr:`num_attention_heads`
            * :attr:`kv_channels`) is returned.

2499
2500
        .. note::

2501
2502
2503
2504
2505
2506
2507
2508
2509
2510
2511
2512
2513
2514
2515
2516
2517
2518
            DotProductAttention supports three backends: 1) FlashAttention which calls
            HazyResearch/Dao-AILab's `flash-attn <https://arxiv.org/pdf/2305.13245.pdf>`_
            PyTorch API, 2) FusedAttention which has multiple fused attention implementations
            based on `cuDNN Graph API
            <https://docs.nvidia.com/deeplearning/cudnn/developer-guide/index.html#op-fusion>`_
            (see :attr:`FusedAttention` for more details on FusedAttention backends), and 3)
            UnfusedDotProductAttention which is the native PyTorch implementation
            with fused scaled masked softmax.

        .. note::

            Users can use environment variables :attr:`NVTE_FLASH_ATTN`, :attr:`NVTE_FUSED_ATTN`,
            and :attr:`NVTE_FUSED_ATTN_BACKEND` to control which DotProductAttention backend,
            and FusedAttention backend if applicable, to use. TransformerEngine prioritizes
            FlashAttention over FusedAttention and over UnfusedDotProductAttention.
            If FusedAttention is being used, users can also choose to switch to flash-attn's
            implementation for backward by setting :attr:`NVTE_FUSED_ATTN_USE_FAv2_BWD=1`
            (default: 0), because of the performance differences between various versions of
2519
2520
2521
2522
2523
            flash-attn and FusedAttention. Further, :attr:`NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT`
            can be used to enable (:attr:`1`) or disable (:attr:`0`) the workspace related
            optimizations in FusedAttention. When unset, TransformerEngine determines the code path
            based on its internal logic. These optimizations trade memory for performance
            and should be used with care.
2524

2525
2526
2527
2528
2529
2530
2531
2532
        Parameters
        ----------
        query_layer : torch.Tensor
                     Query tensor.
        key_layer : torch.Tensor
                   Key tensor.
        value_layer : torch.Tensor
                     Value tensor.
2533
2534
2535
2536
2537
2538
2539
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]],
             default = `None`. Boolean tensor(s) used to mask out attention softmax input.
             It should be 'None' for 'causal' and 'no_mask' types. For 'padding' masks, it should be
             a single tensor of [batch_size, 1, 1, seqlen_q] for self-attention, and a tuple of
             two tensors in shapes [batch_size, 1, 1, seqlen_q] and [batch_size, 1, 1, seqlen_kv]
             for cross-attention. For the 'arbitrary' mask type, it should be in a shape that is
             broadcastable to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv].
2540
2541
2542
2543
2544
2545
2546
2547
        qkv_format: str, default = `None`
                   If provided, overrides :attr:`qkv_format` from initialization.
        cu_seqlens_q: Optional[torch.Tensor], default = `None`
                   Cumulative sum of sequence lengths in a batch for `query_layer`,
                   with shape [batch_size + 1] and dtype torch.int32.
        cu_seqlens_kv: Optional[torch.Tensor], default = `None`
                   Cumulative sum of sequence lengths in a batch for `key_layer` and `value_layer`,
                   with shape [batch_size + 1] and dtype torch.int32.
2548
2549
2550
2551
2552
2553
        max_seqlen_q: Optional[int], default = `None`
                      Maximum sequence length in `query_layer`.
                      Calculated from `cu_seqlens_q` if not provided.
        max_seqlen_kv: Optional[int], default = `None`
                       Maximum sequence length in `key_layer` and `value_layer`.
                       Calculated from `cu_seqlens_kv` if not provided.
2554
2555
2556
        attn_mask_type: {`no_mask`, `padding`, `causal`, `padding,causal`, `causal,padding`,
                       `arbitrary`}, default = `None`. Type of attention mask passed into
                       softmax operation. 'padding,causal' and 'causal,padding' are equivalent.
2557
2558
        window_size: Optional[Tuple[int, int]], default = `None`
                    sliding window size for local attention.
2559
2560
2561
2562
        alibi_slopes: Optional[torch.Tensor], default = `None`
                     An fp32 bias of shape (nheads,) or (batch_size, nheads)
                     (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
                     is added to the attention score of query i and key j.
2563
2564
2565
2566
2567
        checkpoint_core_attention : bool, default = `False`
                                   If true, forward activations for attention are recomputed
                                   during the backward pass in order to save memory that would
                                   otherwise be occupied to store the forward activations until
                                   backprop.
2568
        core_attention_bias_type: str, default = `no_bias`
2569
                    Bias type, {`no_bias`, `pre_scale_bias`, `post_scale_bias`, `alibi`}
2570
        core_attention_bias: Optional[torch.Tensor], default = `None`
2571
2572
                    Bias tensor for Q * K.T, shape [1, num_head, max_seqlen_q, max_seqlen_kv].
                    It should be 'None' for 'no_bias' and 'alibi' bias types.
2573
        fast_zero_fill: bool, default = `True`
2574
                    Whether to use the fast path to set output tensors to 0 or not.
2575
2576
        """

2577
2578
2579
2580
        assert (
            query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda
            ), 'DotProductAttention only supports CUDA tensors.'

2581
2582
2583
        assert (key_layer.shape == value_layer.shape
            ), "Keys and values must have the same shape!"

2584
2585
        if attn_mask_type is not None:
            window_size = check_set_window_size(attn_mask_type, window_size)
2586
        if attn_mask_type is None:
2587
            attn_mask_type = self.attn_mask_type
2588
2589
2590
2591
2592
2593
2594
2595
        else:
            attn_mask_type = attn_mask_type.replace(",","_")
            if attn_mask_type == "causal_padding":
                attn_mask_type = "padding_causal"

        assert (attn_mask_type in AttnMaskTypes
            ), f"Attention mask type {attn_mask_type} is not supported!"

2596
2597
2598
        if window_size is None:
            window_size = self.window_size

2599
2600
        if qkv_format is None:
            qkv_format = self.qkv_format
2601

2602
        assert (key_layer.shape[-2] == self.num_gqa_groups_per_partition
2603
2604
2605
2606
2607
2608
2609
2610
2611
2612
2613
2614
2615
2616
2617
2618
2619
            and value_layer.shape[-2] == self.num_gqa_groups_per_partition
            ), f"Keys and values must have num_gqa_group = {self.num_gqa_groups} heads!"
        assert (qkv_format in ['sbhd', 'bshd', 'thd']
            ), "DotProductAttention only supports qkv_format = {'sbhd', 'bshd', 'thd'}!"

        if qkv_format == 'thd':
            assert (all(len(x.shape) == 3 for x in (query_layer, key_layer, value_layer))
                ), "Queries, keys and values must be 3D tensors when qkv_format = thd!"
            assert (cu_seqlens_q is not None and cu_seqlens_kv is not None
                ), "cu_seqlens_q and cu_seqlens_kv can not be None when qkv_format = thd!"
            assert (cu_seqlens_q.shape == cu_seqlens_kv.shape
                and len(cu_seqlens_q.shape) == 1
                and len(cu_seqlens_kv.shape) == 1
                ), "cu_seqlens_q and cu_seqlens_q must both have shape [batch_size + 1]!"
            assert (cu_seqlens_q.dtype == torch.int32
                and cu_seqlens_kv.dtype == torch.int32
                ), "cu_seqlens_q and cu_seqlens_q must both be in dtype torch.int32!"
2620
2621
2622
2623
2624
2625
            if max_seqlen_q is None:
                seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
                max_seqlen_q = seqlens_q.max().item()
            if max_seqlen_kv is None:
                seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
                max_seqlen_kv = seqlens_kv.max().item()
2626
2627
2628
2629
2630
2631
2632
2633
2634
2635
2636
2637
2638
2639
2640
2641
2642
2643
2644

        if qkv_format in ['sbhd', 'bshd']:
            assert (all(len(x.shape) == 4 for x in (query_layer, key_layer, value_layer))
                ), f"Queries, keys and values must be 4D tensors when qkv_format = {qkv_format}!"
            if qkv_format == 'sbhd':
                max_seqlen_q, max_seqlen_kv = (query_layer.shape[0], key_layer.shape[0])
            if qkv_format == 'bshd':
                max_seqlen_q, max_seqlen_kv = (query_layer.shape[1], key_layer.shape[1])
            if cu_seqlens_q is not None:
                seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
                assert (all(seqlens_q <= max_seqlen_q)
                    ), """Sequence lengths indicated by cu_seqlens_q must be no greater than
                    the sequence dimention in 'query_layer'!"""
            if cu_seqlens_kv is not None:
                seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
                assert (all(seqlens_kv <= max_seqlen_kv)
                    ), """Sequence lengths indicated by cu_seqlens_kv must be no greater than
                    the sequence dimention in 'key_layer' and 'value_layer'!"""

2645
2646
        qkv_layout, query_layer, key_layer, value_layer = _get_qkv_layout(
            query_layer, key_layer, value_layer, qkv_format = qkv_format)
2647

2648
2649
        # The priority for attention backends (subject to availability and clearing the filters)
        # is: FlashAttention > FusedAttention (cuDNN) > UnfusedDotProductAttention.
2650
        use_flash_attention = self.use_flash_attention
2651
        use_fused_attention = self.use_fused_attention
2652
        use_unfused_attention = True
2653

2654
2655
2656
2657
        # The following section filters out some backends based on
        # certain asserts before executing the forward pass.

        # Filter: Input type.
2658
2659
2660
2661
2662
        if (query_layer.dtype not in [torch.bfloat16, torch.float16]
            or key_layer.dtype not in [torch.bfloat16, torch.float16]
            or value_layer.dtype not in [torch.bfloat16, torch.float16]
        ):
            use_flash_attention = False
2663
            use_fused_attention = False
2664

2665
        # Filter: Device and dimensions.
2666
        # FAv2 supports head_dim <= 256, and for >192 requires sm80/sm90
2667
2668
2669
2670
2671
        # FAv2 requires head_dim % 8 == 0
        if (key_layer.shape[-1] > 256
            or key_layer.shape[-1] % 8 != 0
            or (key_layer.shape[-1] > 192
                and self.device_compute_capability not in ((8, 0), (9, 0)))):
2672
2673
            use_flash_attention = False

2674
        # Filter: cross attention + causal mask.
2675
        if (_flash_attn_2_1_plus
2676
            and "causal" in attn_mask_type
2677
2678
2679
2680
2681
2682
2683
2684
            and max_seqlen_q != max_seqlen_kv):
            warnings.warn(
                "Disabling the use of FlashAttention since version 2.1+ has changed its behavior "
                "for causal mask in cross attention. See "
                "https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag"
            )
            use_flash_attention = False

2685
        # Filter: bias.
2686
2687
2688
        if core_attention_bias_type != "no_bias" or core_attention_bias is not None:
            use_flash_attention = False

2689
2690
2691
        context_parallel = (self.cp_group is not None and \
            get_distributed_world_size(self.cp_group) != 1)

2692
2693
2694
2695
2696
2697
2698
        # Filter: sliding window attention.
        # UnfusedDotProductAttention can support SWA via arbitrary attention mask.
        if window_size not in ((-1, -1), (-1, 0)):
            use_fused_attention = False
            if (not _flash_attn_2_3_plus) or context_parallel:
                use_flash_attention = False

2699
        # Filter: ONNX export.
2700
2701
        if is_in_onnx_export_mode():
            use_flash_attention = False
2702
2703
            use_fused_attention = False

2704
        # Filter: Attention mask type.
2705
        #   attn_mask_type(s)    |     supported backends
2706
        # ------------------------------------------------
2707
2708
        #   no_mask              |     All
        #   padding              |     UnfusedDotProductAttention, FlashAttention, FusedAttention
2709
        #   causal               |     All
2710
        #   padding + causal     |     FlashAttention, FusedAttention
2711
2712
2713
2714
2715
        #   arbitrary            |     UnfusedDotProductAttention
        #
        if attn_mask_type == "arbitrary":
            use_flash_attention = False
            use_fused_attention = False
2716
2717
        if "causal" in attn_mask_type and max_seqlen_q != max_seqlen_kv:
            use_unfused_attention = False
2718

2719
2720
2721
2722
2723
2724
        if use_fused_attention:
            fused_attention_backend = tex.get_fused_attn_backend(
                TE_DType[query_layer.dtype],
                TE_DType[key_layer.dtype],
                QKVLayout[qkv_layout],
                AttnBiasType[core_attention_bias_type],
2725
                AttnMaskType[attn_mask_type],
2726
                self.attention_dropout,
2727
2728
2729
2730
2731
2732
                query_layer.shape[-2], # num_attn_heads
                key_layer.shape[-2], # num_gqa_groups
                max_seqlen_q,
                max_seqlen_kv,
                query_layer.shape[-1], # head_dim
            )
2733
2734
2735
            # DPA does not support FP8; for FP8, use cpp_extensions modules directly
            is_backend_avail = (fused_attention_backend in
                [FusedAttnBackend["F16_max512_seqlen"], FusedAttnBackend["F16_arbitrary_seqlen"]])
2736
2737
2738
2739
            use_fused_attention = ( \
                use_fused_attention and is_backend_avail and \
                (not context_parallel or \
                 fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]))
2740

2741
2742
2743
2744
2745
2746
2747
        # Filter: Alibi slopes
        if alibi_slopes is not None:
            use_fused_attention = False
            assert (
                use_flash_attention
            ), "Alibi slopes bias is only supported in the FlashAttention backend."

2748
2749
2750
2751
2752
2753
2754
2755
2756
2757
2758
2759
2760
2761
2762
2763
2764
2765
        # Filter: determinism.
        # backend                                  | deterministic
        # ---------------------------------------------------------
        # flash-attn v1                            | yes
        # flash-attn v2                            | no
        # FusedAttnBackend["F16_max512_seqlen"]    | yes
        # FusedAttnBackend["F16_arbitrary_seqlen"] | workspace optimization path: yes; otherwise: no
        # UnfusedDotProductAttention               | yes
        #
        # Note that FusedAttnBackend["F16_arbitrary_seqlen"] only has workspace optimization path
        # on sm90 architectures.
        #
        if (use_fused_attention
            and fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]
            and self.deterministic
            and self.device_compute_capability != (9, 0)):
            use_fused_attention = False

2766
2767
2768
2769
2770
2771
        # Select FusedAttention on sm90 and FlashAttention on others for performance
        if (use_flash_attention
            and use_fused_attention
            and fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]):
            if self.device_compute_capability == (9, 0):
                use_flash_attention = False
2772
2773

        if use_flash_attention:
2774
2775
            if _NVTE_DEBUG:
                print("[DotProductAttention]: using flash-attn",_flash_attn_version)
2776
2777
2778
2779
2780
2781
2782
2783
            return self.flash_attention(query_layer,
                                        key_layer,
                                        value_layer,
                                        attention_mask=attention_mask,
                                        qkv_layout=qkv_layout,
                                        cu_seqlens_q=cu_seqlens_q,
                                        cu_seqlens_kv=cu_seqlens_kv,
                                        attn_mask_type=attn_mask_type,
2784
                                        window_size=window_size,
2785
                                        alibi_slopes=alibi_slopes,
2786
2787
                                        cp_group=self.cp_group,
                                        cp_global_ranks=self.cp_global_ranks,
2788
2789
2790
                                        cp_stream=self.cp_stream,
                                        max_seqlen_q=max_seqlen_q,
                                        max_seqlen_kv=max_seqlen_kv)
2791

2792
        if use_fused_attention:
2793
2794
2795
            if _NVTE_DEBUG:
                print("[DotProductAttention]: using cuDNN fused attention (backend "
                    + str(int(fused_attention_backend)) + ")")
2796
            if checkpoint_core_attention:
2797
2798
2799
2800
2801
2802
2803
2804
2805
2806
2807
2808
2809
2810
2811
2812
2813
2814
2815
2816
2817
2818
2819
2820
2821
2822
2823
2824
2825
2826
2827
2828
2829
2830
2831
2832
2833
2834
2835
2836
                return self._checkpointed_attention_forward(
                    self.fused_attention,
                    query_layer,
                    key_layer,
                    value_layer,
                    qkv_layout=qkv_layout,
                    cu_seqlens_q=cu_seqlens_q,
                    cu_seqlens_kv=cu_seqlens_kv,
                    attn_mask_type=attn_mask_type,
                    attention_mask=attention_mask,
                    fused_attention_backend=fused_attention_backend,
                    core_attention_bias_type=core_attention_bias_type,
                    core_attention_bias=core_attention_bias,
                    fast_zero_fill=fast_zero_fill,
                    cp_group=self.cp_group,
                    cp_global_ranks=self.cp_global_ranks,
                    cp_stream=self.cp_stream,
                    max_seqlen_q=max_seqlen_q,
                    max_seqlen_kv=max_seqlen_kv)
            return self.fused_attention(
                query_layer,
                key_layer,
                value_layer,
                qkv_layout=qkv_layout,
                cu_seqlens_q=cu_seqlens_q,
                cu_seqlens_kv=cu_seqlens_kv,
                attn_mask_type=attn_mask_type,
                attention_mask=attention_mask,
                fused_attention_backend=fused_attention_backend,
                core_attention_bias_type=core_attention_bias_type,
                core_attention_bias=core_attention_bias,
                fast_zero_fill=fast_zero_fill,
                cp_group=self.cp_group,
                cp_global_ranks=self.cp_global_ranks,
                cp_stream=self.cp_stream,
                max_seqlen_q=max_seqlen_q,
                max_seqlen_kv=max_seqlen_kv)

        assert (not context_parallel), \
            "Context parallelism is only implemented with Flash Attention and Fused Attention!"
2837

2838
2839
2840
2841
2842
2843
2844
        from .cpu_offload import CPUOffloadEnabled
        if CPUOffloadEnabled:
            warnings.warn(
                           "Attention activation Offloading is only implemented"
                           "with Flash Attention and Fused Attention!"
                         )

2845
2846
        if _NVTE_DEBUG:
            print("[DotProductAttention]: using unfused DPA")
2847
2848
2849
2850
2851
2852
2853
2854
2855
2856
2857
2858
2859
2860
2861
2862
2863
2864
2865
2866
2867
2868
2869
2870
2871
2872
        if use_unfused_attention:
            if checkpoint_core_attention:
                return self._checkpointed_attention_forward(
                    self.unfused_attention,
                    query_layer,
                    key_layer,
                    value_layer,
                    qkv_layout = qkv_layout,
                    cu_seqlens_q = cu_seqlens_q,
                    cu_seqlens_kv = cu_seqlens_kv,
                    attn_mask_type = attn_mask_type,
                    attention_mask = attention_mask,
                    core_attention_bias_type = core_attention_bias_type,
                    core_attention_bias = core_attention_bias)
            return self.unfused_attention(query_layer,
                    key_layer,
                    value_layer,
                    qkv_layout = qkv_layout,
                    cu_seqlens_q = cu_seqlens_q,
                    cu_seqlens_kv = cu_seqlens_kv,
                    attn_mask_type = attn_mask_type,
                    attention_mask = attention_mask,
                    core_attention_bias_type = core_attention_bias_type,
                    core_attention_bias = core_attention_bias)

        raise Exception("No dot product attention support for the provided inputs!")
2873
2874


2875
2876
2877
2878
2879
2880
2881
class MultiheadAttention(torch.nn.Module):
    r"""
    Multi-head Attention (MHA), including Query,
    Key, Value and Output projection.

    .. note::

2882
2883
        Argument :attr:`attention_mask` in the `forward` call is only used when
        :attr:`attn_mask_type` includes '"padding"' or `"arbitrary"`.
2884

2885
2886
2887
2888
2889
2890
2891
2892
2893
2894
2895
2896
2897
2898
2899
2900
2901
2902
2903
2904
2905
2906
2907
2908
2909
    Parameters
    ----------
    hidden_size : int
                 size of each input sample.
    num_attention_heads : int
                         number of attention heads in the transformer layer.
    kv_channels: int, default = `None`
                number of key-value channels. defaults to
                :attr:`hidden_size` / :attr:`num_attention_heads` if `None`.
    attention_dropout: float, default = 0.1
                      dropout probability for the dropout op during multi-head attention.
    layernorm_epsilon : float, default = 1e-5
                       a value added to the denominator of layer normalization
                       for numerical stability.
    init_method : Callable, default = `None`
                 used for initializing weights of QKV and FC1 weights in the following way:
                 `init_method(weight)`. When set to `None`, defaults to
                 `torch.nn.init.normal_(mean=0.0, std=0.023)`.
    output_layer_init_method : Callable, default = `None`
                              used for initializing weights of PROJ and FC2 in the following way:
                              `output_layer_init_method(weight)`. When set to `None`, defaults to
                              `torch.nn.init.normal_(mean=0.0, std=0.023)`.
    layer_number: int, default = `None`
                 layer number of the current `TransformerLayer` when multiple such modules are
                 concatenated to form a transformer block.
2910
2911
    attn_mask_type: {'no_mask', 'padding', 'causal', 'padding_causal' 'arbitrary'},
                   default = `causal`
2912
2913
2914
2915
2916
                   type of attention mask passed into softmax operation. Overridden by
                   :attr:`attn_mask_type` in the `forward` method. The forward
                   arg is useful for dynamically changing mask types, e.g. a different
                   mask for training and inference. The init arg is useful for cases
                   involving compilation/tracing, e.g. ONNX export.
2917
2918
2919
2920
2921
2922
    window_size: Optional[Tuple[int, int]], default = `None`
                sliding window size for local attention, where query at position i attends to keys
                in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
                + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
                window and causal mask specifically. Similar to :attr:`attn_mask_type`, it can
                be overridden by :attr:`window_size` in `forward` as well.
2923
2924
2925
2926
2927
2928
2929
2930
2931
2932
2933
2934
2935
    num_gqa_groups : int, default = `None`
                         number of GQA groups in the transformer layer.
                         Grouped Query Attention is described in
                         `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
                         This only affects the keys and values, not the querys.
                         GQA-1 is equivalent to Multi-Query Attention
                         (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
                         is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
    return_layernorm_output : bool, default = `False`
                             if set to `True`, output of layernorm is returned from the forward
                             together with the output of the linear transformation.
                             Example use case: residual connection for transformer module is
                             taken post layernorm.
2936
2937
    input_layernorm: bool, default = `False`
                     if set to `True`, layer normalization to the input is applied.
2938
2939
2940
2941
2942
2943
2944
2945
2946
2947
2948
2949
2950
2951
2952
2953
2954
2955
2956
2957
2958
2959
2960
    attention_type: { 'self', 'cross' }, default = 'self'
                   type of attention applied.
    zero_centered_gamma : bool, default = 'False'
                         if set to 'True', gamma parameter in LayerNorm is initialized to 0 and
                         the LayerNorm formula changes to

                         .. math::
                            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} *
                            (1 + \gamma) + \beta
    normalization : { 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm'
                   type of normalization applied.
    qkv_weight_interleaved : bool, default = `True`
                            if set to `False`, the QKV weight is interpreted as a concatenation of
                            query, key, and value weights along the `0th` dimension. The default
                            interpretation is that the individual `q`, `k`, and `v` weights for each
                            attention head are interleaved. This parameter is set to `False` when
                            using :attr:`fuse_qkv_params=False`.
    bias : bool, default = `True`
          if set to `False`, the transformer layer will not learn any additive biases.
    device : Union[torch.device, str], default = "cuda"
          The device on which the parameters of the model will allocated. It is the user's
          responsibility to ensure all parameters are moved to the GPU before running the
          forward pass.
2961
2962
2963
2964
2965
2966
2967
2968
    qkv_format: str, default = `sbhd`
            dimension format for `query_layer`, `key_layer` and `value_layer`,
            {`sbhd`, `bshd`}. `s` stands for the sequence length, `b` batch size,
            `h` the number of heads and `d` head size. `sbhd` and `bshd` formats
            are used for when sequences in a batch are of equal length or padded to
            equal length. Please note that these formats do not reflect how
            tensors `query_layer`, `key_layer`, `value_layer` are laid out in memory.
            For that, please use `_get_qkv_layout` to gain the layout information.
2969
2970
2971
2972
2973
2974
2975
2976
2977
2978
2979
2980
2981
2982
2983
2984
2985
2986
2987
2988
2989
2990
2991
2992
2993
2994
2995
2996
2997
2998
2999
3000
3001
3002
3003
3004
3005
3006
3007
3008

    Parallelism parameters
    ----------------------
    set_parallel_mode : bool, default = `False`
                      if set to `True`, QKV and FC1 layers are used as Column Parallel
                      whereas PROJ and FC2 is used as Row Parallel as described
                      `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
    sequence_parallel : bool, default = `False`
                       if set to `True`, uses sequence parallelism.
    tp_group : ProcessGroup, default = `None`
              tensor parallel process group.
    tp_size : int, default = 1
             used as TP (tensor parallel) world size when TP groups are not formed during
             initialization. In this case, users must call the
             `set_tensor_parallel_group(tp_group)` method on the initialized module before the
             forward pass to supply the tensor parallel group needed for tensor and sequence
             parallel collectives.

    Optimization parameters
    -----------------------
    fuse_wgrad_accumulation : bool, default = 'False'
                             if set to `True`, enables fusing of creation and accumulation of
                             the weight gradient. When enabled, it is assumed that the weights
                             have an additional `main_grad` attribute (used instead of the
                             regular `grad`) which is a pre-allocated buffer of the correct
                             size to accumulate gradients in.
    params_dtype : torch.dtype, default = `torch.get_default_dtype()`
                  it controls the type used to allocate the initial parameters. Useful when
                  the model is trained with lower precision and the original FP32 parameters
                  would not fit in GPU memory.
    return_bias : bool, default = `False`
                 when set to `True`, this module will not apply the additive bias itself, but
                 instead return the bias value during the forward pass together with the
                 output of the linear transformation :math:`y = xA^T`. This is useful when
                 the bias addition can be fused to subsequent operations.
    fuse_qkv_params: bool, default = 'False'
                    if set to `True`, `TransformerLayer` module exposes a single fused
                    parameter for query-key-value. This enables optimizations such as QKV
                    fusion without concatentations/splits and also enables the argument
                    `fuse_wgrad_accumulation`.
3009
3010
3011
3012
3013
3014
    """

    def __init__(
        self,
        hidden_size: int,
        num_attention_heads: int,
3015
3016
3017
3018
3019
        kv_channels: Optional[int] = None,
        attention_dropout: float = 0.1,
        layernorm_epsilon: float = 1e-5,
        init_method: Optional[Callable] = None,
        output_layer_init_method: Optional[Callable] = None,
3020
        layer_number: Optional[int] = None,
3021
        attn_mask_type: str = "causal",
3022
        window_size: Optional[Tuple[int, int]] = None,
3023
3024
        tp_group: Optional[dist_group_type] = None,
        tp_size: int = 1,
3025
        num_gqa_groups: Optional[int] = None,
3026
3027
3028
        fuse_wgrad_accumulation: bool = False,
        get_rng_state_tracker: Optional[Callable] = None,
        sequence_parallel: bool = False,
3029
        params_dtype: Optional[torch.dtype] = None,
3030
        return_bias: bool = False,
3031
3032
3033
3034
3035
3036
3037
3038
3039
3040
3041
        return_layernorm_output: bool = False,
        input_layernorm: bool = False,
        attention_type: str = "self",
        set_parallel_mode: bool = False,
        fuse_qkv_params: bool = False,
        zero_centered_gamma: bool = False,
        qkv_weight_interleaved: bool = True,
        ub_bulk_wgrad: bool = False,
        ub_bulk_dgrad: bool = False,
        ub_split_rs: bool = False,
        ub_split_ag: bool = False,
3042
3043
        ub_atomic_gemm_rs: bool = False,
        ub_atomic_gemm_ag: bool = False,
3044
        bias: bool = True,
3045
        normalization: str = "LayerNorm",
3046
        device: Union[torch.device, str] = "cuda",
3047
        qkv_format: str = "sbhd",
3048
3049
    ) -> None:
        super().__init__()
3050

3051
        self.qkv_format = qkv_format
3052
        self.attn_mask_type = attn_mask_type
3053
3054
        self.window_size = window_size
        self.window_size = check_set_window_size(attn_mask_type, self.window_size)
3055
        self.layer_number = layer_number
3056
3057
3058
3059
3060
        self.input_layernorm = input_layernorm
        self.attention_type = attention_type
        self.get_rng_state_tracker = get_rng_state_tracker
        self.tp_group = tp_group
        self.return_layernorm_output = return_layernorm_output
3061
        self.params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
3062
        self.num_attention_heads = num_attention_heads
3063
3064
3065
3066
3067
3068
3069
3070
        self.return_bias = return_bias

        kv_channels = kv_channels if kv_channels else (hidden_size // num_attention_heads)

        if init_method is None:
            init_method = get_default_init_method()
        if output_layer_init_method is None:
            output_layer_init_method = get_default_init_method()
3071
3072
3073
3074
3075

        if not fuse_qkv_params:
            qkv_weight_interleaved = False
        self.qkv_weight_interleaved = qkv_weight_interleaved

3076
3077
3078
        assert attention_type in AttnTypes, f"attention_type {attention_type} not supported"
        if layer_number is not None:
            assert layer_number > 0, "layer_number must be a positive integer"
3079
3080
3081
3082
3083
3084
3085

        tp_size = tp_size if tp_group is None else get_distributed_world_size(tp_group)
        self.tp_size = tp_size
        self.sequence_parallel = (tp_size > 1) and sequence_parallel

        self.hidden_size_per_attention_head = kv_channels
        self.num_attention_heads_per_partition = divide(num_attention_heads, tp_size)
3086
3087
3088
3089
        self.num_gqa_groups = (
            num_attention_heads if num_gqa_groups is None else num_gqa_groups
        )
        assert (num_attention_heads % self.num_gqa_groups == 0
cyanguwa's avatar
cyanguwa committed
3090
3091
                ), "The number of attention heads must be divisible by the number of GQA groups!"
        assert (self.num_gqa_groups % tp_size == 0
3092
3093
3094
                ), "The number of GQA groups must be divisible by tensor parallel size!"
        self.num_gqa_groups_per_partition = int(self.num_gqa_groups // tp_size)
        self.hidden_size_kv = int(hidden_size * self.num_gqa_groups // num_attention_heads)
3095
3096
3097
3098
3099
3100
3101

        common_gemm_kwargs = {
            "fuse_wgrad_accumulation": fuse_wgrad_accumulation,
            "tp_group": tp_group,
            "tp_size": tp_size,
            "get_rng_state_tracker": get_rng_state_tracker,
            "sequence_parallel": sequence_parallel,
3102
            "params_dtype": self.params_dtype,
3103
            "device": device,
3104
3105
3106
3107
        }

        qkv_parallel_mode = "column" if set_parallel_mode else None

cyanguwa's avatar
cyanguwa committed
3108
        if self.attention_type == "self":
3109
3110
3111
3112
3113
3114
3115
            parameters_split = None
            if not fuse_qkv_params:
                parameters_split = collections.OrderedDict([
                    ("query", hidden_size),
                    ("key", self.hidden_size_kv),
                    ("value", self.hidden_size_kv),
                ])
3116
3117
3118
            if self.input_layernorm:
                self.layernorm_qkv = LayerNormLinear(
                    hidden_size,
cyanguwa's avatar
cyanguwa committed
3119
                    hidden_size + 2 * self.hidden_size_kv,
3120
3121
3122
3123
3124
3125
                    eps=layernorm_epsilon,
                    init_method=init_method,
                    bias=bias,
                    return_bias=False,
                    parallel_mode=qkv_parallel_mode,
                    return_layernorm_output=return_layernorm_output,
cyanguwa's avatar
cyanguwa committed
3126
                    parameters_split=parameters_split,
3127
3128
3129
3130
                    zero_centered_gamma=zero_centered_gamma,
                    ub_bulk_wgrad=ub_bulk_wgrad,
                    ub_bulk_dgrad=ub_bulk_dgrad,
                    ub_split_ag=ub_split_ag,
3131
                    normalization=normalization,
3132
                    ub_atomic_gemm_ag=ub_atomic_gemm_ag,
Przemyslaw Tredak's avatar
Przemyslaw Tredak committed
3133
                    ub_name="qkv",
3134
3135
3136
3137
3138
                    **common_gemm_kwargs,
                )
            else:
                self.qkv = Linear(
                    hidden_size,
cyanguwa's avatar
cyanguwa committed
3139
                    hidden_size + 2 * self.hidden_size_kv,
3140
3141
3142
3143
                    init_method=init_method,
                    bias=bias,
                    return_bias=False,
                    parallel_mode=qkv_parallel_mode,
cyanguwa's avatar
cyanguwa committed
3144
                    parameters_split=parameters_split,
3145
3146
                    **common_gemm_kwargs,
                )
cyanguwa's avatar
cyanguwa committed
3147
        elif self.attention_type == "cross":
3148
3149
3150
3151
3152
3153
3154
3155
3156
            if self.input_layernorm:
                self.layernorm_query = LayerNormLinear(
                    hidden_size,
                    hidden_size,
                    eps=layernorm_epsilon,
                    init_method=init_method,
                    bias=bias,
                    return_bias=False,
                    parallel_mode=qkv_parallel_mode,
3157
                    parameters_split=("query",) if not fuse_qkv_params else None,
3158
3159
3160
3161
3162
                    return_layernorm_output=return_layernorm_output,
                    zero_centered_gamma=zero_centered_gamma,
                    ub_bulk_wgrad=ub_bulk_wgrad,
                    ub_bulk_dgrad=ub_bulk_dgrad,
                    ub_split_ag=ub_split_ag,
3163
                    normalization=normalization,
3164
                    ub_atomic_gemm_ag=ub_atomic_gemm_ag,
Przemyslaw Tredak's avatar
Przemyslaw Tredak committed
3165
                    ub_name="qkv",
3166
3167
3168
3169
3170
3171
3172
3173
3174
3175
3176
3177
3178
3179
                    **common_gemm_kwargs,
                )
            else:
                self.query_layer = Linear(
                    hidden_size,
                    hidden_size,
                    init_method=init_method,
                    bias=bias,
                    return_bias=False,
                    parallel_mode=qkv_parallel_mode,
                    **common_gemm_kwargs,
                )
            self.key_value = Linear(
                hidden_size,
3180
                2 * self.hidden_size_kv,
3181
3182
3183
3184
                init_method=init_method,
                bias=bias,
                return_bias=False,
                parallel_mode=qkv_parallel_mode,
3185
                parameters_split=("key", "value") if not fuse_qkv_params else None,
3186
3187
3188
3189
3190
3191
3192
                **common_gemm_kwargs,
            )

        # Attention.
        self.core_attention = DotProductAttention(
            num_attention_heads,
            kv_channels,
3193
3194
            num_gqa_groups=self.num_gqa_groups,
            attention_dropout=attention_dropout,
3195
            qkv_format=self.qkv_format,
3196
3197
3198
3199
            tp_size=tp_size,
            get_rng_state_tracker=get_rng_state_tracker,
            sequence_parallel=sequence_parallel,
            tp_group=tp_group,
3200
            layer_number=self.layer_number,
3201
            attention_type=self.attention_type,
3202
3203
3204
3205
3206
3207
3208
3209
        )

        # Linear
        self.proj = Linear(
            hidden_size,
            hidden_size,
            init_method=output_layer_init_method,
            bias=bias,
3210
            return_bias=return_bias,
3211
3212
3213
            parallel_mode="row" if set_parallel_mode else None,
            ub_split_rs=ub_split_rs,
            ub_split_ag=ub_split_ag,
3214
3215
            ub_atomic_gemm_rs=ub_atomic_gemm_rs,
            ub_atomic_gemm_ag=ub_atomic_gemm_ag,
Przemyslaw Tredak's avatar
Przemyslaw Tredak committed
3216
            ub_name="proj",
3217
3218
3219
3220
3221
            **common_gemm_kwargs,
        )


    def _allocate_memory(
3222
        self, inference_max_sequence_len: int, batch_size: int, dtype: torch.dtype
3223
3224
3225
3226
    ) -> torch.Tensor:
        return torch.empty(
            inference_max_sequence_len,
            batch_size,
3227
            self.num_gqa_groups_per_partition,
3228
            self.hidden_size_per_attention_head,
3229
            dtype=dtype,
3230
3231
3232
3233
            device=torch.cuda.current_device(),
        )

    def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None:
3234
3235
3236
3237
3238
3239
3240
3241
3242
        """
        Set the tensor parallel group for the given
        module before executing the forward pass.

        Parameters
        ----------
        tp_group : ProcessGroup, default = `None`
                  tensor parallel process group.
        """
3243
3244
        self.tp_group = tp_group

3245
    def set_context_parallel_group(
3246
3247
        self,
        cp_group: Union[dist_group_type, None],
3248
        cp_global_ranks: List[int],
3249
3250
        cp_stream: torch.cuda.Stream,
    ) -> None:
3251
3252
3253
3254
3255
3256
3257
3258
3259
3260
3261
3262
3263
        """
        Set the context parallel attributes for the given
        module before executing the forward pass.

        Parameters
        ----------
        cp_group : ProcessGroup
                  context parallel process group.
        cp_global_ranks : List[int]
                         list of global ranks in the context group.
        cp_stream : torch.cuda.Stream
                   cuda stream for context parallel execution.
        """
3264
3265
3266
3267
3268
3269
        # Deep iterate but skip self to avoid infinite recursion.
        for index, child in enumerate(self.modules()):
            if index == 0:
                continue
            if hasattr(child, "set_context_parallel_group"):
                child.set_context_parallel_group(cp_group, cp_global_ranks, cp_stream)
3270

3271
3272
3273
    def forward(
        self,
        hidden_states: torch.Tensor,
3274
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
3275
        encoder_output: Optional[torch.Tensor] = None,
3276
        attn_mask_type: Optional[str] = None,
3277
        window_size: Optional[Tuple[int, int]] = None,
3278
3279
        is_first_microbatch: Optional[bool] = None,
        checkpoint_core_attention: bool = False,
3280
        inference_params: Optional[InferenceParams] = None,
3281
        rotary_pos_emb: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
3282
3283
3284
        core_attention_bias_type: str = "no_bias",
        core_attention_bias: Optional[torch.Tensor] = None,
        fast_zero_fill: bool = True,
3285
    ) -> Tuple[Union[torch.Tensor, None], ...]:
3286
3287
3288
3289
3290
        """
        Forward propagation for MultiheadAttention layer.

        .. note::

3291
3292
            Argument :attr:`attention_mask` is only used when :attr:`attn_mask_type`
            includes `"padding"` or `"arbitrary"`.
3293
3294
3295
3296
3297

        Parameters
        ----------
        hidden_states : torch.Tensor
             Input tensor.
3298
3299
3300
3301
3302
3303
3304
3305
3306
        attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]],
             default = `None`. Boolean tensor(s) used to mask out attention softmax input.
             It should be 'None' for 'causal' and 'no_mask' types. For 'padding' masks, it should be
             a single tensor of [batch_size, 1, 1, seqlen_q] for self-attention, and a tuple of
             two tensors in shapes [batch_size, 1, 1, seqlen_q] and [batch_size, 1, 1, seqlen_kv]
             for cross-attention. For the 'arbitrary' mask type, it should be in a shape that is
             broadcastable to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv].
        attn_mask_type: {'no_mask', 'padding', 'causal', 'padding_causal', 'arbitrary'},
                       default = `None`
3307
                       type of attention mask passed into softmax operation.
3308
3309
        window_size: Optional[Tuple[int, int]], default = `None`
                    sliding window size for local attention.
3310
3311
3312
3313
3314
3315
3316
3317
3318
3319
3320
3321
3322
3323
3324
3325
3326
3327
3328
3329
3330
3331
3332
3333
3334
        encoder_output : Optional[torch.Tensor], default = `None`
             Output of the encoder block to be fed into the decoder block if using
             `layer_type="decoder"`.
        is_first_microbatch : {True, False, None}, default = None
                             During training using either gradient accumulation or
                             pipeline parallelism a minibatch of data is further split
                             into microbatches. Between the microbatches of the same minibatch
                             the model weights are not updated. Setting this parameter indicates
                             whether the current microbatch is the first in a minibatch or not.
                             When set, this parameter enables additional optimizations:

                             * during FP8 training, it allows caching of the FP8 versions of
                               the weights
                             * it also allows skipping gradient accumulation during the
                               first microbatch (since it is the first gradient being
                               produced)
        checkpoint_core_attention: bool, default = `False`
                                  If true, forward activations for core attention are recomputed
                                  during the backward pass in order to save memory that would
                                  otherwise be occupied to store the forward activations until
                                  backprop.
        rotary_pos_emb: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], default = `None`
                       Embeddings for query and key tensors for applying rotary position
                       embedding. By default no input embedding is applied.
        core_attention_bias_type: str, default = `no_bias`
3335
                    Bias type, {`no_bias`, `pre_scale_bias`, 'post_scale_bias`, `alibi`}
3336
        core_attention_bias: Optional[torch.Tensor], default = `None`
3337
3338
                    Bias tensor for Q * K.T, shape [1, num_head, max_seqlen_q, max_seqlen_kv].
                    It should be 'None' for 'no_bias' and 'alibi' bias types.
3339
3340
3341
        fast_zero_fill: bool, default = `True`
                    Whether to set output tensors to 0 or not before use.
        """
3342
3343
        # hidden_states: [sq, b, h]

3344
3345
        if attn_mask_type is not None:
            window_size = check_set_window_size(attn_mask_type, window_size)
3346
        if attn_mask_type is None:
3347
            attn_mask_type = self.attn_mask_type
3348
3349
        if window_size is None:
            window_size = self.window_size
3350

3351
3352
3353
3354
3355
        if "padding" in attn_mask_type and attention_mask is not None:
            for i,_ in enumerate(attention_mask):
                assert (
                    attention_mask[i].dtype == torch.bool
                ), "Attention mask must be in boolean type!"
3356

3357
3358
        assert (core_attention_bias_type in AttnBiasTypes
                ), f"core_attention_bias_type {core_attention_bias_type} is not supported!"
3359

3360
3361
3362
3363
3364
3365
3366
3367
3368
        # =================================================
        # Pre-allocate memory for key-values for inference.
        # =================================================

        if inference_params and self.layer_number is not None:
            if self.layer_number not in inference_params.key_value_memory_dict:
                inf_max_seq_len = inference_params.max_sequence_len
                inf_max_batch_size = inference_params.max_batch_size
                inference_key_memory = self._allocate_memory(
3369
                    inf_max_seq_len, inf_max_batch_size, hidden_states.dtype
3370
3371
                )
                inference_value_memory = self._allocate_memory(
3372
                    inf_max_seq_len, inf_max_batch_size, hidden_states.dtype
3373
3374
3375
3376
3377
3378
3379
3380
3381
3382
3383
3384
3385
3386
3387
                )
                inference_params.key_value_memory_dict[self.layer_number] = (
                    inference_key_memory,
                    inference_value_memory,
                )
            else:
                (
                    inference_key_memory,
                    inference_value_memory,
                ) = inference_params.key_value_memory_dict[self.layer_number]

        # =====================
        # Query, Key, and Value
        # =====================

cyanguwa's avatar
cyanguwa committed
3388
3389
        if self.attention_type == "self":
            # Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn]
3390
3391
3392
3393
3394
3395
3396
3397
3398
3399
3400
3401
3402
3403
3404
            if self.input_layernorm:
                layernorm_qkv_outputs = self.layernorm_qkv(
                    hidden_states,
                    is_first_microbatch=is_first_microbatch,
                )
                if self.return_layernorm_output:
                    mixed_x_layer, layernorm_output = layernorm_qkv_outputs
                else:
                    mixed_x_layer = layernorm_qkv_outputs
            else:
                mixed_x_layer = self.qkv(
                    hidden_states,
                    is_first_microbatch=is_first_microbatch,
                )

cyanguwa's avatar
cyanguwa committed
3405
3406
            num_queries_per_key_value = (self.num_attention_heads_per_partition //
                                         self.num_gqa_groups_per_partition)
3407
            if self.qkv_weight_interleaved:
cyanguwa's avatar
cyanguwa committed
3408
                # [sq, b, ng * (np/ng + 2) * hn] --> [sq, b, ng, (np/ng + 2), hn]
3409
                new_tensor_shape = mixed_x_layer.size()[:-1] + (
cyanguwa's avatar
cyanguwa committed
3410
3411
                    self.num_gqa_groups_per_partition,
                    (num_queries_per_key_value + 2),
3412
3413
3414
3415
                    self.hidden_size_per_attention_head,
                )
                # split along second last dimension
                split_dim = -2
cyanguwa's avatar
cyanguwa committed
3416
3417
3418
3419
3420
3421
3422
3423
3424
            else:
                # [sq, b, ng * (np/ng + 2) * hn] --> [sq, b, (np/ng + 2), ng, hn]
                new_tensor_shape = mixed_x_layer.size()[:-1] + (
                    (num_queries_per_key_value + 2),
                    self.num_gqa_groups_per_partition,
                    self.hidden_size_per_attention_head
                )
                # split along third last dimension
                split_dim = -3
3425
3426
3427

            mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)

cyanguwa's avatar
cyanguwa committed
3428
3429
3430
3431
3432
3433
3434
3435
3436
            # qkv_weight_interleaved:
            #  [sq, b, ng, (np/ng + 2), hn]
            #  --> [sq, b, ng, np/ng, hn], [sq, b, ng, 1, hn], [sq, b, ng, 1, hn]
            # not qkv_weight_interleaved:
            #  [sq, b, (np/ng + 2), ng, hn]
            #  --> [sq, b, np/ng, np, hn], [sq, b, 1, ng, hn], [sq, b, 1, ng, hn]
            if not is_in_onnx_export_mode():
                query_layer, key_layer, value_layer = _SplitAlongDim.apply(
                    mixed_x_layer, split_dim, (num_queries_per_key_value, 1, 1)
3437
                )
3438
            else:
cyanguwa's avatar
cyanguwa committed
3439
3440
3441
3442
3443
3444
3445
3446
3447
3448
3449
3450
                query_layer, key_layer, value_layer = torch.split(
                    mixed_x_layer, (num_queries_per_key_value, 1, 1), dim = split_dim,
                 )

            # query: -> [sq, b, np, hn]
            # key, value: -> [sq, b, ng, hn]
            query_layer, key_layer, value_layer = (x.reshape(x.size(0), x.size(1), -1,
                                                             self.hidden_size_per_attention_head)
                                                   for x in (query_layer, key_layer, value_layer))

        elif self.attention_type == "cross":
            # Attention heads [sk, b, h] --> [sk, b, (ng * 2 * hn)]
3451
            mixed_kv_layer = self.key_value(
cyanguwa's avatar
cyanguwa committed
3452
                encoder_output,
3453
3454
3455
3456
                is_first_microbatch=is_first_microbatch,
            )

            if self.qkv_weight_interleaved:
cyanguwa's avatar
cyanguwa committed
3457
                # [sq, b, (ng * 2 * hn)] --> [sq, b, ng, 2 * hn]
3458
                new_tensor_shape = mixed_kv_layer.size()[:-1] + (
3459
                    self.num_gqa_groups_per_partition,
3460
3461
3462
3463
3464
                    2 * self.hidden_size_per_attention_head,
                )
                # split along last dimension
                split_dim = -1
            else:
cyanguwa's avatar
cyanguwa committed
3465
                # [sq, b, (ng * 2 * hn)] --> [sq, b, 2 * ng, hn]
3466
                new_tensor_shape = mixed_kv_layer.size()[:-1] + (
3467
                    2 * self.num_gqa_groups_per_partition,
3468
3469
3470
3471
3472
3473
3474
                    self.hidden_size_per_attention_head,
                )
                # split along second last dimension
                split_dim = -2

            mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)

cyanguwa's avatar
cyanguwa committed
3475
3476
3477
3478
3479
            # mixed_kv_layer --> 2 [sk, b, ng, hn]
            if not is_in_onnx_export_mode():
                key_layer, value_layer = _SplitAlongDim.apply(
                    mixed_kv_layer, split_dim, mixed_kv_layer.shape[split_dim] // 2,
                )
3480
            else:
cyanguwa's avatar
cyanguwa committed
3481
3482
3483
                key_layer, value_layer = torch.split(
                    mixed_kv_layer, mixed_kv_layer.shape[split_dim] // 2, dim = split_dim,
                )
3484
3485
3486
3487
3488
3489
3490
3491
3492
3493
3494
3495
3496
3497
3498
3499
3500
3501
3502
3503
3504
3505
3506
3507
3508
3509
3510
3511

            # Attention head [sq, b, h] --> [sq, b, hp]
            if self.input_layernorm:
                layernorm_query_outputs = self.layernorm_query(
                    hidden_states,
                    is_first_microbatch=is_first_microbatch,
                )
                if self.return_layernorm_output:
                    query_layer, layernorm_output = layernorm_query_outputs
                else:
                    query_layer = layernorm_query_outputs
            else:
                query_layer = self.query_layer(
                    hidden_states,
                    is_first_microbatch=is_first_microbatch,
                )

            # [sq, b, hp] --> [sq, b, np, hn]
            new_tensor_shape = query_layer.size()[:-1] + (
                self.num_attention_heads_per_partition,
                self.hidden_size_per_attention_head,
            )
            query_layer = query_layer.view(*new_tensor_shape)

        # ==================================
        # Adjust key and value for inference
        # ==================================

3512
3513
3514
3515
3516
        # duplicate the pos_emb for self attention
        if rotary_pos_emb is not None:
            if not isinstance(rotary_pos_emb, tuple):
                rotary_pos_emb = ((rotary_pos_emb,) * 2)

3517
3518
3519
3520
3521
3522
3523
3524
3525
3526
3527
3528
3529
3530
3531
3532
3533
3534
3535
        if inference_params and self.layer_number is not None:
            batch_start = inference_params.batch_size_offset
            batch_end = batch_start + key_layer.size(1)
            assert batch_end <= inference_key_memory.size(1)
            sequence_start = inference_params.sequence_len_offset
            sequence_end = sequence_start + key_layer.size(0)
            assert sequence_end <= inference_key_memory.size(0)
            # Copy key and values.
            inference_key_memory[
                sequence_start:sequence_end, batch_start:batch_end, ...
            ] = key_layer
            inference_value_memory[
                sequence_start:sequence_end, batch_start:batch_end, ...
            ] = value_layer
            key_layer = inference_key_memory[:sequence_end, batch_start:batch_end, ...]
            value_layer = inference_value_memory[
                :sequence_end, batch_start:batch_end, ...
            ]

3536
3537
3538
            # adjust the key rotary positional embedding
            if rotary_pos_emb is not None:
                q_pos_emb, k_pos_emb = rotary_pos_emb
3539
                q_pos_emb = q_pos_emb[sequence_start:sequence_end, :, :, :]
3540
3541
3542
                k_pos_emb = k_pos_emb[:sequence_end, :, :, :]
                rotary_pos_emb = (q_pos_emb, k_pos_emb)

3543
3544
3545
3546
        # ==================================
        # core attention computation
        # ==================================

3547
3548
3549
        # apply relative positional encoding (rotary embedding)
        if rotary_pos_emb is not None:
            q_pos_emb, k_pos_emb = rotary_pos_emb
3550
3551
            query_layer = apply_rotary_pos_emb(query_layer, q_pos_emb, self.qkv_format)
            key_layer = apply_rotary_pos_emb(key_layer, k_pos_emb, self.qkv_format)
3552

3553
3554
3555
3556
        context_layer = self.core_attention(
            query_layer,
            key_layer,
            value_layer,
3557
            qkv_format=self.qkv_format,
3558
3559
            cu_seqlens_q=None,
            cu_seqlens_kv=None,
3560
3561
            attention_mask=attention_mask,
            attn_mask_type=attn_mask_type,
3562
            window_size=window_size,
3563
3564
3565
3566
            checkpoint_core_attention=checkpoint_core_attention,
            core_attention_bias_type=core_attention_bias_type,
            core_attention_bias=core_attention_bias,
            fast_zero_fill=fast_zero_fill,
3567
3568
3569
3570
3571
3572
        )

        # =================
        # Output. [sq, b, h]
        # =================

3573
        projection_output = self.proj(
3574
3575
3576
            context_layer, is_first_microbatch=is_first_microbatch
        )

3577
3578
3579
3580
3581
3582
3583
3584
        if self.return_bias:
            attention_output, attention_bias = projection_output
        else:
            attention_output, attention_bias = projection_output, None

        outputs = (attention_output,)
        if self.return_bias:
            outputs += (attention_bias,)
3585
        if self.input_layernorm and self.return_layernorm_output:
3586
3587
            outputs += (layernorm_output,)
        return outputs if len(outputs) > 1 else outputs[0]