attention.py 53.1 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.
"""JAX multi-head attention modules"""
5
from __future__ import annotations
6
7
from enum import Enum
from functools import partial
8
9
10
from typing import Optional, Tuple, Union
import warnings

11
from jax.ad_checkpoint import checkpoint_name
12
13
import jax
import jax.numpy as jnp
14
from flax.linen import make_attention_mask
15

16
17
18
19
20
from transformer_engine_jax import NVTE_Bias_Type
from transformer_engine_jax import NVTE_Mask_Type
from transformer_engine_jax import NVTE_QKV_Layout
from transformer_engine_jax import NVTE_QKV_Format
from transformer_engine_jax import nvte_get_qkv_format
21
from transformer_engine_jax import NVTE_Softmax_Type
22

23
from . import cpp_extensions as tex
24
25
26


class AttnBiasType(Enum):
27
28
29
30
31
    """
    NO_BIAS: Softmax is performed as softmax(scale * qk)
    PRE_SCALE_BIAS: Softmax is performed as softmax(scale * (qk + bias))
    POST_SCALE_BIAS: Softmax is performed as softmax(scale * qk + bias)
    """
32

33
34
35
36
37
38
    NO_BIAS = NVTE_Bias_Type.NVTE_NO_BIAS
    PRE_SCALE_BIAS = NVTE_Bias_Type.NVTE_PRE_SCALE_BIAS
    POST_SCALE_BIAS = NVTE_Bias_Type.NVTE_POST_SCALE_BIAS


class AttnMaskType(Enum):
39
40
41
42
43
44
    """
    NO_MASK: No attention mask is applied.
    PADDING_MASK: Indicates the presence of paddings at the end of each sequence.
    CAUSAL_MASK: An upper triangular mask is applied to the softmax inputs.
    PADDING_CAUSAL_MASK: A combination of both causal and padding masks.
    """
45

46
47
48
    NO_MASK = NVTE_Mask_Type.NVTE_NO_MASK
    PADDING_MASK = NVTE_Mask_Type.NVTE_PADDING_MASK
    CAUSAL_MASK = NVTE_Mask_Type.NVTE_CAUSAL_MASK
49
    PADDING_CAUSAL_MASK = NVTE_Mask_Type.NVTE_PADDING_CAUSAL_MASK
50
51
    CAUSAL_BOTTOM_RIGHT_MASK = NVTE_Mask_Type.NVTE_CAUSAL_BOTTOM_RIGHT_MASK
    PADDING_CAUSAL_BOTTOM_RIGHT_MASK = NVTE_Mask_Type.NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK
52

53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
    def is_causal(self):
        """Returns True if the mask is a causal mask"""
        return self in [
            AttnMaskType.CAUSAL_MASK,
            AttnMaskType.PADDING_CAUSAL_MASK,
            AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK,
            AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK,
        ]

    def is_padding(self):
        """Returns True if the mask includes padding"""
        return self in [
            AttnMaskType.PADDING_MASK,
            AttnMaskType.PADDING_CAUSAL_MASK,
            AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK,
        ]

    def is_bottom_right(self):
        """Returns True if the causal mask is calculated from the bottom-right section"""
        return self in [
            AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK,
            AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK,
        ]


78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
class AttnSoftmaxType(Enum):
    """
    VANILLA_SOFTMAX: S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1),
    OFF_BY_ONE_SOFTMAX: S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)),
    LEARNABLE_SOFTMAX: S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)),
    where alpha is a learnable parameter in shape [H].
    """

    VANILLA_SOFTMAX = NVTE_Softmax_Type.NVTE_VANILLA_SOFTMAX
    OFF_BY_ONE_SOFTMAX = NVTE_Softmax_Type.NVTE_OFF_BY_ONE_SOFTMAX
    LEARNABLE_SOFTMAX = NVTE_Softmax_Type.NVTE_LEARNABLE_SOFTMAX

    @classmethod
    def from_str(cls, softmax_type: str) -> "AttnSoftmaxType":
        """Convert string to AttnSoftmaxType: 'vanilla', 'off_by_one', or 'learnable'."""
        softmax_type_map = {
            "vanilla": cls.VANILLA_SOFTMAX,
            "off_by_one": cls.OFF_BY_ONE_SOFTMAX,
            "learnable": cls.LEARNABLE_SOFTMAX,
        }
        result = softmax_type_map.get(softmax_type)
        if result is None:
            raise ValueError(
                f"Unknown softmax_type: {softmax_type}. "
                "Valid options: 'vanilla', 'off_by_one', 'learnable'"
            )
        return result


107
108
109
110
111
112
113
114
115
116
117
class QKVFormat(Enum):
    """
    SBHD: q,k,v memory layout with [s, b, ..., h, d]
    BSHD: q,k,v memory layout with [b, s, ..., h, d]
    THD: q,k,v memory layout is same as BSHD, but allow multiple segments packed in a sequence.
    """

    SBHD = NVTE_QKV_Format.NVTE_SBHD
    BSHD = NVTE_QKV_Format.NVTE_BSHD
    THD = NVTE_QKV_Format.NVTE_THD

118

119
class QKVLayout(Enum):
120
121
122
123
124
125
126
127
128
129
    """
    BSHD Format:
        - BS3HD: q,k,v are interleave packed as a tensor with shape [b, s, 3, h, d].
        - BSHD_BS2HD: q with shape [b, s, h, d] and kv are interleaved with shape [b, s, 2, h, d].
        - BSHD_BSHD_BSHD: q,k,v are seperate tensors with shape [b, s, h, d]
    THD Format: Shape is same as BSHD layout but allow multiple segments packed in a sequence.
        - T3HD: q,k,v are interleave packed as a tensor with shape [b, s, 3, h, d].
        - THD_T2HD: q with shape [b, s, h, d] and kv are interleaved with shape [b, s, 2, h, d].
        - THD_THD_THD: q,k,v are seperate tensors with shape [b, s, h, d]
    """
130

131
132
    BS3HD = NVTE_QKV_Layout.NVTE_BS3HD
    BSHD_BS2HD = NVTE_QKV_Layout.NVTE_BSHD_BS2HD
133
    BSHD_BSHD_BSHD = NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD
134
135
136
137
    T3HD = NVTE_QKV_Layout.NVTE_T3HD
    THD_T2HD = NVTE_QKV_Layout.NVTE_THD_T2HD
    THD_THD_THD = NVTE_QKV_Layout.NVTE_THD_THD_THD

138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
    def get_qkv_format(self):
        """
        Return the corresponding qkv_format (BSHD, SBHD, THD)
        """
        return QKVFormat(nvte_get_qkv_format(self.value))

    def is_qkvpacked(self):
        """
        Return True if the query, key, value is packed
        """
        return self in [QKVLayout.BS3HD, QKVLayout.T3HD]

    def is_kvpacked(self):
        """
        Return True if the key, value is packed
        """
        return self in [QKVLayout.BSHD_BS2HD, QKVLayout.THD_T2HD]

    def is_separate(self):
        """
        Return True if the query, key, value are three separate tensors
        """
        return self in [QKVLayout.BSHD_BSHD_BSHD, QKVLayout.THD_THD_THD]

    def is_thd(self):
        """
        Return True if the layout belongs to THD
        """
        return self in [QKVLayout.T3HD, QKVLayout.THD_T2HD, QKVLayout.THD_THD_THD]
167

Reese Wang's avatar
Reese Wang committed
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
    def to_qkvpacked(self):
        """
        Return the corresponding qkvpacked format, useful when adjusting q, k, v layout
        """
        qkv_format = self.get_qkv_format()
        if qkv_format == QKVFormat.BSHD:
            return QKVLayout.BS3HD
        if qkv_format == QKVFormat.THD:
            return QKVLayout.T3HD
        raise ValueError(f"Unsupported {qkv_format=}")

    def to_kvpacked(self):
        """
        Return the corresponding kvpacked format, useful when adjusting q, k, v layout
        """
        qkv_format = self.get_qkv_format()
        if qkv_format == QKVFormat.BSHD:
            return QKVLayout.BSHD_BS2HD
        if qkv_format == QKVFormat.THD:
            return QKVLayout.THD_T2HD
        raise ValueError(f"Unsupported {qkv_format=}")

    def to_separate(self):
        """
        Return the corresponding separate format, useful when adjusting q, k, v layout
        """
        qkv_format = self.get_qkv_format()
        if qkv_format == QKVFormat.BSHD:
            return QKVLayout.BSHD_BSHD_BSHD
        if qkv_format == QKVFormat.THD:
            return QKVLayout.THD_THD_THD
        raise ValueError(f"Unsupported {qkv_format=}")

201

202
203
204
205
206
207
208
209
210
211
212
213
214
class CPStrategy(Enum):
    """Defines the context parallel strategies of Jax fused attention.

    DEFAULT: Default strategy will choose automatically if context parallel axis is sharded.
    ALL_GATHER: All-gather/reduce scatter implementation.
    RING: Ring attention implementation (https://arxiv.org/abs/2310.01889).
    """

    DEFAULT = 0
    ALL_GATHER = 1
    RING = 2


Reese Wang's avatar
Reese Wang committed
215
216
217
218
219
220
class ReorderStrategy(Enum):
    """
    Defines the tokens re-order strategy for context parallel load balancing for causal mask.

    - DualChunkSwap: This strategy splits each query into two chunks and do the mirror swap between
    GPUs. This is currently used for non-THD load balance. It requires the max_seqlens be the
221
    multiple of 2 * cp_size.
Reese Wang's avatar
Reese Wang committed
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
      Examples:
      - Before reorder: GPU0: [0, 1, 2, 3]; GPU1: [4, 5, 6, 7]; GPU2: [8, 9, 10, 11]; GPU3: [12, 13, 14, 15];
      - After reorder: GPU0: [0, 1, 14, 15]; GPU1: [4, 5, 10, 11]; GPU2: [8, 9, 6, 7]; GPU3: [12, 13, 2, 3]

    - Striped: This strategy distributes the tokens in a striped (interleaved) manner across
      the sequence. This is currently used for THD load balance.
      Example: Consider 4 GPUs with seqlens=16.
      - Before reorder: GPU0: [0, 1, 2, 3]; GPU1: [4, 5, 6, 7]; ...; GPU3: [12, 13, 14, 15]
      - After reorder: GPU0: [0, 4, 8, 12]; GPU1: [1, 5, 9, 13]; ...; GPU3: [3, 7, 11, 15]
    """

    DualChunkSwap = 0
    Striped = 1


237
def make_swa_mask(
238
239
    segment_pos_q: jnp.ndarray,
    segment_pos_kv: jnp.ndarray,
240
241
    window_size: Optional[Tuple[int, int]] = None,
    dtype: jax.typing.DTypeLike = jnp.float32,
242
243
    segment_ids_q: jnp.ndarray = None,
    segment_ids_kv: jnp.ndarray = None,
244
245
):
    """
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
    Generate a sliding window mask (1 = attend, 0 = masked).

    Args:
        segment_pos_q (jnp.ndarray):
            Query positions within each segment. For example, a batch with segment_ids =
            [[1, 1, 1, 2, 2, 2, 2, 2]] yields segment_pos =
            [[0, 1, 2, 0, 1, 2, 3, 4]].
        segment_pos_kv (jnp.ndarray):
            Key/value positions within each segment.
        window_size (Optional[Tuple[int, int]], optional):
            Sliding window size for local attention, where query at position i attends to keys
            in [i - window_size[0], i + window_size[1]] inclusive. A negative number means an
            infinite window; None means no sliding window.
            Defaults to None.
        dtype (jax.typing.DTypeLike, optional):
            Mask data type. Defaults to jnp.float32.
262
263
264
265
        segment_ids_q (jnp.ndarray):
            Query segment id that each token belongs to
        segment_ids_kv (jnp.ndarray):
            Key/value segment id that each token belongs to
266
267
268
269

    Returns:
        jnp.ndarray:
            The mask with shape [b, 1, max_seqlen_q, max_seqlen_kv].
270
    """
271
272
    if window_size is not None:
        left_window, right_window = window_size
273
    else:
274
275
276
277
278
        left_window = right_window = jnp.inf
    left_window = jnp.inf if left_window < 0 else left_window
    right_window = jnp.inf if right_window < 0 else right_window
    pos_q = jnp.expand_dims(segment_pos_q, axis=-1)
    pos_kv = jnp.expand_dims(segment_pos_kv, axis=-2)
279
280
281
282
283
284
285
286
287
288
289
290
    # For Bottom Right Causal Mask (BRCM)
    if segment_ids_q is not None and segment_ids_kv is not None:
        run_length_q = run_length_fill(segment_ids_q)
        run_length_kv = run_length_fill(segment_ids_kv)
        run_length_q_exp = jnp.expand_dims(run_length_q, axis=-1)
        run_length_kv_exp = jnp.expand_dims(run_length_kv, axis=-2)
        bottom_right_inv_swa_mask = (
            run_length_q_exp - pos_q + left_window >= run_length_kv_exp - pos_kv
        )
        bottom_right_inv_swa_mask = jnp.expand_dims(bottom_right_inv_swa_mask, axis=-3)
        return bottom_right_inv_swa_mask.astype(dtype)
    # All other cases other than BRCM
291
292
293
    inv_swa_mask = (pos_kv >= pos_q - left_window) & (pos_kv <= pos_q + right_window)
    inv_swa_mask = jnp.expand_dims(inv_swa_mask, axis=-3)
    return inv_swa_mask.astype(dtype)
294
295


296
297
298
299
300
301
def canonicalize_attn_mask_type(attn_mask_type: str):
    """Convert string attn_mask_type to AttnMaskType
    TE-JAX currently fall back to the padding version kernels for the libraries integration.
    The overhead between padding and non-padding version should be small.
    However, we will lease this limitation in the near feature.
    """
302
    match attn_mask_type:
303
        case "no_mask":
304
            return AttnMaskType.NO_MASK
305
        case "padding":
306
            return AttnMaskType.PADDING_MASK
307
        case "causal":
308
            return AttnMaskType.CAUSAL_MASK
309
310
        case "causal_bottom_right" | "bottom_right_causal":
            return AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK
311
        case "padding_causal" | "causal_padding":
312
            return AttnMaskType.PADDING_CAUSAL_MASK
313
314
315
316
317
318
319
        case (
            "padding_causal_bottom_right"
            | "causal_padding_bottom_right"
            | "bottom_right_causal_padding"
            | "bottom_right_padding_causal"
        ):
            return AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK
320
    raise ValueError(
321
322
323
        f"Unsupported {attn_mask_type=}, supported attn_mask_type={{'no_mask', 'padding', 'causal',"
        " 'padding_causal', 'causal_padding', 'causal_bottom_right',"
        " 'padding_causal_bottom_right'}"
324
325
326
327
    )


def is_fused_attn_kernel_available(
328
    is_training,
329
330
331
332
333
    q_dtype,
    kv_dtype,
    qkv_layout,
    attn_bias_type,
    attn_mask_type,
334
    softmax_type,
335
336
337
338
339
    dropout_probability,
    q_num_heads,
    kv_num_heads,
    q_max_seqlen,
    kv_max_seqlen,
340
341
    head_dim_qk,
    head_dim_v,
342
    window_size: Optional[Tuple[int, int]] = None,
343
):
344
    """
345
    To check whether the fused attention kernel is supported
346
    """
347
    window_size_tuple = (-1, -1) if window_size is None else window_size
348
349
350

    def make_helper(attn_mask_type):
        return tex.FusedAttnHelper(
351
            is_training,
352
353
            q_dtype,
            kv_dtype,
Reese Wang's avatar
Reese Wang committed
354
355
356
            qkv_layout,
            attn_bias_type,
            attn_mask_type,
357
            softmax_type,
358
359
360
361
362
            dropout_probability,
            q_num_heads,
            kv_num_heads,
            q_max_seqlen,
            kv_max_seqlen,
363
364
            head_dim_qk,
            head_dim_v,
365
            window_size_tuple,
366
367
        )

368
    return make_helper(attn_mask_type).is_fused_attn_kernel_available()
369
370


371
def _obtain_batch_and_max_seqlen(qkv, qkv_layout):
372
373
374
375
376
377
378
379
380
381
382
383
384
385
    if qkv_layout.is_qkvpacked():
        assert len(qkv) == 1, f"qkv must be (qkvpacked,) with {qkv_layout=}"
        batch, q_max_seqlen, *_ = qkv[0].shape
        kv_max_seqlen = q_max_seqlen
    elif qkv_layout.is_kvpacked():
        assert len(qkv) == 2, f"qkv must be (query, kvpacked) with {qkv_layout=}"
        batch, q_max_seqlen, *_ = qkv[0].shape
        kv_max_seqlen = qkv[1].shape[1]
    elif qkv_layout.is_separate():
        assert len(qkv) == 3, f"qkv must be (query, key, value) with {qkv_layout=}"
        batch, q_max_seqlen, *_ = qkv[0].shape
        kv_max_seqlen = qkv[1].shape[1]
    else:
        raise ValueError(f"Unsupported {qkv_layout=}")
386
    return batch, q_max_seqlen, kv_max_seqlen
387

388

389
390
391
def reorder_causal_load_balancing(
    tensor, strategy: ReorderStrategy, cp_size: int, seq_dim: int, stripe_size: int | None = None
):
392
    """Reorders a tensor for load balancing the compute of causal attention."""
Reese Wang's avatar
Reese Wang committed
393
    if strategy == ReorderStrategy.DualChunkSwap:
394
395
396
397
398
        if stripe_size is not None:
            raise ValueError(
                f"Incorrect value for CP dual chunk reordering {stripe_size=}. stripe_size must be"
                " None"
            )
Reese Wang's avatar
Reese Wang committed
399
400
        return tex.attention.reorder_causal_dual_chunk_swap(tensor, cp_size, seq_dim, False)
    if strategy == ReorderStrategy.Striped:
401
402
403
404
405
406
407
408
409
410
411
412
        # stripe_size > 1 is only supported for CP+THD+AG+Striped>1+SWA
        # stripe_size = 128 is recommended for CP+THD+AG+Striped>1+SWA
        if stripe_size is not None and stripe_size <= 0:
            raise ValueError(
                f"Incorrect value for CP striped reordering {stripe_size=}. stripe_size must be a"
                " positive integer"
            )
        # Supporting old API defaults of stripe_size=1
        effective_stripe_size = 1 if stripe_size is None else stripe_size
        return tex.attention.reorder_causal_striped(
            tensor, cp_size, seq_dim, False, effective_stripe_size
        )
Reese Wang's avatar
Reese Wang committed
413
    raise ValueError(f"Unsupported {strategy=}")
414
415


Reese Wang's avatar
Reese Wang committed
416
def inverse_reorder_causal_load_balancing(
417
    tensor, strategy: ReorderStrategy, cp_size: int, seq_dim: int, stripe_size: int | None = None
Reese Wang's avatar
Reese Wang committed
418
):
419
    """Inverse operation of `reorder_causal_load_balancing`."""
Reese Wang's avatar
Reese Wang committed
420
    if strategy == ReorderStrategy.DualChunkSwap:
421
422
423
424
425
        if stripe_size is not None:
            raise ValueError(
                f"Incorrect value for CP dual chunk reordering {stripe_size=}. stripe_size must be"
                " None"
            )
Reese Wang's avatar
Reese Wang committed
426
427
        return tex.attention.reorder_causal_dual_chunk_swap(tensor, cp_size, seq_dim, True)
    if strategy == ReorderStrategy.Striped:
428
429
430
431
432
433
434
435
436
437
438
439
        # stripe_size > 1 is only supported for CP+THD+AG+Striped>1+SWA
        # stripe_size = 128 is recommended for CP+THD+AG+Striped>1+SWA
        if stripe_size is not None and stripe_size <= 0:
            raise ValueError(
                f"Incorrect value for CP reordering {stripe_size=}. stripe_size must be a positive"
                " integer"
            )
        # Supporting old API defaults of stripe_size=1
        effective_stripe_size = 1 if stripe_size is None else stripe_size
        return tex.attention.reorder_causal_striped(
            tensor, cp_size, seq_dim, True, effective_stripe_size
        )
Reese Wang's avatar
Reese Wang committed
440
    raise ValueError(f"Unsupported {strategy=}")
441
442


443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
def _get_seqlens_and_offsets(segment_ids, max_segments_per_seq):
    # bincount map with 0s
    bincount_vmap = jax.vmap(partial(jnp.bincount, length=max_segments_per_seq + 1))
    seqlens_with_zero = bincount_vmap(segment_ids.astype(jnp.int32))
    seqlens = seqlens_with_zero[..., 1:]

    def _find_offsets(x):
        same_as_previous = jnp.logical_and(x[..., 1:] != x[..., :-1], x[..., 1:] != 0)
        first_column = x[..., :1] != 0
        same_as_previous = jnp.hstack((first_column, same_as_previous))
        return jax.vmap(partial(jnp.argwhere, size=(max_segments_per_seq + 1), fill_value=-1))(
            same_as_previous
        ).squeeze(-1)

    offsets = _find_offsets(segment_ids)
    return seqlens, offsets


def _mask_to_seqlens_offset(mask, max_segments_per_seq):
    assert mask.shape[1] == 1
    row_ids = mask.squeeze(axis=1).max(axis=-1)
    q_seqlen, q_offset = _get_seqlens_and_offsets(row_ids, max_segments_per_seq)
    col_ids = mask.squeeze(axis=1).max(axis=-2)
    kv_seqlen, kv_offset = _get_seqlens_and_offsets(col_ids, max_segments_per_seq)
    return q_seqlen, q_offset, kv_seqlen, kv_offset


470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
def _fast_causal_adjust_seqlen_and_offsets(
    segment_pos_q, q_len, q_offset, segment_pos_kv, kv_len, kv_offset
):
    # The assumption is that for any segment tokens respect causal ordering except at the ends
    # of the segment. This allows us to tweak the length and offset by only looking at the start
    # and end tokens between segments.
    is_active_segment = jnp.logical_and(q_len > 0, kv_len > 0)

    q_seq_id_start = jnp.take(segment_pos_q, q_offset[..., :-1], fill_value=-1)
    kv_seq_id_start = jnp.take(segment_pos_kv, kv_offset[..., :-1], fill_value=-1)
    skip_start_token = jnp.logical_and(kv_seq_id_start > q_seq_id_start, is_active_segment).astype(
        jnp.int32
    )

    q_len -= skip_start_token
    q_offset += jnp.insert(skip_start_token, skip_start_token.shape[-1], 0, axis=-1)

    q_seq_id_end = jnp.take(segment_pos_q, q_offset[..., 1:] - 1, fill_value=-1)
    kv_seq_id_end = jnp.take(segment_pos_kv, kv_offset[..., 1:] - 1, fill_value=-1)
    skip_end_token = jnp.logical_and(kv_seq_id_end > q_seq_id_end, is_active_segment).astype(
        jnp.int32
    )

    kv_len -= skip_end_token

    return q_len, kv_len, q_offset, kv_offset


def _segment_ids_pos_to_seqlens_offsets_fast_causal_path(
    segment_ids_q, segment_ids_kv, segment_pos_q, segment_pos_kv, max_segments_per_seq
):
    q_len, q_offset = _get_seqlens_and_offsets(segment_ids_q, max_segments_per_seq)
    kv_len, kv_offset = _get_seqlens_and_offsets(segment_ids_kv, max_segments_per_seq)
    return _fast_causal_adjust_seqlen_and_offsets(
        segment_pos_q, q_len, q_offset, segment_pos_kv, kv_len, kv_offset
    )


508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
def run_length_fill_flattened(segment_ids_flattened) -> jnp.ndarray:
    """
    Returns an array of run-lengths of the flattened segment ids
    """
    # Example for run_length_fill_flattened:
    # Input segment_ids_flattened:       [[1 1 2 2 2 0 3 0 4 4 4 4 4 0 0 0], [1 0 0 2 2 2 0 0 3 3 4 4 4 4 0 0]]
    # run_ids:                           [[0 0 1 1 1 2 3 4 5 5 5 5 5 6 6 6], [0 1 1 2 2 2 3 3 4 4 5 5 5 5 6 6]]
    # counts:                            [[2 3 1 1 1 5 3 0 0 0 0 0 0 0 0 0], [1 2 3 2 2 4 2 0 0 0 0 0 0 0 0 0]]
    # Returns segment_ids_run_length_1d: [[2 2 3 3 3 0 1 0 5 5 5 5 5 0 0 0], [1 0 0 3 3 3 0 0 2 2 4 4 4 4 0 0]]
    boundary = jnp.concatenate(
        [jnp.broadcast_to(True, (1,)), segment_ids_flattened[1:] != segment_ids_flattened[:-1]]
    )
    run_ids = jnp.cumsum(boundary) - 1
    # Each element could, in worst case, start a run
    max_runs = segment_ids_flattened.shape[-1]
    counts = jnp.bincount(run_ids, length=max_runs)
    # Fill in the missing values
    segment_ids_run_length_1d = counts[run_ids]
    segment_ids_run_length_1d = jnp.where(segment_ids_flattened == 0, 0, segment_ids_run_length_1d)
    return segment_ids_run_length_1d


def run_length_fill(segment_ids) -> jnp.ndarray:
    """
    Returns an array of run-lengths of the segment ids, with shape preserved
    """
    # Example for run_length_fill:
    # Input segment_ids:  [[1 1 2 2 2 0 3 0 4 4 4 4 4 0 0 0], [1 0 0 2 2 2 0 0 3 3 4 4 4 4 0 0]]
    # Returns run length: [[2 2 3 3 3 0 1 0 5 5 5 5 5 0 0 0], [1 0 0 3 3 3 0 0 2 2 4 4 4 4 0 0]]
    # Flatten all dimension except the last one prior to executing vmap run length
    orig_shape = segment_ids.shape
    segment_ids_flat = segment_ids.reshape(-1, orig_shape[-1])
    run_length_segment_id_shape = jax.vmap(run_length_fill_flattened, in_axes=0)(segment_ids_flat)
    return run_length_segment_id_shape.reshape(orig_shape)


544
545
546
547
548
549
550
551
552
def _segment_ids_pos_to_seqlens_offsets(
    segment_ids_q,
    segment_ids_kv,
    segment_pos_q,
    segment_pos_kv,
    attn_mask_type,
    window_size,
    max_segments_per_seq,
):
553
554
555
556
557
558
559
560
561
562
563
564
565
566
    # TODO(mgoldfarb-nvidia): Consider an opt-in for arbitrary masking if needed here.
    # Computing the full mask is expensive due to quadratic expansion of Q * KV masking.

    # Assumptions for cudnn causal mask correctness.
    # 1. Segments are monotonic [4 4 4 0 0 5 5 5 6 6 0 0]
    # 2. No intra-segment padding, only inter-segment paddding allowed
    # 3. Only start or end token within a segment may violate the causal order relationship
    #        1 5 9     0 4 8 10    0 4 8
    #    0             x           x
    #    4   x         x x         x x
    #    8   x x       x x x       x x x
    #
    # This fast path avoids expanding the mask to Q * KV matrix and instead allows us to
    # examine only O(Q+KV) elements.
567
568
569
570
571

    # For seqlens and seqoffsets calculations, the intermediate(temp) attn_mask creation
    # using the segment ids and pos along with mask type (causal or brcm) is sufficient.
    # It does not need to involve SW for this mask's creation

572
573
574
575
    # TODO(KshitijLakhani): Try exercising the fast path for BRCM as well
    if (attn_mask_type.is_causal() and window_size is None) or (
        window_size == (-1, -1) and not attn_mask_type.is_bottom_right()
    ):
576
577
578
579
        return _segment_ids_pos_to_seqlens_offsets_fast_causal_path(
            segment_ids_q, segment_ids_kv, segment_pos_q, segment_pos_kv, max_segments_per_seq
        )

580
581
582
583
584
585
586
587
588
589
590
    # (1 = attend, 0 = masked)
    segment_mask = make_attention_mask(
        segment_ids_q,
        segment_ids_kv,
        jnp.equal,
    )
    segment_mask_with_id = make_attention_mask(
        segment_ids_q,
        segment_ids_kv,
        lambda x, y: jnp.equal(x, y) * x,
    )
591
    # TE JAX Attn expects the THD segments to have q_token <= kv_tokens so that a correct cross-attn type BRCM can be applied
592
    attn_mask = segment_mask
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
    if attn_mask_type.is_bottom_right():
        run_length_out_q = run_length_fill(segment_ids_q)
        run_length_out_kv = run_length_fill(segment_ids_kv)
        # Example for brcm:
        # run_length_out_q:  [3 3 3 0 4 4 4 4]
        # segment_pos_q:     [0 1 2 3 0 1 2 3]
        # segment_ids_q:     [1 1 1 0 2 2 2 2]
        # run_length_out_kv: [4 4 4 4 0 0 10 10 10 10 10 10 10 10 10 10]
        # segment_pos_kv:    [0 1 2 3 4 5 0 1 2 3 4 5 6 7 8 9]
        # segment_ids_kv:    [1 1 1 1 0 0 2 2 2 2 2 2 2 2 2 2]
        # brcm:            [[[1 1 0 0 0 0 1 1 1 1 1 1 1 1 0 0]
        #                    [1 1 1 0 0 0 1 1 1 1 1 1 1 1 1 0]
        #                    [1 1 1 1 0 0 1 1 1 1 1 1 1 1 1 1]
        #                    [1 1 1 1 0 0 1 1 1 1 1 1 1 1 1 1]
        #                    [1 0 0 0 0 0 1 1 1 1 1 1 1 0 0 0]
        #                    [1 1 0 0 0 0 1 1 1 1 1 1 1 1 0 0]
        #                    [1 1 1 0 0 0 1 1 1 1 1 1 1 1 1 0]
        #                    [1 1 1 1 0 0 1 1 1 1 1 1 1 1 1 1]]]
        # attn_mask(noswa):[[[1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
        #                    [1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0]
        #                    [1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0]
        #                    [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
        #                    [0 0 0 0 0 0 1 1 1 1 1 1 1 0 0 0]
        #                    [0 0 0 0 0 0 1 1 1 1 1 1 1 1 0 0]
        #                    [0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 0]
        #                    [0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1]]]
        bottom_right_causal_mask = make_attention_mask(
            run_length_out_q - segment_pos_q,
            run_length_out_kv - segment_pos_kv,
            jnp.less_equal,
        )
        attn_mask = jnp.logical_and(segment_mask, bottom_right_causal_mask)
    elif attn_mask_type.is_causal():
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
        causal_mask = make_attention_mask(
            segment_pos_q,
            segment_pos_kv,
            jnp.greater_equal,
        )
        attn_mask = jnp.logical_and(segment_mask, causal_mask)

    attn_mask_with_id = jnp.where(attn_mask, segment_mask_with_id, 0)
    q_seqlen, q_offset, kv_seqlen, kv_offset = _mask_to_seqlens_offset(
        attn_mask_with_id, max_segments_per_seq
    )
    return q_seqlen, kv_seqlen, q_offset, kv_offset


def _segment_ids_to_seqlens(segment_ids_q, segment_ids_kv, attn_mask_type):
    # convert the mask to seqlens, mask doesn't support ragged offsets
    if not attn_mask_type.is_padding():
        q_max_seqlen = segment_ids_q.shape[-1]
        kv_max_seqlen = segment_ids_kv.shape[-1]
        q_seq_lens = jnp.full_like(q_max_seqlen, q_max_seqlen, dtype=jnp.int32)
        kv_seq_lens = jnp.full_like(kv_max_seqlen, kv_max_seqlen, dtype=jnp.int32)
    else:
        q_seq_lens = jnp.sum(segment_ids_q, axis=-1).astype(jnp.int32)
        kv_seq_lens = jnp.sum(segment_ids_kv, axis=-1).astype(jnp.int32)
    return q_seq_lens, kv_seq_lens


@jax.tree_util.register_pytree_node_class
class SequenceDescriptor:
655
    """A class to describe the sequences with flexible initialization.
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
    - SequenceDescriptor.from_seqlens
      For non-THD (non-packed) cases, where each batch has only 1 sequence.
    - SequenceDescriptor.from_seqlens_and_offsets
      For THD (packed) cases, where each batch may have not only 1 sequence.
    - SequenceDescriptor.from_segment_ids_and_pos
      Experimental feature for THD (packed) cases with context parallelism.
    """

    seqlens: Optional[Tuple[jnp.ndarray, jnp.ndarray]]
    seq_offsets: Optional[Tuple[jnp.ndarray, jnp.ndarray]]
    segment_ids: Optional[Tuple[jnp.ndarray, jnp.ndarray]]
    segment_pos: Optional[Tuple[jnp.ndarray, jnp.ndarray]]

    def __init__(self, seqlens=None, seq_offsets=None, segment_ids=None, segment_pos=None):
        """
        Initialize to Tuple(jnp.zeros, jnp.zeros) because the primitive only accepts pure jax array
        """
        self.seqlens = (jnp.zeros(0), jnp.zeros(0)) if seqlens is None else seqlens
        self.seq_offsets = (jnp.zeros(0), jnp.zeros(0)) if seq_offsets is None else seq_offsets
        self.segment_ids = (jnp.zeros(0), jnp.zeros(0)) if segment_ids is None else segment_ids
        self.segment_pos = (jnp.zeros(0), jnp.zeros(0)) if segment_pos is None else segment_pos

    def tree_flatten(self):
        """
        Flatten method to register as a pytree node
        """
        return ((self.seqlens, self.seq_offsets, self.segment_ids, self.segment_pos), None)

    @classmethod
    def tree_unflatten(cls, aux_data, children):
        """
        Unflatten method to register as a pytree node
        """
        del aux_data
        return cls(*children)

    def get_seqlens_and_offsets(
        self, attn_mask_type, qkv_layout, window_size, max_segments_per_seq
    ):
        """
        Acquire the seqlens/offsets for cuDNN backend
        """
        q_segment_ids, kv_segment_ids = self.segment_ids
        q_segment_pos, kv_segment_pos = self.segment_pos
        assert q_segment_ids.shape == q_segment_pos.shape
        assert kv_segment_ids.shape == kv_segment_pos.shape
        # No segment_ids/segment_pos
        if q_segment_ids.size + kv_segment_ids.size == 0:
            return self.seqlens, self.seq_offsets

        if qkv_layout.is_thd():
            q_seqlens, kv_seqlens, q_offsets, kv_offsets = _segment_ids_pos_to_seqlens_offsets(
                q_segment_ids,
                kv_segment_ids,
                q_segment_pos,
                kv_segment_pos,
                attn_mask_type,
                window_size,
                max_segments_per_seq,
            )
        else:
            q_seqlens, kv_seqlens = _segment_ids_to_seqlens(
                q_segment_ids,
                kv_segment_ids,
                attn_mask_type,
            )
            q_offsets = kv_offsets = jnp.zeros(0)
        return (q_seqlens, kv_seqlens), (q_offsets, kv_offsets)

    @classmethod
    def _expand_to_pair(
        cls, value: Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]
    ) -> Tuple[jnp.ndarray, jnp.ndarray]:
        """
        Internal helper to ensure a single value expands into a pair (q, kv).
        """
        if isinstance(value, tuple):
            if len(value) != 2:
                raise ValueError("Input tuple must have exactly 2 elements.")
            return value

        if isinstance(value, jnp.ndarray):
            return value, value  # Duplicate for q=kv case

        raise TypeError(
            "Expected a jax.numpy.ndarray or a tuple of two jax.numpy.ndarray, "
            f"but got {type(value).__name__}."
        )

    @classmethod
    def from_seqlens(
        cls,
        seqlens: Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]],
    ) -> SequenceDescriptor:
        """
        Factory method for inputs with sequence lengths only (non-THD).
        Args:
            seqlens(Tuple(jnp.ndarray, jnp.ndarray)) = (q_seqlens, kv_seqlens):
                - q_seqlens (jnp.ndarray):
                  Sequence lengths for the query, with shape [batch].
                - kv_seqlen (jnp.ndarray):
                  Sequence lengths for the key and value, with shape [batch].
        Return:
            A SequenceDescriptor with only seqlens initialized.
        """
        q_seqlens, kv_seqlens = cls._expand_to_pair(seqlens)
        return cls(seqlens=(q_seqlens, kv_seqlens))

    @classmethod
    def from_seqlens_and_offsets(
        cls,
        seqlens: Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]],
        seq_offsets: Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]],
    ) -> SequenceDescriptor:
        """
        Factory method for inputs with sequence lengths and offsets (THD).
        Args:
            seqlens(Tuple(jnp.ndarray, jnp.ndarray)) = (q_seqlens, kv_seqlens):
                - q_seqlens (jnp.ndarray):
                  Sequence lengths for the query, with shape [batch, max_seqlen].
                  Unused positions are padded with -1.
                - kv_seqlen (jnp.ndarray):
                  Sequence lengths for the key and value, with shape [batch, max_seqlen].
                  Unused positions are padded with -1.
            seq_offsets(Tuple(jnp.ndarray, jnp.ndarray)) = (q_offsets, kv_offsets)
                - q_seq_offsets (jnp.ndarray):
                  The offsets in the sequence dim for the query, with shape [batch, max_seqlen + 1].
                  Unused positions are padded with -1.
                - kv_seq_offsets (jnp.ndarray):
                  The offsets in the sequence dim for the query, with shape [batch, max_seqlen + 1].
                  Unused positions are padded with -1.
        Return:
            A SequenceDescriptor with seqlens/seq_offsets initialized.
        """
        q_seqlens, kv_seqlens = cls._expand_to_pair(seqlens)
        q_offsets, kv_offsets = cls._expand_to_pair(seq_offsets)
        return cls(seqlens=(q_seqlens, kv_seqlens), seq_offsets=(q_offsets, kv_offsets))

    @classmethod
    def from_segment_ids_and_pos(
        cls,
        segment_ids: Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]],
        segment_pos: Optional[Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]] = None,
    ) -> SequenceDescriptor:
        """
        Experimental factory method for inputs with segment IDs and optional positions. (THD)
        Args:
            segment_ids(Tuple(jnp.ndarray, jnp.ndarray)) = (q_segment_ids, kv_segment_ids):
                - q_segment_ids (jnp.ndarray):
                  Query segment ids start with 1, with shape [batch, max_seqlen].
                  0s are treated as paddings.
                - kv_segment_ids (jnp.ndarray):
                  Key, value segment ids start with 1, with shape [batch, max_seqlen].
                  0s are treated as paddings.
            segment_pos(Tuple(jnp.ndarray, jnp.ndarray)) = (q_segment_pos, kv_segment_pos)
                - q_segment_pos (jnp.ndarray):
                  The position inside each segment for query, with shape [batch, max_seqlen].
                - kv_segment_pos (jnp.ndarray):
                  The position inside each segment for key, value, with shape [batch, max_seqlen].
        Return:
            A SequenceDescriptor with segment_ids/segment_pos initialized.
        """
        q_seg_ids, kv_seg_ids = cls._expand_to_pair(segment_ids)

        if segment_pos is not None:
            segment_pos = cls._expand_to_pair(segment_pos)
        else:

            def generate_default_pos(segment_ids):
                seqlen = segment_ids.shape[-1]
                return jnp.broadcast_to(jnp.arange(seqlen), segment_ids.shape)

            q_seg_pos = generate_default_pos(q_seg_ids)
            kv_seg_pos = generate_default_pos(kv_seg_ids)
            segment_pos = (q_seg_pos, kv_seg_pos)

        return cls(
            segment_ids=(q_seg_ids, kv_seg_ids),
            segment_pos=segment_pos,
        )


def _legacy_fused_attn(
839
840
841
842
    qkv: Tuple[jnp.ndarray, ...],
    bias: Optional[jnp.ndarray],
    mask: Optional[jnp.ndarray],
    seed: Optional[jnp.ndarray],
843
844
    attn_bias_type: AttnBiasType,
    attn_mask_type: AttnMaskType,
845
    qkv_layout: QKVLayout,
846
    softmax_type: AttnSoftmaxType,
847
848
849
    scaling_factor: float,
    dropout_probability: float,
    is_training: bool,
850
    window_size: Optional[Tuple[int, int]] = None,
851
    context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT,
852
853
    context_parallel_causal_load_balanced: bool = False,
    context_parallel_axis: str = "",
854
    softmax_offset: Optional[jnp.ndarray] = None,
855
):
856
    """
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
    Perform non-THD (non-packed) cuDNN fused attention.

    This function implements the following formula:
        BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2
    Args:
        qkv (Tuple[jnp.ndarray, ...]): A tuple containing query, key, and value tensors.
        It supports three formats:
            - `(qkv_packed,)`: For interleaved QKV packed format, typically used when query, key,
              and value have the same shape (e.g., self-attention).
            - `(query, kv_packed)`: For separate query and KV packed format, typically used when
              query has a different shape (e.g., cross-attention).
            - `(query, key, value)`: For separate query, key, and value tensors.
        bias (Optional[jnp.ndarray]): An optional bias tensor to be added to the attention scores.
        mask (Optional[jnp.ndarray]):
            An optional mask tensor to mask out the attention scores, `True` means mask out.
            Intra-sequence padding is not valid. The padded tokens can only on the right-most.
            Otherwise the results will be wrong.
        seed (Optional[jnp.ndarray]): Optional random seed for dropout.
Reese Wang's avatar
Reese Wang committed
875
876
        attn_bias_type (AttnBiasType): Type of attention bias.
        attn_mask_type (AttnMaskType): Type of attention mask.
877
        softmax_type (AttnSoftmaxType): Type of attention softmax.
Reese Wang's avatar
Reese Wang committed
878
        qkv_layout (QKVLayout): Layout of the QKV tensors.
879
880
881
        scaling_factor (float): Scaling factor for the attention scores.
        dropout_probability (float): Dropout probability to apply during attention.
        is_training (bool): Flag indicating whether the model is in training mode.
882
        window_size (Optional[Tuple[int, int]]): Sliding window size.
883
884
885
        context_parallel_causal_load_balanced (bool):
            Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
        context_parallel_axis (str): The name of the context parallel axis.
886
887
    Returns:
        (jnp.ndarray): The output tensor from the fused attention.
888
    """
889
    assert (
890
        not qkv_layout.is_thd()
891
892
893
894
    ), "Please use transformer_engine.jax.attention.fused_attn_thd for THD format."

    # Check inputs qkv
    match qkv_layout:
Reese Wang's avatar
Reese Wang committed
895
        case QKVLayout.BS3HD:
896
            assert len(qkv) == 1, f"qkv=(packed_qkv,) is expected with {qkv_layout=} but got {qkv=}"
Reese Wang's avatar
Reese Wang committed
897
        case QKVLayout.BSHD_BS2HD:
898
899
900
            assert (
                len(qkv) == 2
            ), f"qkv=(query, packed_kv) is expected with {qkv_layout=} but got {qkv=}"
Reese Wang's avatar
Reese Wang committed
901
        case QKVLayout.BSHD_BSHD_BSHD:
902
903
904
            assert (
                len(qkv) == 3
            ), f"qkv=(query, key, value) is expected with {qkv_layout=} but got {qkv=}"
Reese Wang's avatar
Reese Wang committed
905
906
        case _:
            raise ValueError(f"Unknown {qkv_layout=}")
907
908

    # convert the mask to seqlens, mask doesn't support ragged offsets
909
    if not attn_mask_type.is_padding():
910
911
912
        batch, q_max_seqlen, kv_max_seqlen = _obtain_batch_and_max_seqlen(qkv, qkv_layout)
        q_seq_lens = jnp.full((batch,), q_max_seqlen, dtype=jnp.int32)
        kv_seq_lens = jnp.full((batch,), kv_max_seqlen, dtype=jnp.int32)
zlsh80826's avatar
zlsh80826 committed
913
    else:
914
        assert mask is not None
915
        mask = jnp.logical_not(mask)
916
        q_seq_lens = jnp.sum(mask, axis=-2, dtype=jnp.int32)[..., 0, 0]
917
        if attn_mask_type == AttnMaskType.PADDING_MASK:
918
            kv_seq_lens = jnp.sum(mask, axis=-1, dtype=jnp.int32)[..., 0, 0]
919
920
        else:
            # When mask is causal, the actual seqlen is not the last row, use max to find it
921
            kv_seq_lens = jnp.max(jnp.sum(mask, axis=-1, dtype=jnp.int32), axis=(-1, -2))
922

923
924
    output = _fused_attn(
        qkv,
925
        bias,
926
        softmax_offset,
927
        SequenceDescriptor.from_seqlens((q_seq_lens, kv_seq_lens)),
928
        seed,
929
930
        attn_bias_type=attn_bias_type,
        attn_mask_type=attn_mask_type,
931
        softmax_type=softmax_type,
932
        qkv_layout=qkv_layout,
933
934
935
        scaling_factor=scaling_factor,
        dropout_probability=dropout_probability,
        is_training=is_training,
936
        max_segments_per_seq=1,
937
        window_size=window_size,
938
        context_parallel_strategy=context_parallel_strategy,
939
940
        context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
        context_parallel_axis=context_parallel_axis,
941
    )
942

943
    return output
944
945


946
947
948
949
950
951
952
953
def fused_attn_thd(
    qkv: Tuple[jnp.ndarray, ...],
    bias: Optional[jnp.ndarray],
    q_seq_lens: jnp.ndarray,
    kv_seq_lens: jnp.ndarray,
    q_seq_offsets: jnp.ndarray,
    kv_seq_offsets: jnp.ndarray,
    seed: Optional[jnp.ndarray],
954
955
    attn_bias_type: AttnBiasType,
    attn_mask_type: AttnMaskType,
956
    qkv_layout: QKVLayout,
957
958
959
    scaling_factor: float,
    dropout_probability: float,
    is_training: bool,
960
    max_segments_per_seq: int = 1,
961
    window_size: Optional[Tuple[int, int]] = None,
962
    context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT,
963
964
    context_parallel_causal_load_balanced: bool = False,
    context_parallel_axis: str = "",
965
    softmax_offset: Optional[jnp.ndarray] = None,
966
):
967
    """
968
    Deprecated THD fused attn, please use fusd_attn with SequenceDescriptor
969
    """
970
971
972
973
974
    warnings.warn(
        "fused_attn_thd is deprecated, please use fused_attn with SequenceDescriptor",
        DeprecationWarning,
    )

975
    assert (
976
        qkv_layout.is_thd()
977
978
979
980
    ), "Please use transformer_engine.jax.attention.fused_attn for non-THD format."

    # Check inputs qkv
    match qkv_layout:
Reese Wang's avatar
Reese Wang committed
981
        case QKVLayout.T3HD:
982
            assert len(qkv) == 1, f"qkv=(packed_qkv,) is expected with {qkv_layout=} but got {qkv=}"
Reese Wang's avatar
Reese Wang committed
983
        case QKVLayout.THD_T2HD:
984
985
986
            assert (
                len(qkv) == 2
            ), f"qkv=(query, packed_kv) is expected with {qkv_layout=} but got {qkv=}"
Reese Wang's avatar
Reese Wang committed
987
        case QKVLayout.THD_THD_THD:
988
989
990
            assert (
                len(qkv) == 3
            ), f"qkv=(query, key, value) is expected with {qkv_layout=} but got {qkv=}"
Reese Wang's avatar
Reese Wang committed
991
992
        case _:
            raise ValueError(f"Unknown {qkv_layout=}")
993
994
995
996
997
998

    batch, q_max_seqlen, kv_max_seqlen = _obtain_batch_and_max_seqlen(qkv, qkv_layout)
    assert q_seq_lens.shape == (batch, q_max_seqlen)
    assert kv_seq_lens.shape == (batch, kv_max_seqlen)
    assert q_seq_offsets.shape == (batch, q_max_seqlen + 1)
    assert kv_seq_offsets.shape == (batch, kv_max_seqlen + 1)
999

1000
    output = _fused_attn(
1001
        qkv,
1002
        bias,
1003
        softmax_offset,
1004
1005
1006
        SequenceDescriptor.from_seqlens_and_offsets(
            (q_seq_lens, kv_seq_lens), (q_seq_offsets, kv_seq_offsets)
        ),
1007
1008
1009
        seed,
        attn_bias_type=attn_bias_type,
        attn_mask_type=attn_mask_type,
1010
        qkv_layout=qkv_layout,
1011
        scaling_factor=scaling_factor,
1012
        softmax_type=AttnSoftmaxType.VANILLA_SOFTMAX,
1013
1014
        dropout_probability=dropout_probability,
        is_training=is_training,
1015
        max_segments_per_seq=max_segments_per_seq,
1016
        window_size=window_size,
1017
        context_parallel_strategy=context_parallel_strategy,
1018
1019
        context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
        context_parallel_axis=context_parallel_axis,
1020
    )
1021
1022
1023
1024

    return output


1025
@partial(jax.custom_vjp, nondiff_argnums=(5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18))
1026
def _fused_attn(
1027
1028
    qkv: Tuple[jnp.ndarray, ...],
    bias: Optional[jnp.ndarray],
1029
    softmax_offset: Optional[jnp.ndarray],
1030
1031
    sequence_descriptor: SequenceDescriptor,
    seed: Optional[jnp.ndarray],
1032
1033
    attn_bias_type: AttnBiasType,
    attn_mask_type: AttnMaskType,
1034
    qkv_layout: QKVLayout,
1035
    softmax_type: AttnSoftmaxType,
1036
1037
1038
    scaling_factor: float,
    dropout_probability: float,
    is_training: bool,
1039
    max_segments_per_seq: int,
1040
    window_size: Optional[Tuple[int, int]],
1041
    context_parallel_strategy: CPStrategy,
1042
1043
    context_parallel_causal_load_balanced: bool,
    context_parallel_axis: str,
1044
    context_checkpoint_name: str = "context",
1045
    stripe_size: int | None = None,
1046
1047
):
    output, _ = _fused_attn_fwd_rule(
1048
        qkv,
1049
        bias,
1050
        softmax_offset,
1051
        sequence_descriptor,
1052
1053
1054
        seed,
        attn_bias_type,
        attn_mask_type,
1055
        qkv_layout,
1056
        softmax_type,
1057
1058
1059
        scaling_factor,
        dropout_probability,
        is_training,
1060
        max_segments_per_seq,
1061
        window_size,
1062
        context_parallel_strategy,
1063
1064
        context_parallel_causal_load_balanced,
        context_parallel_axis,
1065
        context_checkpoint_name=context_checkpoint_name,
1066
        stripe_size=stripe_size,
1067
    )
1068
1069
1070
    return output


1071
def _fused_attn_fwd_rule(
1072
    qkv,
1073
    bias,
1074
    softmax_offset,
1075
    sequence_descriptor,
1076
1077
1078
    seed,
    attn_bias_type,
    attn_mask_type,
1079
    qkv_layout,
1080
    softmax_type,
1081
1082
1083
    scaling_factor,
    dropout_probability,
    is_training,
1084
    max_segments_per_seq,
1085
    window_size,
1086
    context_parallel_strategy,
1087
1088
    context_parallel_causal_load_balanced,
    context_parallel_axis,
1089
    context_checkpoint_name,
1090
    stripe_size,
1091
1092
):
    output, softmax_aux, rng_state = tex.fused_attn_fwd(
1093
        qkv,
1094
        bias,
1095
        softmax_offset,
1096
        sequence_descriptor,
1097
        seed,
Reese Wang's avatar
Reese Wang committed
1098
1099
        attn_bias_type=attn_bias_type,
        attn_mask_type=attn_mask_type,
1100
        softmax_type=softmax_type,
Reese Wang's avatar
Reese Wang committed
1101
        qkv_layout=qkv_layout,
1102
1103
1104
        scaling_factor=scaling_factor,
        dropout_probability=dropout_probability,
        is_training=is_training,
1105
        max_segments_per_seq=max_segments_per_seq,
1106
        window_size=window_size,
1107
        context_parallel_strategy=context_parallel_strategy,
1108
1109
        context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
        context_parallel_axis=context_parallel_axis,
1110
        stripe_size=stripe_size,
1111
    )
1112
1113
1114
    output = checkpoint_name(output, context_checkpoint_name)
    softmax_aux = checkpoint_name(softmax_aux, context_checkpoint_name)
    rng_state = checkpoint_name(rng_state, context_checkpoint_name)
1115
    return output, (
1116
        qkv,
1117
        bias,
1118
        sequence_descriptor,
1119
1120
        softmax_aux,
        rng_state,
1121
        softmax_offset,
1122
1123
1124
1125
1126
        output,
    )


def _fused_attn_bwd_rule(
1127
1128
1129
    attn_bias_type,
    attn_mask_type,
    qkv_layout,
1130
    softmax_type,
1131
1132
1133
1134
    scaling_factor,
    dropout_probability,
    is_training,
    max_segments_per_seq,
1135
    window_size,
1136
    context_parallel_strategy,
1137
1138
    context_parallel_causal_load_balanced,
    context_parallel_axis,
1139
    context_checkpoint_name,
1140
    stripe_size,
1141
1142
    ctx,
    dz,
1143
):
1144
    del context_checkpoint_name
1145
1146
1147
    (
        qkv,
        bias,
1148
        sequence_descriptor,
1149
1150
        softmax_aux,
        rng_state,
1151
        softmax_offset,
1152
1153
        output,
    ) = ctx
1154
    grad_qkv, grad_bias, grad_softmax_offset = tex.fused_attn_bwd(
1155
        qkv,
1156
        bias,
1157
        softmax_offset,
1158
1159
1160
1161
        softmax_aux,
        rng_state,
        output,
        dz,
1162
        sequence_descriptor,
Reese Wang's avatar
Reese Wang committed
1163
1164
        attn_bias_type=attn_bias_type,
        attn_mask_type=attn_mask_type,
1165
        softmax_type=softmax_type,
Reese Wang's avatar
Reese Wang committed
1166
        qkv_layout=qkv_layout,
1167
1168
1169
        scaling_factor=scaling_factor,
        dropout_probability=dropout_probability,
        is_training=is_training,
1170
        max_segments_per_seq=max_segments_per_seq,
1171
        window_size=window_size,
1172
        context_parallel_strategy=context_parallel_strategy,
1173
1174
        context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
        context_parallel_axis=context_parallel_axis,
1175
        stripe_size=stripe_size,
1176
    )
1177
1178
    if attn_bias_type == AttnBiasType.NO_BIAS:
        grad_bias = None
1179
1180
    if softmax_type != AttnSoftmaxType.LEARNABLE_SOFTMAX:
        grad_softmax_offset = None
1181
1182
1183
    return (
        grad_qkv,
        grad_bias,
1184
        grad_softmax_offset,
1185
1186
1187
        None,
        None,
    )
1188
1189
1190


_fused_attn.defvjp(_fused_attn_fwd_rule, _fused_attn_bwd_rule)
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200


def fused_attn(
    qkv: Tuple[jnp.ndarray, ...],
    bias: Optional[jnp.ndarray],
    sequence_descriptor: SequenceDescriptor,
    seed: Optional[jnp.ndarray],
    attn_bias_type: AttnBiasType,
    attn_mask_type: AttnMaskType,
    qkv_layout: QKVLayout,
1201
    softmax_type: AttnSoftmaxType,
1202
1203
1204
1205
1206
1207
1208
1209
    scaling_factor: float,
    dropout_probability: float,
    is_training: bool,
    max_segments_per_seq: int = 1,
    window_size: Optional[Tuple[int, int]] = None,
    context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT,
    context_parallel_causal_load_balanced: bool = False,
    context_parallel_axis: str = "",
1210
    context_checkpoint_name: str = "context",
1211
    softmax_offset: Optional[jnp.ndarray] = None,
1212
    stripe_size: int | None = None,
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
):
    """
    Perform cuDNN fused attention.

    This function implements the following formula:
        BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2
    Args:
        qkv (Tuple[jnp.ndarray, ...]): A tuple containing query, key, and value tensors.
        It supports three formats:
            - `(qkv_packed,)`: For interleaved QKV packed format, typically used when query, key,
              and value have the same shape (e.g., self-attention).
            - `(query, kv_packed)`: For separate query and KV packed format, typically used when
              query has a different shape (e.g., cross-attention).
            - `(query, key, value)`: For separate query, key, and value tensors.
        bias (Optional[jnp.ndarray]): An optional bias tensor to be added to the attention scores.
        sequence_descriptor (SequenceDescriptor): Descriptor for how to describe the sequence.
        seed (Optional[jnp.ndarray]): Optional random seed for dropout.
Reese Wang's avatar
Reese Wang committed
1230
1231
        attn_bias_type (AttnBiasType): Type of attention bias.
        attn_mask_type (AttnMaskType): Type of attention mask.
1232
        softmax_type (AttnSoftmaxType): Type of attention softmax.
Reese Wang's avatar
Reese Wang committed
1233
        qkv_layout (QKVLayout): Layout of the QKV tensors.
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
        scaling_factor (float): Scaling factor for the attention scores.
        dropout_probability (float): Dropout probability to apply during attention.
        is_training (bool): Flag indicating whether the model is in training mode.
        max_segments_per_seq (int):
            Indicating the maximum number of segments inside a sequence. This parameter is to
            constrain the limit usage and need to be static during the e2e training. The XLA compile
            time and memory consumption is proportional to `max_segments_per_seq`.
        window_size (Optional[Tuple[int, int]]):
            Sliding window size.
        context_parallel_causal_load_balanced (bool):
            Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
        context_parallel_axis (str): The name of the context parallel axis.
1246
        context_checkpoint_name (str): The name of the context checkpoint for the custom VJP forward pass.
1247
1248
1249
        softmax_offset (Optional[jnp.ndarray]): An optional learnable softmax offset tensor with shape
            [1, num_heads, 1, 1]. Used when softmax_type is AttnSoftmaxType.LEARNABLE_SOFTMAX.
            If provided, this parameter will receive gradients during backpropagation.
1250
1251
1252
1253
1254
        stripe_size (int |  None):
            Indicates the striping size to be used when using ReorderStrategy.Striped.
            Currently, a stripe_size > 1 is only supported for CP + THD + Striped + AG, whereas a stripe_size=1
            is supported for both, CP + THD + Striped + AG and CP + THD + Striped + P2P(Ring)
            None indicates no striping strategy
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
    Returns:
        (jnp.ndarray): The output tensor from the fused attention.

    Examples (non-THD, also known as non-packed):
        >>> #  q_segment_ids = [[1, 1, 1, 0], [1, 1, 0, 0]], 0 means padded tokens
        >>> # kv_segment_ids = [[1, 0, 0, 0], [1, 1, 0, 0]], 0 means padded tokens
        >>> b, s, h, d = 2, 4, 12, 64
        >>> qkv = jnp.zeros((b, s, 3, h, d), dtype=jnp.bfloat16)
        >>> q_seq_lens = jnp.asarray([3, 2])
        >>> kv_seq_lens = jnp.asarray([1, 2])
        >>> sequence_desc = SequenceDescriptor.from_seqlens(
                seqlens=(q_seq_lens, kv_seq_lens))
        >>> out = fused_attn((qkv,), None, sequence_desc, None,
                             AttnBiasType.NO_BIAS, AttnMaskType.PADDING_CAUSAL_MASK,
                             QKVLayout.BS3HD, 0.125, 0, True, 3)

    Examples (THD, also known as packed):
        >>> # segment_ids = [[1, 1, 2, 3], [1, 1, 2, 0]], 0 means padded tokens
        >>> # segment_pos = [[0, 1, 0, 0], [0, 1, 0, 1]]
        >>> b, s, h, d = 2, 4, 12, 64
        >>> qkv = jnp.zeros((b, s, 3, h, d), dtype=jnp.bfloat16)
        >>> # 3 segments in first seq, 2 segments in second seq
        >>> q_seq_lens = kv_seq_lens = jnp.asarray([[2, 1, 1, -1], [2, 1, -1, -1]])
        >>> # seq_offsets need to include the end offset of the last segments
        >>> q_seq_offsets = kv_seq_offsets = jnp.asarray([[0, 2, 3, 4, -1], [0, 2, 3, -1, -1]])
        >>> sequence_desc = SequenceDescriptor.from_seqlens_and_offsets(
                seqlens=(q_seq_lens, kv_seq_lens),
                seq_offsets=(q_seq_offsets, kv_seq_offsets))
        >>> out = fused_attn((qkv,), None, sequence_desc, None,
                             AttnBiasType.NO_BIAS, AttnMaskType.PADDING_CAUSAL_MASK,
                             QKVLayout.T3HD, 0.125, 0, True, 3)
    """
1287
    if sequence_descriptor is None or isinstance(sequence_descriptor, jnp.ndarray):
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
        warnings.warn(
            "Pass mask to fused_attn is deprecated, please use SequenceDescriptor instead. "
            + "See help(transformer_engine.jax.attention.SequenceDescriptor) for details.",
            DeprecationWarning,
        )
        if max_segments_per_seq != 1:
            raise ValueError("Passing mask is only supported for non-THD case.")
        return _legacy_fused_attn(
            qkv,
            bias,
            sequence_descriptor,
            seed,
            attn_bias_type=attn_bias_type,
            attn_mask_type=attn_mask_type,
1302
            softmax_type=softmax_type,
1303
1304
1305
1306
1307
1308
1309
1310
            qkv_layout=qkv_layout,
            scaling_factor=scaling_factor,
            dropout_probability=dropout_probability,
            is_training=is_training,
            window_size=window_size,
            context_parallel_strategy=context_parallel_strategy,
            context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
            context_parallel_axis=context_parallel_axis,
1311
            softmax_offset=softmax_offset,
1312
1313
1314
1315
        )
    output = _fused_attn(
        qkv,
        bias,
1316
        softmax_offset,
1317
1318
1319
1320
1321
        sequence_descriptor,
        seed,
        attn_bias_type=attn_bias_type,
        attn_mask_type=attn_mask_type,
        qkv_layout=qkv_layout,
1322
        softmax_type=softmax_type,
1323
1324
1325
1326
1327
1328
1329
1330
        scaling_factor=scaling_factor,
        dropout_probability=dropout_probability,
        is_training=is_training,
        max_segments_per_seq=max_segments_per_seq,
        window_size=window_size,
        context_parallel_strategy=context_parallel_strategy,
        context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
        context_parallel_axis=context_parallel_axis,
1331
        context_checkpoint_name=context_checkpoint_name,
1332
        stripe_size=stripe_size,
1333
1334
    )
    return output