attention.py 42.9 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.
"""JAX multi-head attention modules"""
5
from __future__ import annotations
6
7
from enum import Enum
from functools import partial
8
9
10
from typing import Optional, Tuple, Union
import warnings

11
from jax.ad_checkpoint import checkpoint_name
12
13
import jax
import jax.numpy as jnp
14
from flax.linen import make_attention_mask
15

16
17
18
19
20
from transformer_engine_jax import NVTE_Bias_Type
from transformer_engine_jax import NVTE_Mask_Type
from transformer_engine_jax import NVTE_QKV_Layout
from transformer_engine_jax import NVTE_QKV_Format
from transformer_engine_jax import nvte_get_qkv_format
21

22
from . import cpp_extensions as tex
23
24
25


class AttnBiasType(Enum):
26
27
28
29
30
    """
    NO_BIAS: Softmax is performed as softmax(scale * qk)
    PRE_SCALE_BIAS: Softmax is performed as softmax(scale * (qk + bias))
    POST_SCALE_BIAS: Softmax is performed as softmax(scale * qk + bias)
    """
31

32
33
34
35
36
37
    NO_BIAS = NVTE_Bias_Type.NVTE_NO_BIAS
    PRE_SCALE_BIAS = NVTE_Bias_Type.NVTE_PRE_SCALE_BIAS
    POST_SCALE_BIAS = NVTE_Bias_Type.NVTE_POST_SCALE_BIAS


class AttnMaskType(Enum):
38
39
40
41
42
43
    """
    NO_MASK: No attention mask is applied.
    PADDING_MASK: Indicates the presence of paddings at the end of each sequence.
    CAUSAL_MASK: An upper triangular mask is applied to the softmax inputs.
    PADDING_CAUSAL_MASK: A combination of both causal and padding masks.
    """
44

45
46
47
    NO_MASK = NVTE_Mask_Type.NVTE_NO_MASK
    PADDING_MASK = NVTE_Mask_Type.NVTE_PADDING_MASK
    CAUSAL_MASK = NVTE_Mask_Type.NVTE_CAUSAL_MASK
48
    PADDING_CAUSAL_MASK = NVTE_Mask_Type.NVTE_PADDING_CAUSAL_MASK
49
50
    CAUSAL_BOTTOM_RIGHT_MASK = NVTE_Mask_Type.NVTE_CAUSAL_BOTTOM_RIGHT_MASK
    PADDING_CAUSAL_BOTTOM_RIGHT_MASK = NVTE_Mask_Type.NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK
51

52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
    def is_causal(self):
        """Returns True if the mask is a causal mask"""
        return self in [
            AttnMaskType.CAUSAL_MASK,
            AttnMaskType.PADDING_CAUSAL_MASK,
            AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK,
            AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK,
        ]

    def is_padding(self):
        """Returns True if the mask includes padding"""
        return self in [
            AttnMaskType.PADDING_MASK,
            AttnMaskType.PADDING_CAUSAL_MASK,
            AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK,
        ]

    def is_bottom_right(self):
        """Returns True if the causal mask is calculated from the bottom-right section"""
        return self in [
            AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK,
            AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK,
        ]


class QKVFormat(Enum):
    """
    SBHD: q,k,v memory layout with [s, b, ..., h, d]
    BSHD: q,k,v memory layout with [b, s, ..., h, d]
    THD: q,k,v memory layout is same as BSHD, but allow multiple segments packed in a sequence.
    """

    SBHD = NVTE_QKV_Format.NVTE_SBHD
    BSHD = NVTE_QKV_Format.NVTE_BSHD
    THD = NVTE_QKV_Format.NVTE_THD

88

89
class QKVLayout(Enum):
90
91
92
93
94
95
96
97
98
99
    """
    BSHD Format:
        - BS3HD: q,k,v are interleave packed as a tensor with shape [b, s, 3, h, d].
        - BSHD_BS2HD: q with shape [b, s, h, d] and kv are interleaved with shape [b, s, 2, h, d].
        - BSHD_BSHD_BSHD: q,k,v are seperate tensors with shape [b, s, h, d]
    THD Format: Shape is same as BSHD layout but allow multiple segments packed in a sequence.
        - T3HD: q,k,v are interleave packed as a tensor with shape [b, s, 3, h, d].
        - THD_T2HD: q with shape [b, s, h, d] and kv are interleaved with shape [b, s, 2, h, d].
        - THD_THD_THD: q,k,v are seperate tensors with shape [b, s, h, d]
    """
100

101
102
    BS3HD = NVTE_QKV_Layout.NVTE_BS3HD
    BSHD_BS2HD = NVTE_QKV_Layout.NVTE_BSHD_BS2HD
103
    BSHD_BSHD_BSHD = NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD
104
105
106
107
    T3HD = NVTE_QKV_Layout.NVTE_T3HD
    THD_T2HD = NVTE_QKV_Layout.NVTE_THD_T2HD
    THD_THD_THD = NVTE_QKV_Layout.NVTE_THD_THD_THD

108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
    def get_qkv_format(self):
        """
        Return the corresponding qkv_format (BSHD, SBHD, THD)
        """
        return QKVFormat(nvte_get_qkv_format(self.value))

    def is_qkvpacked(self):
        """
        Return True if the query, key, value is packed
        """
        return self in [QKVLayout.BS3HD, QKVLayout.T3HD]

    def is_kvpacked(self):
        """
        Return True if the key, value is packed
        """
        return self in [QKVLayout.BSHD_BS2HD, QKVLayout.THD_T2HD]

    def is_separate(self):
        """
        Return True if the query, key, value are three separate tensors
        """
        return self in [QKVLayout.BSHD_BSHD_BSHD, QKVLayout.THD_THD_THD]

    def is_thd(self):
        """
        Return True if the layout belongs to THD
        """
        return self in [QKVLayout.T3HD, QKVLayout.THD_T2HD, QKVLayout.THD_THD_THD]
137

Reese Wang's avatar
Reese Wang committed
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
    def to_qkvpacked(self):
        """
        Return the corresponding qkvpacked format, useful when adjusting q, k, v layout
        """
        qkv_format = self.get_qkv_format()
        if qkv_format == QKVFormat.BSHD:
            return QKVLayout.BS3HD
        if qkv_format == QKVFormat.THD:
            return QKVLayout.T3HD
        raise ValueError(f"Unsupported {qkv_format=}")

    def to_kvpacked(self):
        """
        Return the corresponding kvpacked format, useful when adjusting q, k, v layout
        """
        qkv_format = self.get_qkv_format()
        if qkv_format == QKVFormat.BSHD:
            return QKVLayout.BSHD_BS2HD
        if qkv_format == QKVFormat.THD:
            return QKVLayout.THD_T2HD
        raise ValueError(f"Unsupported {qkv_format=}")

    def to_separate(self):
        """
        Return the corresponding separate format, useful when adjusting q, k, v layout
        """
        qkv_format = self.get_qkv_format()
        if qkv_format == QKVFormat.BSHD:
            return QKVLayout.BSHD_BSHD_BSHD
        if qkv_format == QKVFormat.THD:
            return QKVLayout.THD_THD_THD
        raise ValueError(f"Unsupported {qkv_format=}")

171

172
173
174
175
176
177
178
179
180
181
182
183
184
class CPStrategy(Enum):
    """Defines the context parallel strategies of Jax fused attention.

    DEFAULT: Default strategy will choose automatically if context parallel axis is sharded.
    ALL_GATHER: All-gather/reduce scatter implementation.
    RING: Ring attention implementation (https://arxiv.org/abs/2310.01889).
    """

    DEFAULT = 0
    ALL_GATHER = 1
    RING = 2


Reese Wang's avatar
Reese Wang committed
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
class ReorderStrategy(Enum):
    """
    Defines the tokens re-order strategy for context parallel load balancing for causal mask.

    - DualChunkSwap: This strategy splits each query into two chunks and do the mirror swap between
    GPUs. This is currently used for non-THD load balance. It requires the max_seqlens be the
    mulitple of 2 * cp_size.
      Examples:
      - Before reorder: GPU0: [0, 1, 2, 3]; GPU1: [4, 5, 6, 7]; GPU2: [8, 9, 10, 11]; GPU3: [12, 13, 14, 15];
      - After reorder: GPU0: [0, 1, 14, 15]; GPU1: [4, 5, 10, 11]; GPU2: [8, 9, 6, 7]; GPU3: [12, 13, 2, 3]

    - Striped: This strategy distributes the tokens in a striped (interleaved) manner across
      the sequence. This is currently used for THD load balance.
      Example: Consider 4 GPUs with seqlens=16.
      - Before reorder: GPU0: [0, 1, 2, 3]; GPU1: [4, 5, 6, 7]; ...; GPU3: [12, 13, 14, 15]
      - After reorder: GPU0: [0, 4, 8, 12]; GPU1: [1, 5, 9, 13]; ...; GPU3: [3, 7, 11, 15]
    """

    DualChunkSwap = 0
    Striped = 1


207
def make_swa_mask(
208
209
    segment_pos_q: jnp.ndarray,
    segment_pos_kv: jnp.ndarray,
210
211
212
213
    window_size: Optional[Tuple[int, int]] = None,
    dtype: jax.typing.DTypeLike = jnp.float32,
):
    """
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
    Generate a sliding window mask (1 = attend, 0 = masked).

    Args:
        segment_pos_q (jnp.ndarray):
            Query positions within each segment. For example, a batch with segment_ids =
            [[1, 1, 1, 2, 2, 2, 2, 2]] yields segment_pos =
            [[0, 1, 2, 0, 1, 2, 3, 4]].
        segment_pos_kv (jnp.ndarray):
            Key/value positions within each segment.
        window_size (Optional[Tuple[int, int]], optional):
            Sliding window size for local attention, where query at position i attends to keys
            in [i - window_size[0], i + window_size[1]] inclusive. A negative number means an
            infinite window; None means no sliding window.
            Defaults to None.
        dtype (jax.typing.DTypeLike, optional):
            Mask data type. Defaults to jnp.float32.

    Returns:
        jnp.ndarray:
            The mask with shape [b, 1, max_seqlen_q, max_seqlen_kv].
234
    """
235
236
    if window_size is not None:
        left_window, right_window = window_size
237
    else:
238
239
240
241
242
243
244
245
        left_window = right_window = jnp.inf
    left_window = jnp.inf if left_window < 0 else left_window
    right_window = jnp.inf if right_window < 0 else right_window
    pos_q = jnp.expand_dims(segment_pos_q, axis=-1)
    pos_kv = jnp.expand_dims(segment_pos_kv, axis=-2)
    inv_swa_mask = (pos_kv >= pos_q - left_window) & (pos_kv <= pos_q + right_window)
    inv_swa_mask = jnp.expand_dims(inv_swa_mask, axis=-3)
    return inv_swa_mask.astype(dtype)
246
247


248
249
250
251
252
253
def canonicalize_attn_mask_type(attn_mask_type: str):
    """Convert string attn_mask_type to AttnMaskType
    TE-JAX currently fall back to the padding version kernels for the libraries integration.
    The overhead between padding and non-padding version should be small.
    However, we will lease this limitation in the near feature.
    """
254
    match attn_mask_type:
255
        case "no_mask":
256
            return AttnMaskType.NO_MASK
257
        case "padding":
258
            return AttnMaskType.PADDING_MASK
259
        case "causal":
260
            return AttnMaskType.CAUSAL_MASK
261
262
        case "causal_bottom_right" | "bottom_right_causal":
            return AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK
263
        case "padding_causal" | "causal_padding":
264
            return AttnMaskType.PADDING_CAUSAL_MASK
265
266
267
268
269
270
271
        case (
            "padding_causal_bottom_right"
            | "causal_padding_bottom_right"
            | "bottom_right_causal_padding"
            | "bottom_right_padding_causal"
        ):
            return AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK
272
    raise ValueError(
273
274
275
        f"Unsupported {attn_mask_type=}, supported attn_mask_type={{'no_mask', 'padding', 'causal',"
        " 'padding_causal', 'causal_padding', 'causal_bottom_right',"
        " 'padding_causal_bottom_right'}"
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
    )


def is_fused_attn_kernel_available(
    q_dtype,
    kv_dtype,
    qkv_layout,
    attn_bias_type,
    attn_mask_type,
    dropout_probability,
    q_num_heads,
    kv_num_heads,
    q_max_seqlen,
    kv_max_seqlen,
    head_dim,
291
    window_size: Optional[Tuple[int, int]] = None,
292
):
293
    """
294
    To check whether the fused attention kernel is supported
295
    """
296
297
298
299
300

    def make_helper(attn_mask_type):
        return tex.FusedAttnHelper(
            q_dtype,
            kv_dtype,
Reese Wang's avatar
Reese Wang committed
301
302
303
            qkv_layout,
            attn_bias_type,
            attn_mask_type,
304
305
306
307
308
309
310
311
312
            dropout_probability,
            q_num_heads,
            kv_num_heads,
            q_max_seqlen,
            kv_max_seqlen,
            head_dim,
            (-1, -1) if window_size is None else window_size,
        )

313
    return make_helper(attn_mask_type).is_fused_attn_kernel_available()
314
315


316
def _obtain_batch_and_max_seqlen(qkv, qkv_layout):
317
318
319
320
321
322
323
324
325
326
327
328
329
330
    if qkv_layout.is_qkvpacked():
        assert len(qkv) == 1, f"qkv must be (qkvpacked,) with {qkv_layout=}"
        batch, q_max_seqlen, *_ = qkv[0].shape
        kv_max_seqlen = q_max_seqlen
    elif qkv_layout.is_kvpacked():
        assert len(qkv) == 2, f"qkv must be (query, kvpacked) with {qkv_layout=}"
        batch, q_max_seqlen, *_ = qkv[0].shape
        kv_max_seqlen = qkv[1].shape[1]
    elif qkv_layout.is_separate():
        assert len(qkv) == 3, f"qkv must be (query, key, value) with {qkv_layout=}"
        batch, q_max_seqlen, *_ = qkv[0].shape
        kv_max_seqlen = qkv[1].shape[1]
    else:
        raise ValueError(f"Unsupported {qkv_layout=}")
331
    return batch, q_max_seqlen, kv_max_seqlen
332

333

Reese Wang's avatar
Reese Wang committed
334
def reorder_causal_load_balancing(tensor, strategy: ReorderStrategy, cp_size: int, seq_dim: int):
335
    """Reorders a tensor for load balancing the compute of causal attention."""
Reese Wang's avatar
Reese Wang committed
336
337
338
339
340
    if strategy == ReorderStrategy.DualChunkSwap:
        return tex.attention.reorder_causal_dual_chunk_swap(tensor, cp_size, seq_dim, False)
    if strategy == ReorderStrategy.Striped:
        return tex.attention.reorder_causal_striped(tensor, cp_size, seq_dim, False)
    raise ValueError(f"Unsupported {strategy=}")
341
342


Reese Wang's avatar
Reese Wang committed
343
344
345
def inverse_reorder_causal_load_balancing(
    tensor, strategy: ReorderStrategy, cp_size: int, seq_dim: int
):
346
    """Inverse operation of `reorder_causal_load_balancing`."""
Reese Wang's avatar
Reese Wang committed
347
348
349
350
351
    if strategy == ReorderStrategy.DualChunkSwap:
        return tex.attention.reorder_causal_dual_chunk_swap(tensor, cp_size, seq_dim, True)
    if strategy == ReorderStrategy.Striped:
        return tex.attention.reorder_causal_striped(tensor, cp_size, seq_dim, True)
    raise ValueError(f"Unsupported {strategy=}")
352
353


354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
def _get_seqlens_and_offsets(segment_ids, max_segments_per_seq):
    # bincount map with 0s
    bincount_vmap = jax.vmap(partial(jnp.bincount, length=max_segments_per_seq + 1))
    seqlens_with_zero = bincount_vmap(segment_ids.astype(jnp.int32))
    seqlens = seqlens_with_zero[..., 1:]

    def _find_offsets(x):
        same_as_previous = jnp.logical_and(x[..., 1:] != x[..., :-1], x[..., 1:] != 0)
        first_column = x[..., :1] != 0
        same_as_previous = jnp.hstack((first_column, same_as_previous))
        return jax.vmap(partial(jnp.argwhere, size=(max_segments_per_seq + 1), fill_value=-1))(
            same_as_previous
        ).squeeze(-1)

    offsets = _find_offsets(segment_ids)
    return seqlens, offsets


def _mask_to_seqlens_offset(mask, max_segments_per_seq):
    assert mask.shape[1] == 1
    row_ids = mask.squeeze(axis=1).max(axis=-1)
    q_seqlen, q_offset = _get_seqlens_and_offsets(row_ids, max_segments_per_seq)
    col_ids = mask.squeeze(axis=1).max(axis=-2)
    kv_seqlen, kv_offset = _get_seqlens_and_offsets(col_ids, max_segments_per_seq)
    return q_seqlen, q_offset, kv_seqlen, kv_offset


381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
def _fast_causal_adjust_seqlen_and_offsets(
    segment_pos_q, q_len, q_offset, segment_pos_kv, kv_len, kv_offset
):
    # The assumption is that for any segment tokens respect causal ordering except at the ends
    # of the segment. This allows us to tweak the length and offset by only looking at the start
    # and end tokens between segments.
    is_active_segment = jnp.logical_and(q_len > 0, kv_len > 0)

    q_seq_id_start = jnp.take(segment_pos_q, q_offset[..., :-1], fill_value=-1)
    kv_seq_id_start = jnp.take(segment_pos_kv, kv_offset[..., :-1], fill_value=-1)
    skip_start_token = jnp.logical_and(kv_seq_id_start > q_seq_id_start, is_active_segment).astype(
        jnp.int32
    )

    q_len -= skip_start_token
    q_offset += jnp.insert(skip_start_token, skip_start_token.shape[-1], 0, axis=-1)

    q_seq_id_end = jnp.take(segment_pos_q, q_offset[..., 1:] - 1, fill_value=-1)
    kv_seq_id_end = jnp.take(segment_pos_kv, kv_offset[..., 1:] - 1, fill_value=-1)
    skip_end_token = jnp.logical_and(kv_seq_id_end > q_seq_id_end, is_active_segment).astype(
        jnp.int32
    )

    kv_len -= skip_end_token

    return q_len, kv_len, q_offset, kv_offset


def _segment_ids_pos_to_seqlens_offsets_fast_causal_path(
    segment_ids_q, segment_ids_kv, segment_pos_q, segment_pos_kv, max_segments_per_seq
):
    q_len, q_offset = _get_seqlens_and_offsets(segment_ids_q, max_segments_per_seq)
    kv_len, kv_offset = _get_seqlens_and_offsets(segment_ids_kv, max_segments_per_seq)
    return _fast_causal_adjust_seqlen_and_offsets(
        segment_pos_q, q_len, q_offset, segment_pos_kv, kv_len, kv_offset
    )


419
420
421
422
423
424
425
426
427
def _segment_ids_pos_to_seqlens_offsets(
    segment_ids_q,
    segment_ids_kv,
    segment_pos_q,
    segment_pos_kv,
    attn_mask_type,
    window_size,
    max_segments_per_seq,
):
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
    # TODO(mgoldfarb-nvidia): Consider an opt-in for arbitrary masking if needed here.
    # Computing the full mask is expensive due to quadratic expansion of Q * KV masking.

    # Assumptions for cudnn causal mask correctness.
    # 1. Segments are monotonic [4 4 4 0 0 5 5 5 6 6 0 0]
    # 2. No intra-segment padding, only inter-segment paddding allowed
    # 3. Only start or end token within a segment may violate the causal order relationship
    #        1 5 9     0 4 8 10    0 4 8
    #    0             x           x
    #    4   x         x x         x x
    #    8   x x       x x x       x x x
    #
    # This fast path avoids expanding the mask to Q * KV matrix and instead allows us to
    # examine only O(Q+KV) elements.
    if attn_mask_type.is_causal() and window_size is None or window_size == (-1, -1):
        return _segment_ids_pos_to_seqlens_offsets_fast_causal_path(
            segment_ids_q, segment_ids_kv, segment_pos_q, segment_pos_kv, max_segments_per_seq
        )

447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
    # (1 = attend, 0 = masked)
    segment_mask = make_attention_mask(
        segment_ids_q,
        segment_ids_kv,
        jnp.equal,
    )
    segment_mask_with_id = make_attention_mask(
        segment_ids_q,
        segment_ids_kv,
        lambda x, y: jnp.equal(x, y) * x,
    )
    attn_mask = segment_mask
    if attn_mask_type.is_causal():
        causal_mask = make_attention_mask(
            segment_pos_q,
            segment_pos_kv,
            jnp.greater_equal,
        )
        attn_mask = jnp.logical_and(segment_mask, causal_mask)

    swa_mask = make_swa_mask(segment_pos_q, segment_pos_kv, window_size, dtype=jnp.bool)
    attn_mask = jnp.logical_and(attn_mask, swa_mask)

    attn_mask_with_id = jnp.where(attn_mask, segment_mask_with_id, 0)
    q_seqlen, q_offset, kv_seqlen, kv_offset = _mask_to_seqlens_offset(
        attn_mask_with_id, max_segments_per_seq
    )
    return q_seqlen, kv_seqlen, q_offset, kv_offset


def _segment_ids_to_seqlens(segment_ids_q, segment_ids_kv, attn_mask_type):
    # convert the mask to seqlens, mask doesn't support ragged offsets
    if not attn_mask_type.is_padding():
        q_max_seqlen = segment_ids_q.shape[-1]
        kv_max_seqlen = segment_ids_kv.shape[-1]
        q_seq_lens = jnp.full_like(q_max_seqlen, q_max_seqlen, dtype=jnp.int32)
        kv_seq_lens = jnp.full_like(kv_max_seqlen, kv_max_seqlen, dtype=jnp.int32)
    else:
        q_seq_lens = jnp.sum(segment_ids_q, axis=-1).astype(jnp.int32)
        kv_seq_lens = jnp.sum(segment_ids_kv, axis=-1).astype(jnp.int32)
    return q_seq_lens, kv_seq_lens


@jax.tree_util.register_pytree_node_class
class SequenceDescriptor:
    """A class to descibe the sequences with flexible initialization.
    - SequenceDescriptor.from_seqlens
      For non-THD (non-packed) cases, where each batch has only 1 sequence.
    - SequenceDescriptor.from_seqlens_and_offsets
      For THD (packed) cases, where each batch may have not only 1 sequence.
    - SequenceDescriptor.from_segment_ids_and_pos
      Experimental feature for THD (packed) cases with context parallelism.
    """

    seqlens: Optional[Tuple[jnp.ndarray, jnp.ndarray]]
    seq_offsets: Optional[Tuple[jnp.ndarray, jnp.ndarray]]
    segment_ids: Optional[Tuple[jnp.ndarray, jnp.ndarray]]
    segment_pos: Optional[Tuple[jnp.ndarray, jnp.ndarray]]

    def __init__(self, seqlens=None, seq_offsets=None, segment_ids=None, segment_pos=None):
        """
        Initialize to Tuple(jnp.zeros, jnp.zeros) because the primitive only accepts pure jax array
        """
        self.seqlens = (jnp.zeros(0), jnp.zeros(0)) if seqlens is None else seqlens
        self.seq_offsets = (jnp.zeros(0), jnp.zeros(0)) if seq_offsets is None else seq_offsets
        self.segment_ids = (jnp.zeros(0), jnp.zeros(0)) if segment_ids is None else segment_ids
        self.segment_pos = (jnp.zeros(0), jnp.zeros(0)) if segment_pos is None else segment_pos

    def tree_flatten(self):
        """
        Flatten method to register as a pytree node
        """
        return ((self.seqlens, self.seq_offsets, self.segment_ids, self.segment_pos), None)

    @classmethod
    def tree_unflatten(cls, aux_data, children):
        """
        Unflatten method to register as a pytree node
        """
        del aux_data
        return cls(*children)

    def get_seqlens_and_offsets(
        self, attn_mask_type, qkv_layout, window_size, max_segments_per_seq
    ):
        """
        Acquire the seqlens/offsets for cuDNN backend
        """
        q_segment_ids, kv_segment_ids = self.segment_ids
        q_segment_pos, kv_segment_pos = self.segment_pos
        assert q_segment_ids.shape == q_segment_pos.shape
        assert kv_segment_ids.shape == kv_segment_pos.shape
        # No segment_ids/segment_pos
        if q_segment_ids.size + kv_segment_ids.size == 0:
            return self.seqlens, self.seq_offsets

        if qkv_layout.is_thd():
            q_seqlens, kv_seqlens, q_offsets, kv_offsets = _segment_ids_pos_to_seqlens_offsets(
                q_segment_ids,
                kv_segment_ids,
                q_segment_pos,
                kv_segment_pos,
                attn_mask_type,
                window_size,
                max_segments_per_seq,
            )
        else:
            q_seqlens, kv_seqlens = _segment_ids_to_seqlens(
                q_segment_ids,
                kv_segment_ids,
                attn_mask_type,
            )
            q_offsets = kv_offsets = jnp.zeros(0)
        return (q_seqlens, kv_seqlens), (q_offsets, kv_offsets)

    @classmethod
    def _expand_to_pair(
        cls, value: Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]
    ) -> Tuple[jnp.ndarray, jnp.ndarray]:
        """
        Internal helper to ensure a single value expands into a pair (q, kv).
        """
        if isinstance(value, tuple):
            if len(value) != 2:
                raise ValueError("Input tuple must have exactly 2 elements.")
            return value

        if isinstance(value, jnp.ndarray):
            return value, value  # Duplicate for q=kv case

        raise TypeError(
            "Expected a jax.numpy.ndarray or a tuple of two jax.numpy.ndarray, "
            f"but got {type(value).__name__}."
        )

    @classmethod
    def from_seqlens(
        cls,
        seqlens: Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]],
    ) -> SequenceDescriptor:
        """
        Factory method for inputs with sequence lengths only (non-THD).
        Args:
            seqlens(Tuple(jnp.ndarray, jnp.ndarray)) = (q_seqlens, kv_seqlens):
                - q_seqlens (jnp.ndarray):
                  Sequence lengths for the query, with shape [batch].
                - kv_seqlen (jnp.ndarray):
                  Sequence lengths for the key and value, with shape [batch].
        Return:
            A SequenceDescriptor with only seqlens initialized.
        """
        q_seqlens, kv_seqlens = cls._expand_to_pair(seqlens)
        return cls(seqlens=(q_seqlens, kv_seqlens))

    @classmethod
    def from_seqlens_and_offsets(
        cls,
        seqlens: Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]],
        seq_offsets: Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]],
    ) -> SequenceDescriptor:
        """
        Factory method for inputs with sequence lengths and offsets (THD).
        Args:
            seqlens(Tuple(jnp.ndarray, jnp.ndarray)) = (q_seqlens, kv_seqlens):
                - q_seqlens (jnp.ndarray):
                  Sequence lengths for the query, with shape [batch, max_seqlen].
                  Unused positions are padded with -1.
                - kv_seqlen (jnp.ndarray):
                  Sequence lengths for the key and value, with shape [batch, max_seqlen].
                  Unused positions are padded with -1.
            seq_offsets(Tuple(jnp.ndarray, jnp.ndarray)) = (q_offsets, kv_offsets)
                - q_seq_offsets (jnp.ndarray):
                  The offsets in the sequence dim for the query, with shape [batch, max_seqlen + 1].
                  Unused positions are padded with -1.
                - kv_seq_offsets (jnp.ndarray):
                  The offsets in the sequence dim for the query, with shape [batch, max_seqlen + 1].
                  Unused positions are padded with -1.
        Return:
            A SequenceDescriptor with seqlens/seq_offsets initialized.
        """
        q_seqlens, kv_seqlens = cls._expand_to_pair(seqlens)
        q_offsets, kv_offsets = cls._expand_to_pair(seq_offsets)
        return cls(seqlens=(q_seqlens, kv_seqlens), seq_offsets=(q_offsets, kv_offsets))

    @classmethod
    def from_segment_ids_and_pos(
        cls,
        segment_ids: Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]],
        segment_pos: Optional[Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]] = None,
    ) -> SequenceDescriptor:
        """
        Experimental factory method for inputs with segment IDs and optional positions. (THD)
        Args:
            segment_ids(Tuple(jnp.ndarray, jnp.ndarray)) = (q_segment_ids, kv_segment_ids):
                - q_segment_ids (jnp.ndarray):
                  Query segment ids start with 1, with shape [batch, max_seqlen].
                  0s are treated as paddings.
                - kv_segment_ids (jnp.ndarray):
                  Key, value segment ids start with 1, with shape [batch, max_seqlen].
                  0s are treated as paddings.
            segment_pos(Tuple(jnp.ndarray, jnp.ndarray)) = (q_segment_pos, kv_segment_pos)
                - q_segment_pos (jnp.ndarray):
                  The position inside each segment for query, with shape [batch, max_seqlen].
                - kv_segment_pos (jnp.ndarray):
                  The position inside each segment for key, value, with shape [batch, max_seqlen].
        Return:
            A SequenceDescriptor with segment_ids/segment_pos initialized.
        """
        q_seg_ids, kv_seg_ids = cls._expand_to_pair(segment_ids)

        if segment_pos is not None:
            segment_pos = cls._expand_to_pair(segment_pos)
        else:

            def generate_default_pos(segment_ids):
                seqlen = segment_ids.shape[-1]
                return jnp.broadcast_to(jnp.arange(seqlen), segment_ids.shape)

            q_seg_pos = generate_default_pos(q_seg_ids)
            kv_seg_pos = generate_default_pos(kv_seg_ids)
            segment_pos = (q_seg_pos, kv_seg_pos)

        return cls(
            segment_ids=(q_seg_ids, kv_seg_ids),
            segment_pos=segment_pos,
        )


def _legacy_fused_attn(
676
677
678
679
    qkv: Tuple[jnp.ndarray, ...],
    bias: Optional[jnp.ndarray],
    mask: Optional[jnp.ndarray],
    seed: Optional[jnp.ndarray],
680
681
    attn_bias_type: AttnBiasType,
    attn_mask_type: AttnMaskType,
682
    qkv_layout: QKVLayout,
683
684
685
    scaling_factor: float,
    dropout_probability: float,
    is_training: bool,
686
    window_size: Optional[Tuple[int, int]] = None,
687
    context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT,
688
689
    context_parallel_causal_load_balanced: bool = False,
    context_parallel_axis: str = "",
690
):
691
    """
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
    Perform non-THD (non-packed) cuDNN fused attention.

    This function implements the following formula:
        BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2
    Args:
        qkv (Tuple[jnp.ndarray, ...]): A tuple containing query, key, and value tensors.
        It supports three formats:
            - `(qkv_packed,)`: For interleaved QKV packed format, typically used when query, key,
              and value have the same shape (e.g., self-attention).
            - `(query, kv_packed)`: For separate query and KV packed format, typically used when
              query has a different shape (e.g., cross-attention).
            - `(query, key, value)`: For separate query, key, and value tensors.
        bias (Optional[jnp.ndarray]): An optional bias tensor to be added to the attention scores.
        mask (Optional[jnp.ndarray]):
            An optional mask tensor to mask out the attention scores, `True` means mask out.
            Intra-sequence padding is not valid. The padded tokens can only on the right-most.
            Otherwise the results will be wrong.
        seed (Optional[jnp.ndarray]): Optional random seed for dropout.
Reese Wang's avatar
Reese Wang committed
710
711
712
        attn_bias_type (AttnBiasType): Type of attention bias.
        attn_mask_type (AttnMaskType): Type of attention mask.
        qkv_layout (QKVLayout): Layout of the QKV tensors.
713
714
715
        scaling_factor (float): Scaling factor for the attention scores.
        dropout_probability (float): Dropout probability to apply during attention.
        is_training (bool): Flag indicating whether the model is in training mode.
716
        window_size (Optional[Tuple[int, int]]): Sliding window size.
717
718
719
        context_parallel_causal_load_balanced (bool):
            Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
        context_parallel_axis (str): The name of the context parallel axis.
720
721
    Returns:
        (jnp.ndarray): The output tensor from the fused attention.
722
    """
723
    assert (
724
        not qkv_layout.is_thd()
725
726
727
728
    ), "Please use transformer_engine.jax.attention.fused_attn_thd for THD format."

    # Check inputs qkv
    match qkv_layout:
Reese Wang's avatar
Reese Wang committed
729
        case QKVLayout.BS3HD:
730
            assert len(qkv) == 1, f"qkv=(packed_qkv,) is expected with {qkv_layout=} but got {qkv=}"
Reese Wang's avatar
Reese Wang committed
731
        case QKVLayout.BSHD_BS2HD:
732
733
734
            assert (
                len(qkv) == 2
            ), f"qkv=(query, packed_kv) is expected with {qkv_layout=} but got {qkv=}"
Reese Wang's avatar
Reese Wang committed
735
        case QKVLayout.BSHD_BSHD_BSHD:
736
737
738
            assert (
                len(qkv) == 3
            ), f"qkv=(query, key, value) is expected with {qkv_layout=} but got {qkv=}"
Reese Wang's avatar
Reese Wang committed
739
740
        case _:
            raise ValueError(f"Unknown {qkv_layout=}")
741
742

    # convert the mask to seqlens, mask doesn't support ragged offsets
743
    if not attn_mask_type.is_padding():
744
745
746
        batch, q_max_seqlen, kv_max_seqlen = _obtain_batch_and_max_seqlen(qkv, qkv_layout)
        q_seq_lens = jnp.full((batch,), q_max_seqlen, dtype=jnp.int32)
        kv_seq_lens = jnp.full((batch,), kv_max_seqlen, dtype=jnp.int32)
zlsh80826's avatar
zlsh80826 committed
747
    else:
748
        assert mask is not None
749
        mask = jnp.logical_not(mask)
750
        q_seq_lens = jnp.sum(mask, axis=-2, dtype=jnp.int32)[..., 0, 0]
751
        if attn_mask_type == AttnMaskType.PADDING_MASK:
752
            kv_seq_lens = jnp.sum(mask, axis=-1, dtype=jnp.int32)[..., 0, 0]
753
754
        else:
            # When mask is causal, the actual seqlen is not the last row, use max to find it
755
            kv_seq_lens = jnp.max(jnp.sum(mask, axis=-1, dtype=jnp.int32), axis=(-1, -2))
756

757
758
    output = _fused_attn(
        qkv,
759
        bias,
760
        SequenceDescriptor.from_seqlens((q_seq_lens, kv_seq_lens)),
761
        seed,
762
763
764
        attn_bias_type=attn_bias_type,
        attn_mask_type=attn_mask_type,
        qkv_layout=qkv_layout,
765
766
767
        scaling_factor=scaling_factor,
        dropout_probability=dropout_probability,
        is_training=is_training,
768
        max_segments_per_seq=1,
769
        window_size=window_size,
770
        context_parallel_strategy=context_parallel_strategy,
771
772
        context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
        context_parallel_axis=context_parallel_axis,
773
    )
774

775
    return output
776
777


778
779
780
781
782
783
784
785
def fused_attn_thd(
    qkv: Tuple[jnp.ndarray, ...],
    bias: Optional[jnp.ndarray],
    q_seq_lens: jnp.ndarray,
    kv_seq_lens: jnp.ndarray,
    q_seq_offsets: jnp.ndarray,
    kv_seq_offsets: jnp.ndarray,
    seed: Optional[jnp.ndarray],
786
787
    attn_bias_type: AttnBiasType,
    attn_mask_type: AttnMaskType,
788
    qkv_layout: QKVLayout,
789
790
791
    scaling_factor: float,
    dropout_probability: float,
    is_training: bool,
792
    max_segments_per_seq: int = 1,
793
    window_size: Optional[Tuple[int, int]] = None,
794
    context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT,
795
796
    context_parallel_causal_load_balanced: bool = False,
    context_parallel_axis: str = "",
797
):
798
    """
799
    Deprecated THD fused attn, please use fusd_attn with SequenceDescriptor
800
    """
801
802
803
804
805
    warnings.warn(
        "fused_attn_thd is deprecated, please use fused_attn with SequenceDescriptor",
        DeprecationWarning,
    )

806
    assert (
807
        qkv_layout.is_thd()
808
809
810
811
    ), "Please use transformer_engine.jax.attention.fused_attn for non-THD format."

    # Check inputs qkv
    match qkv_layout:
Reese Wang's avatar
Reese Wang committed
812
        case QKVLayout.T3HD:
813
            assert len(qkv) == 1, f"qkv=(packed_qkv,) is expected with {qkv_layout=} but got {qkv=}"
Reese Wang's avatar
Reese Wang committed
814
        case QKVLayout.THD_T2HD:
815
816
817
            assert (
                len(qkv) == 2
            ), f"qkv=(query, packed_kv) is expected with {qkv_layout=} but got {qkv=}"
Reese Wang's avatar
Reese Wang committed
818
        case QKVLayout.THD_THD_THD:
819
820
821
            assert (
                len(qkv) == 3
            ), f"qkv=(query, key, value) is expected with {qkv_layout=} but got {qkv=}"
Reese Wang's avatar
Reese Wang committed
822
823
        case _:
            raise ValueError(f"Unknown {qkv_layout=}")
824
825
826
827
828
829

    batch, q_max_seqlen, kv_max_seqlen = _obtain_batch_and_max_seqlen(qkv, qkv_layout)
    assert q_seq_lens.shape == (batch, q_max_seqlen)
    assert kv_seq_lens.shape == (batch, kv_max_seqlen)
    assert q_seq_offsets.shape == (batch, q_max_seqlen + 1)
    assert kv_seq_offsets.shape == (batch, kv_max_seqlen + 1)
830

831
    output = _fused_attn(
832
        qkv,
833
        bias,
834
835
836
        SequenceDescriptor.from_seqlens_and_offsets(
            (q_seq_lens, kv_seq_lens), (q_seq_offsets, kv_seq_offsets)
        ),
837
838
839
        seed,
        attn_bias_type=attn_bias_type,
        attn_mask_type=attn_mask_type,
840
        qkv_layout=qkv_layout,
841
842
843
        scaling_factor=scaling_factor,
        dropout_probability=dropout_probability,
        is_training=is_training,
844
        max_segments_per_seq=max_segments_per_seq,
845
        window_size=window_size,
846
        context_parallel_strategy=context_parallel_strategy,
847
848
        context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
        context_parallel_axis=context_parallel_axis,
849
    )
850
851
852
853

    return output


854
@partial(jax.custom_vjp, nondiff_argnums=(4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14))
855
def _fused_attn(
856
857
    qkv: Tuple[jnp.ndarray, ...],
    bias: Optional[jnp.ndarray],
858
859
    sequence_descriptor: SequenceDescriptor,
    seed: Optional[jnp.ndarray],
860
861
    attn_bias_type: AttnBiasType,
    attn_mask_type: AttnMaskType,
862
    qkv_layout: QKVLayout,
863
864
865
    scaling_factor: float,
    dropout_probability: float,
    is_training: bool,
866
    max_segments_per_seq: int,
867
    window_size: Optional[Tuple[int, int]],
868
    context_parallel_strategy: CPStrategy,
869
870
    context_parallel_causal_load_balanced: bool,
    context_parallel_axis: str,
871
872
):
    output, _ = _fused_attn_fwd_rule(
873
        qkv,
874
        bias,
875
        sequence_descriptor,
876
877
878
        seed,
        attn_bias_type,
        attn_mask_type,
879
        qkv_layout,
880
881
882
        scaling_factor,
        dropout_probability,
        is_training,
883
        max_segments_per_seq,
884
        window_size,
885
        context_parallel_strategy,
886
887
        context_parallel_causal_load_balanced,
        context_parallel_axis,
888
    )
889
890
891
    return output


892
def _fused_attn_fwd_rule(
893
    qkv,
894
    bias,
895
    sequence_descriptor,
896
897
898
    seed,
    attn_bias_type,
    attn_mask_type,
899
    qkv_layout,
900
901
902
    scaling_factor,
    dropout_probability,
    is_training,
903
    max_segments_per_seq,
904
    window_size,
905
    context_parallel_strategy,
906
907
    context_parallel_causal_load_balanced,
    context_parallel_axis,
908
909
):
    output, softmax_aux, rng_state = tex.fused_attn_fwd(
910
        qkv,
911
        bias,
912
        sequence_descriptor,
913
        seed,
Reese Wang's avatar
Reese Wang committed
914
915
916
        attn_bias_type=attn_bias_type,
        attn_mask_type=attn_mask_type,
        qkv_layout=qkv_layout,
917
918
919
        scaling_factor=scaling_factor,
        dropout_probability=dropout_probability,
        is_training=is_training,
920
        max_segments_per_seq=max_segments_per_seq,
921
        window_size=window_size,
922
        context_parallel_strategy=context_parallel_strategy,
923
924
        context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
        context_parallel_axis=context_parallel_axis,
925
926
927
928
929
    )
    output = checkpoint_name(output, "context")
    softmax_aux = checkpoint_name(softmax_aux, "context")
    rng_state = checkpoint_name(rng_state, "context")
    return output, (
930
        qkv,
931
        bias,
932
        sequence_descriptor,
933
934
935
936
937
938
939
        softmax_aux,
        rng_state,
        output,
    )


def _fused_attn_bwd_rule(
940
941
942
943
944
945
946
    attn_bias_type,
    attn_mask_type,
    qkv_layout,
    scaling_factor,
    dropout_probability,
    is_training,
    max_segments_per_seq,
947
    window_size,
948
    context_parallel_strategy,
949
950
    context_parallel_causal_load_balanced,
    context_parallel_axis,
951
952
    ctx,
    dz,
953
):
954
955
956
    (
        qkv,
        bias,
957
        sequence_descriptor,
958
959
960
961
962
963
        softmax_aux,
        rng_state,
        output,
    ) = ctx
    grad_qkv, grad_bias = tex.fused_attn_bwd(
        qkv,
964
965
966
967
968
        bias,
        softmax_aux,
        rng_state,
        output,
        dz,
969
        sequence_descriptor,
Reese Wang's avatar
Reese Wang committed
970
971
972
        attn_bias_type=attn_bias_type,
        attn_mask_type=attn_mask_type,
        qkv_layout=qkv_layout,
973
974
975
        scaling_factor=scaling_factor,
        dropout_probability=dropout_probability,
        is_training=is_training,
976
        max_segments_per_seq=max_segments_per_seq,
977
        window_size=window_size,
978
        context_parallel_strategy=context_parallel_strategy,
979
980
        context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
        context_parallel_axis=context_parallel_axis,
981
    )
982
983
    if attn_bias_type == AttnBiasType.NO_BIAS:
        grad_bias = None
984
985
986
987
988
989
    return (
        grad_qkv,
        grad_bias,
        None,
        None,
    )
990
991
992


_fused_attn.defvjp(_fused_attn_fwd_rule, _fused_attn_bwd_rule)
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027


def fused_attn(
    qkv: Tuple[jnp.ndarray, ...],
    bias: Optional[jnp.ndarray],
    sequence_descriptor: SequenceDescriptor,
    seed: Optional[jnp.ndarray],
    attn_bias_type: AttnBiasType,
    attn_mask_type: AttnMaskType,
    qkv_layout: QKVLayout,
    scaling_factor: float,
    dropout_probability: float,
    is_training: bool,
    max_segments_per_seq: int = 1,
    window_size: Optional[Tuple[int, int]] = None,
    context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT,
    context_parallel_causal_load_balanced: bool = False,
    context_parallel_axis: str = "",
):
    """
    Perform cuDNN fused attention.

    This function implements the following formula:
        BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2
    Args:
        qkv (Tuple[jnp.ndarray, ...]): A tuple containing query, key, and value tensors.
        It supports three formats:
            - `(qkv_packed,)`: For interleaved QKV packed format, typically used when query, key,
              and value have the same shape (e.g., self-attention).
            - `(query, kv_packed)`: For separate query and KV packed format, typically used when
              query has a different shape (e.g., cross-attention).
            - `(query, key, value)`: For separate query, key, and value tensors.
        bias (Optional[jnp.ndarray]): An optional bias tensor to be added to the attention scores.
        sequence_descriptor (SequenceDescriptor): Descriptor for how to describe the sequence.
        seed (Optional[jnp.ndarray]): Optional random seed for dropout.
Reese Wang's avatar
Reese Wang committed
1028
1029
1030
        attn_bias_type (AttnBiasType): Type of attention bias.
        attn_mask_type (AttnMaskType): Type of attention mask.
        qkv_layout (QKVLayout): Layout of the QKV tensors.
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
        scaling_factor (float): Scaling factor for the attention scores.
        dropout_probability (float): Dropout probability to apply during attention.
        is_training (bool): Flag indicating whether the model is in training mode.
        max_segments_per_seq (int):
            Indicating the maximum number of segments inside a sequence. This parameter is to
            constrain the limit usage and need to be static during the e2e training. The XLA compile
            time and memory consumption is proportional to `max_segments_per_seq`.
        window_size (Optional[Tuple[int, int]]):
            Sliding window size.
        context_parallel_causal_load_balanced (bool):
            Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
        context_parallel_axis (str): The name of the context parallel axis.
    Returns:
        (jnp.ndarray): The output tensor from the fused attention.

    Examples (non-THD, also known as non-packed):
        >>> #  q_segment_ids = [[1, 1, 1, 0], [1, 1, 0, 0]], 0 means padded tokens
        >>> # kv_segment_ids = [[1, 0, 0, 0], [1, 1, 0, 0]], 0 means padded tokens
        >>> b, s, h, d = 2, 4, 12, 64
        >>> qkv = jnp.zeros((b, s, 3, h, d), dtype=jnp.bfloat16)
        >>> q_seq_lens = jnp.asarray([3, 2])
        >>> kv_seq_lens = jnp.asarray([1, 2])
        >>> sequence_desc = SequenceDescriptor.from_seqlens(
                seqlens=(q_seq_lens, kv_seq_lens))
        >>> out = fused_attn((qkv,), None, sequence_desc, None,
                             AttnBiasType.NO_BIAS, AttnMaskType.PADDING_CAUSAL_MASK,
                             QKVLayout.BS3HD, 0.125, 0, True, 3)

    Examples (THD, also known as packed):
        >>> # segment_ids = [[1, 1, 2, 3], [1, 1, 2, 0]], 0 means padded tokens
        >>> # segment_pos = [[0, 1, 0, 0], [0, 1, 0, 1]]
        >>> b, s, h, d = 2, 4, 12, 64
        >>> qkv = jnp.zeros((b, s, 3, h, d), dtype=jnp.bfloat16)
        >>> # 3 segments in first seq, 2 segments in second seq
        >>> q_seq_lens = kv_seq_lens = jnp.asarray([[2, 1, 1, -1], [2, 1, -1, -1]])
        >>> # seq_offsets need to include the end offset of the last segments
        >>> q_seq_offsets = kv_seq_offsets = jnp.asarray([[0, 2, 3, 4, -1], [0, 2, 3, -1, -1]])
        >>> sequence_desc = SequenceDescriptor.from_seqlens_and_offsets(
                seqlens=(q_seq_lens, kv_seq_lens),
                seq_offsets=(q_seq_offsets, kv_seq_offsets))
        >>> out = fused_attn((qkv,), None, sequence_desc, None,
                             AttnBiasType.NO_BIAS, AttnMaskType.PADDING_CAUSAL_MASK,
                             QKVLayout.T3HD, 0.125, 0, True, 3)
    """
1075
    if sequence_descriptor is None or isinstance(sequence_descriptor, jnp.ndarray):
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
        warnings.warn(
            "Pass mask to fused_attn is deprecated, please use SequenceDescriptor instead. "
            + "See help(transformer_engine.jax.attention.SequenceDescriptor) for details.",
            DeprecationWarning,
        )
        if max_segments_per_seq != 1:
            raise ValueError("Passing mask is only supported for non-THD case.")
        return _legacy_fused_attn(
            qkv,
            bias,
            sequence_descriptor,
            seed,
            attn_bias_type=attn_bias_type,
            attn_mask_type=attn_mask_type,
            qkv_layout=qkv_layout,
            scaling_factor=scaling_factor,
            dropout_probability=dropout_probability,
            is_training=is_training,
            window_size=window_size,
            context_parallel_strategy=context_parallel_strategy,
            context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
            context_parallel_axis=context_parallel_axis,
        )
    output = _fused_attn(
        qkv,
        bias,
        sequence_descriptor,
        seed,
        attn_bias_type=attn_bias_type,
        attn_mask_type=attn_mask_type,
        qkv_layout=qkv_layout,
        scaling_factor=scaling_factor,
        dropout_probability=dropout_probability,
        is_training=is_training,
        max_segments_per_seq=max_segments_per_seq,
        window_size=window_size,
        context_parallel_strategy=context_parallel_strategy,
        context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
        context_parallel_axis=context_parallel_axis,
    )

    return output