attention.py 24.9 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
#
# See LICENSE for license information.
"""JAX multi-head attention modules"""

from enum import Enum
from functools import partial
8
from typing import Optional, Tuple
9
from jax.ad_checkpoint import checkpoint_name
10
11
12
import jax
import jax.numpy as jnp

13
14
15
from transformer_engine.transformer_engine_jax import NVTE_Bias_Type
from transformer_engine.transformer_engine_jax import NVTE_Mask_Type
from transformer_engine.transformer_engine_jax import NVTE_QKV_Layout
16
17
from transformer_engine.transformer_engine_jax import NVTE_QKV_Format
from transformer_engine.transformer_engine_jax import nvte_get_qkv_format
18

19
from . import cpp_extensions as tex
20
21
22


class AttnBiasType(Enum):
23
24
25
26
27
    """
    NO_BIAS: Softmax is performed as softmax(scale * qk)
    PRE_SCALE_BIAS: Softmax is performed as softmax(scale * (qk + bias))
    POST_SCALE_BIAS: Softmax is performed as softmax(scale * qk + bias)
    """
28

29
30
31
32
33
34
    NO_BIAS = NVTE_Bias_Type.NVTE_NO_BIAS
    PRE_SCALE_BIAS = NVTE_Bias_Type.NVTE_PRE_SCALE_BIAS
    POST_SCALE_BIAS = NVTE_Bias_Type.NVTE_POST_SCALE_BIAS


class AttnMaskType(Enum):
35
36
37
38
39
40
    """
    NO_MASK: No attention mask is applied.
    PADDING_MASK: Indicates the presence of paddings at the end of each sequence.
    CAUSAL_MASK: An upper triangular mask is applied to the softmax inputs.
    PADDING_CAUSAL_MASK: A combination of both causal and padding masks.
    """
41

42
43
44
    NO_MASK = NVTE_Mask_Type.NVTE_NO_MASK
    PADDING_MASK = NVTE_Mask_Type.NVTE_PADDING_MASK
    CAUSAL_MASK = NVTE_Mask_Type.NVTE_CAUSAL_MASK
45
    PADDING_CAUSAL_MASK = NVTE_Mask_Type.NVTE_PADDING_CAUSAL_MASK
46
47
    CAUSAL_BOTTOM_RIGHT_MASK = NVTE_Mask_Type.NVTE_CAUSAL_BOTTOM_RIGHT_MASK
    PADDING_CAUSAL_BOTTOM_RIGHT_MASK = NVTE_Mask_Type.NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK
48

49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
    def is_causal(self):
        """Returns True if the mask is a causal mask"""
        return self in [
            AttnMaskType.CAUSAL_MASK,
            AttnMaskType.PADDING_CAUSAL_MASK,
            AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK,
            AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK,
        ]

    def is_padding(self):
        """Returns True if the mask includes padding"""
        return self in [
            AttnMaskType.PADDING_MASK,
            AttnMaskType.PADDING_CAUSAL_MASK,
            AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK,
        ]

    def is_bottom_right(self):
        """Returns True if the causal mask is calculated from the bottom-right section"""
        return self in [
            AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK,
            AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK,
        ]


class QKVFormat(Enum):
    """
    SBHD: q,k,v memory layout with [s, b, ..., h, d]
    BSHD: q,k,v memory layout with [b, s, ..., h, d]
    THD: q,k,v memory layout is same as BSHD, but allow multiple segments packed in a sequence.
    """

    SBHD = NVTE_QKV_Format.NVTE_SBHD
    BSHD = NVTE_QKV_Format.NVTE_BSHD
    THD = NVTE_QKV_Format.NVTE_THD

85

86
class QKVLayout(Enum):
87
88
89
90
91
92
93
94
95
96
    """
    BSHD Format:
        - BS3HD: q,k,v are interleave packed as a tensor with shape [b, s, 3, h, d].
        - BSHD_BS2HD: q with shape [b, s, h, d] and kv are interleaved with shape [b, s, 2, h, d].
        - BSHD_BSHD_BSHD: q,k,v are seperate tensors with shape [b, s, h, d]
    THD Format: Shape is same as BSHD layout but allow multiple segments packed in a sequence.
        - T3HD: q,k,v are interleave packed as a tensor with shape [b, s, 3, h, d].
        - THD_T2HD: q with shape [b, s, h, d] and kv are interleaved with shape [b, s, 2, h, d].
        - THD_THD_THD: q,k,v are seperate tensors with shape [b, s, h, d]
    """
97

98
99
    BS3HD = NVTE_QKV_Layout.NVTE_BS3HD
    BSHD_BS2HD = NVTE_QKV_Layout.NVTE_BSHD_BS2HD
100
    BSHD_BSHD_BSHD = NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD
101
102
103
104
    T3HD = NVTE_QKV_Layout.NVTE_T3HD
    THD_T2HD = NVTE_QKV_Layout.NVTE_THD_T2HD
    THD_THD_THD = NVTE_QKV_Layout.NVTE_THD_THD_THD

105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
    def get_qkv_format(self):
        """
        Return the corresponding qkv_format (BSHD, SBHD, THD)
        """
        return QKVFormat(nvte_get_qkv_format(self.value))

    def is_qkvpacked(self):
        """
        Return True if the query, key, value is packed
        """
        return self in [QKVLayout.BS3HD, QKVLayout.T3HD]

    def is_kvpacked(self):
        """
        Return True if the key, value is packed
        """
        return self in [QKVLayout.BSHD_BS2HD, QKVLayout.THD_T2HD]

    def is_separate(self):
        """
        Return True if the query, key, value are three separate tensors
        """
        return self in [QKVLayout.BSHD_BSHD_BSHD, QKVLayout.THD_THD_THD]

    def is_thd(self):
        """
        Return True if the layout belongs to THD
        """
        return self in [QKVLayout.T3HD, QKVLayout.THD_T2HD, QKVLayout.THD_THD_THD]
134
135


136
137
138
139
140
141
142
143
144
145
146
147
148
class CPStrategy(Enum):
    """Defines the context parallel strategies of Jax fused attention.

    DEFAULT: Default strategy will choose automatically if context parallel axis is sharded.
    ALL_GATHER: All-gather/reduce scatter implementation.
    RING: Ring attention implementation (https://arxiv.org/abs/2310.01889).
    """

    DEFAULT = 0
    ALL_GATHER = 1
    RING = 2


149
def make_swa_mask(
150
151
    segment_pos_q: jnp.ndarray,
    segment_pos_kv: jnp.ndarray,
152
153
154
155
    window_size: Optional[Tuple[int, int]] = None,
    dtype: jax.typing.DTypeLike = jnp.float32,
):
    """
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
    Generate a sliding window mask (1 = attend, 0 = masked).

    Args:
        segment_pos_q (jnp.ndarray):
            Query positions within each segment. For example, a batch with segment_ids =
            [[1, 1, 1, 2, 2, 2, 2, 2]] yields segment_pos =
            [[0, 1, 2, 0, 1, 2, 3, 4]].
        segment_pos_kv (jnp.ndarray):
            Key/value positions within each segment.
        window_size (Optional[Tuple[int, int]], optional):
            Sliding window size for local attention, where query at position i attends to keys
            in [i - window_size[0], i + window_size[1]] inclusive. A negative number means an
            infinite window; None means no sliding window.
            Defaults to None.
        dtype (jax.typing.DTypeLike, optional):
            Mask data type. Defaults to jnp.float32.

    Returns:
        jnp.ndarray:
            The mask with shape [b, 1, max_seqlen_q, max_seqlen_kv].
176
    """
177
178
    if window_size is not None:
        left_window, right_window = window_size
179
    else:
180
181
182
183
184
185
186
187
        left_window = right_window = jnp.inf
    left_window = jnp.inf if left_window < 0 else left_window
    right_window = jnp.inf if right_window < 0 else right_window
    pos_q = jnp.expand_dims(segment_pos_q, axis=-1)
    pos_kv = jnp.expand_dims(segment_pos_kv, axis=-2)
    inv_swa_mask = (pos_kv >= pos_q - left_window) & (pos_kv <= pos_q + right_window)
    inv_swa_mask = jnp.expand_dims(inv_swa_mask, axis=-3)
    return inv_swa_mask.astype(dtype)
188
189


190
191
192
193
194
195
def canonicalize_attn_mask_type(attn_mask_type: str):
    """Convert string attn_mask_type to AttnMaskType
    TE-JAX currently fall back to the padding version kernels for the libraries integration.
    The overhead between padding and non-padding version should be small.
    However, we will lease this limitation in the near feature.
    """
196
    match attn_mask_type:
197
        case "no_mask":
198
            return AttnMaskType.NO_MASK
199
        case "padding":
200
            return AttnMaskType.PADDING_MASK
201
        case "causal":
202
            return AttnMaskType.CAUSAL_MASK
203
204
        case "causal_bottom_right" | "bottom_right_causal":
            return AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK
205
        case "padding_causal" | "causal_padding":
206
            return AttnMaskType.PADDING_CAUSAL_MASK
207
208
209
210
211
212
213
        case (
            "padding_causal_bottom_right"
            | "causal_padding_bottom_right"
            | "bottom_right_causal_padding"
            | "bottom_right_padding_causal"
        ):
            return AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK
214
    raise ValueError(
215
216
217
        f"Unsupported {attn_mask_type=}, supported attn_mask_type={{'no_mask', 'padding', 'causal',"
        " 'padding_causal', 'causal_padding', 'causal_bottom_right',"
        " 'padding_causal_bottom_right'}"
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
    )


def is_fused_attn_kernel_available(
    q_dtype,
    kv_dtype,
    qkv_layout,
    attn_bias_type,
    attn_mask_type,
    dropout_probability,
    q_num_heads,
    kv_num_heads,
    q_max_seqlen,
    kv_max_seqlen,
    head_dim,
233
    window_size: Optional[Tuple[int, int]] = None,
234
):
235
    """
236
    To check whether the fused attention kernel is supported
237
    """
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258

    def make_helper(attn_mask_type):
        return tex.FusedAttnHelper(
            q_dtype,
            kv_dtype,
            qkv_layout.value,
            attn_bias_type.value,
            attn_mask_type.value,
            dropout_probability,
            q_num_heads,
            kv_num_heads,
            q_max_seqlen,
            kv_max_seqlen,
            head_dim,
            (-1, -1) if window_size is None else window_size,
        )

    if not make_helper(attn_mask_type).is_fused_attn_kernel_available():
        return False

    return True
259
260


261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
def _obtain_batch_and_max_seqlen(qkv, qkv_layout):
    match qkv_layout:
        case QKVLayout.BS3HD | QKVLayout.T3HD:
            assert len(qkv) == 1, f"qkv must be (qkvpacked,) with {qkv_layout=}"
            batch, q_max_seqlen, *_ = qkv[0].shape
            kv_max_seqlen = q_max_seqlen
        case QKVLayout.BSHD_BS2HD | QKVLayout.THD_T2HD:
            assert len(qkv) == 2, f"qkv must be (query, kvpacked) with {qkv_layout=}"
            batch, q_max_seqlen, *_ = qkv[0].shape
            kv_max_seqlen = qkv[1].shape[1]
        case QKVLayout.BSHD_BSHD_BSHD | QKVLayout.THD_THD_THD:
            assert len(qkv) == 3, f"qkv must be (query, key, value) with {qkv_layout=}"
            batch, q_max_seqlen, *_ = qkv[0].shape
            kv_max_seqlen = qkv[1].shape[1]
        case _:
            raise ValueError(f"Unsupported {qkv_layout=}")
    return batch, q_max_seqlen, kv_max_seqlen
278

279

280
281
def reorder_causal_load_balancing(tensor, cp_size: int, tensor_format: QKVFormat):
    """Reorders a tensor for load balancing the compute of causal attention."""
282
283
    seq_dim = 1 if tensor_format == QKVFormat.BSHD else 0
    return tex.attention.reorder_causal_load_balancing(tensor, cp_size, seq_dim, False)
284
285
286
287


def inverse_reorder_causal_load_balancing(tensor, cp_size: int, tensor_format: QKVFormat):
    """Inverse operation of `reorder_causal_load_balancing`."""
288
289
    seq_dim = 1 if tensor_format == QKVFormat.BSHD else 0
    return tex.attention.reorder_causal_load_balancing(tensor, cp_size, seq_dim, True)
290
291


292
293
294
295
296
def fused_attn(
    qkv: Tuple[jnp.ndarray, ...],
    bias: Optional[jnp.ndarray],
    mask: Optional[jnp.ndarray],
    seed: Optional[jnp.ndarray],
297
298
    attn_bias_type: AttnBiasType,
    attn_mask_type: AttnMaskType,
299
    qkv_layout: QKVLayout,
300
301
302
    scaling_factor: float,
    dropout_probability: float,
    is_training: bool,
303
    window_size: Optional[Tuple[int, int]] = None,
304
    context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT,
305
306
    context_parallel_causal_load_balanced: bool = False,
    context_parallel_axis: str = "",
307
):
308
    """
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
    Perform non-THD (non-packed) cuDNN fused attention.

    This function implements the following formula:
        BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2
    Args:
        qkv (Tuple[jnp.ndarray, ...]): A tuple containing query, key, and value tensors.
        It supports three formats:
            - `(qkv_packed,)`: For interleaved QKV packed format, typically used when query, key,
              and value have the same shape (e.g., self-attention).
            - `(query, kv_packed)`: For separate query and KV packed format, typically used when
              query has a different shape (e.g., cross-attention).
            - `(query, key, value)`: For separate query, key, and value tensors.
        bias (Optional[jnp.ndarray]): An optional bias tensor to be added to the attention scores.
        mask (Optional[jnp.ndarray]):
            An optional mask tensor to mask out the attention scores, `True` means mask out.
            Intra-sequence padding is not valid. The padded tokens can only on the right-most.
            Otherwise the results will be wrong.
        seed (Optional[jnp.ndarray]): Optional random seed for dropout.
        attn_bias_type (NVTE_Bias_Type): Type of attention bias.
        attn_mask_type (NVTE_Mask_Type): Type of attention mask.
        qkv_layout (NVTE_QKV_Layout): Layout of the QKV tensors.
        scaling_factor (float): Scaling factor for the attention scores.
        dropout_probability (float): Dropout probability to apply during attention.
        is_training (bool): Flag indicating whether the model is in training mode.
333
        window_size (Optional[Tuple[int, int]]): Sliding window size.
334
335
336
        context_parallel_causal_load_balanced (bool):
            Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
        context_parallel_axis (str): The name of the context parallel axis.
337
338
    Returns:
        (jnp.ndarray): The output tensor from the fused attention.
339
    """
340
    assert (
341
        not qkv_layout.is_thd()
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
    ), "Please use transformer_engine.jax.attention.fused_attn_thd for THD format."

    # Check inputs qkv
    match qkv_layout:
        case NVTE_QKV_Layout.NVTE_BS3HD:
            assert len(qkv) == 1, f"qkv=(packed_qkv,) is expected with {qkv_layout=} but got {qkv=}"
        case NVTE_QKV_Layout.NVTE_BSHD_BS2HD:
            assert (
                len(qkv) == 2
            ), f"qkv=(query, packed_kv) is expected with {qkv_layout=} but got {qkv=}"
        case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD:
            assert (
                len(qkv) == 3
            ), f"qkv=(query, key, value) is expected with {qkv_layout=} but got {qkv=}"

    # convert the mask to seqlens, mask doesn't support ragged offsets
358
    if not attn_mask_type.is_padding():
359
360
361
        batch, q_max_seqlen, kv_max_seqlen = _obtain_batch_and_max_seqlen(qkv, qkv_layout)
        q_seq_lens = jnp.full((batch,), q_max_seqlen, dtype=jnp.int32)
        kv_seq_lens = jnp.full((batch,), kv_max_seqlen, dtype=jnp.int32)
zlsh80826's avatar
zlsh80826 committed
362
    else:
363
        assert mask is not None
364
        mask = jnp.logical_not(mask)
365
        q_seq_lens = jnp.sum(mask, axis=-2, dtype=jnp.int32)[..., 0, 0]
366
        if attn_mask_type == AttnMaskType.PADDING_MASK:
367
            kv_seq_lens = jnp.sum(mask, axis=-1, dtype=jnp.int32)[..., 0, 0]
368
369
        else:
            # When mask is causal, the actual seqlen is not the last row, use max to find it
370
            kv_seq_lens = jnp.max(jnp.sum(mask, axis=-1, dtype=jnp.int32), axis=(-1, -2))
371

372
373
    output = _fused_attn(
        qkv,
374
        bias,
375
376
377
378
        q_seq_lens,
        kv_seq_lens,
        None,
        None,
379
        seed,
380
381
382
        attn_bias_type=attn_bias_type,
        attn_mask_type=attn_mask_type,
        qkv_layout=qkv_layout,
383
384
385
        scaling_factor=scaling_factor,
        dropout_probability=dropout_probability,
        is_training=is_training,
386
        max_segments_per_seq=1,
387
        window_size=window_size,
388
        context_parallel_strategy=context_parallel_strategy,
389
390
        context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
        context_parallel_axis=context_parallel_axis,
391
    )
392

393
    return output
394
395


396
397
398
399
400
401
402
403
def fused_attn_thd(
    qkv: Tuple[jnp.ndarray, ...],
    bias: Optional[jnp.ndarray],
    q_seq_lens: jnp.ndarray,
    kv_seq_lens: jnp.ndarray,
    q_seq_offsets: jnp.ndarray,
    kv_seq_offsets: jnp.ndarray,
    seed: Optional[jnp.ndarray],
404
405
    attn_bias_type: AttnBiasType,
    attn_mask_type: AttnMaskType,
406
    qkv_layout: QKVLayout,
407
408
409
    scaling_factor: float,
    dropout_probability: float,
    is_training: bool,
410
    max_segments_per_seq: int = 1,
411
    window_size: Optional[Tuple[int, int]] = None,
412
    context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT,
413
414
    context_parallel_causal_load_balanced: bool = False,
    context_parallel_axis: str = "",
415
):
416
    """
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
    (Experimental) Perform THD (packed) cuDNN fused attention.

    This function implements the following formula:
        BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2
    Args:
        qkv (Tuple[jnp.ndarray, ...]): A tuple containing query, key, and value tensors.
        It supports three formats:
            - `(qkv_packed,)`: For interleaved QKV packed format, typically used when query, key,
              and value have the same shape (e.g., self-attention).
            - `(query, kv_packed)`: For separate query and KV packed format, typically used when
              query has a different shape (e.g., cross-attention).
            - `(query, key, value)`: For separate query, key, and value tensors.
        bias (Optional[jnp.ndarray]): An optional bias tensor to be added to the attention scores.
        q_seqlen (jnp.ndarray):
            Sequence lengths for the query, with shape [batch, max_seqlen]. Unused positions are
            padded with -1.
        kv_seqlen (jnp.ndarray):
            Sequence lengths for the key and value, with shape [batch, max_seqlen]. Unused positions
            are padded with -1.
        q_seq_offsets (jnp.ndarray):
            The offsets in the sequence dim for the query, with shape [batch, max_seqlen + 1].
            Unused positions are padded with -1.
        kv_seq_offsets (jnp.ndarray):
            The offsets in the sequence dim for the query, with shape [batch, max_seqlen + 1].
            Unused positions are padded with -1.
        seed (Optional[jnp.ndarray]): Optional random seed for dropout.
        attn_bias_type (NVTE_Bias_Type): Type of attention bias.
        attn_mask_type (NVTE_Mask_Type): Type of attention mask.
        qkv_layout (NVTE_QKV_Layout): Layout of the QKV tensors.
        scaling_factor (float): Scaling factor for the attention scores.
        dropout_probability (float): Dropout probability to apply during attention.
        is_training (bool): Flag indicating whether the model is in training mode.
        max_segments_per_seq (int):
            Indicating the maximum number of segments inside a sequence. This parameter is to
            constrain the limit usage and need to be static during the e2e training. The XLA compile
            time and memory consumption is proportional to `max_segments_per_seq`.
453
454
        window_size (Optional[Tuple[int, int]]):
            Sliding window size.
455
456
457
        context_parallel_causal_load_balanced (bool):
            Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
        context_parallel_axis (str): The name of the context parallel axis.
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
    Returns:
        (jnp.ndarray): The output tensor from the fused attention.

    Examples:
        >>> # segment_ids = [[1, 1, 2, 3], [1, 1, 2, 0]], 0 means padded tokens
        >>> b, s, h, d = 2, 4, 12, 64
        >>> qkv = jnp.zeros((b, s, 3, h, d), dtype=jnp.bfloat16)
        >>> # 3 segments in first seq, 2 segments in second seq
        >>> q_seq_lens = kv_seq_lens = jnp.asarray([[2, 1, 1, -1], [2, 1, -1, -1]])
        >>> # seq_offsets need to include the end offset of the last segments
        >>> q_seq_offsets = kv_seq_offsets = jnp.asarray([[0, 2, 3, 4, -1], [0, 2, 3, -1, -1]])
        >>> out = fused_attn_thd((qkv,), None, q_seq_lens, kv_seq_lens,
                                 q_seq_offsets, kv_seq_offsets, None,
                                 AttnBiasType.NO_BIAS, AttnMaskType.PADDING_CAUSAL_MASK,
                                 QKVLayout.T3HD, 0.125, 0, True, 3)
473
    """
474
    assert (
475
        qkv_layout.is_thd()
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
    ), "Please use transformer_engine.jax.attention.fused_attn for non-THD format."

    # Check inputs qkv
    match qkv_layout:
        case NVTE_QKV_Layout.NVTE_T3HD:
            assert len(qkv) == 1, f"qkv=(packed_qkv,) is expected with {qkv_layout=} but got {qkv=}"
        case NVTE_QKV_Layout.NVTE_THD_T2HD:
            assert (
                len(qkv) == 2
            ), f"qkv=(query, packed_kv) is expected with {qkv_layout=} but got {qkv=}"
        case NVTE_QKV_Layout.NVTE_THD_THD_THD:
            assert (
                len(qkv) == 3
            ), f"qkv=(query, key, value) is expected with {qkv_layout=} but got {qkv=}"

    batch, q_max_seqlen, kv_max_seqlen = _obtain_batch_and_max_seqlen(qkv, qkv_layout)
    assert q_seq_lens.shape == (batch, q_max_seqlen)
    assert kv_seq_lens.shape == (batch, kv_max_seqlen)
    assert q_seq_offsets.shape == (batch, q_max_seqlen + 1)
    assert kv_seq_offsets.shape == (batch, kv_max_seqlen + 1)
496

497
    output = _fused_attn(
498
        qkv,
499
        bias,
500
501
502
503
        q_seq_lens,
        kv_seq_lens,
        q_seq_offsets,
        kv_seq_offsets,
504
505
506
        seed,
        attn_bias_type=attn_bias_type,
        attn_mask_type=attn_mask_type,
507
        qkv_layout=qkv_layout,
508
509
510
        scaling_factor=scaling_factor,
        dropout_probability=dropout_probability,
        is_training=is_training,
511
        max_segments_per_seq=max_segments_per_seq,
512
        window_size=window_size,
513
        context_parallel_strategy=context_parallel_strategy,
514
515
        context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
        context_parallel_axis=context_parallel_axis,
516
    )
517
518
519
520

    return output


521
@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17))
522
def _fused_attn(
523
524
525
526
527
528
    qkv: Tuple[jnp.ndarray, ...],
    bias: Optional[jnp.ndarray],
    q_seq_lens: jnp.ndarray,
    kv_seq_lens: jnp.ndarray,
    q_seq_offsets: Optional[jnp.ndarray],
    kv_seq_offsets: Optional[jnp.ndarray],
529
530
531
    seed: jnp.ndarray,
    attn_bias_type: AttnBiasType,
    attn_mask_type: AttnMaskType,
532
    qkv_layout: QKVLayout,
533
534
535
    scaling_factor: float,
    dropout_probability: float,
    is_training: bool,
536
    max_segments_per_seq: int,
537
    window_size: Optional[Tuple[int, int]],
538
    context_parallel_strategy: CPStrategy,
539
540
    context_parallel_causal_load_balanced: bool,
    context_parallel_axis: str,
541
542
):
    output, _ = _fused_attn_fwd_rule(
543
        qkv,
544
        bias,
545
546
547
548
        q_seq_lens,
        kv_seq_lens,
        q_seq_offsets,
        kv_seq_offsets,
549
550
551
        seed,
        attn_bias_type,
        attn_mask_type,
552
        qkv_layout,
553
554
555
        scaling_factor,
        dropout_probability,
        is_training,
556
        max_segments_per_seq,
557
        window_size,
558
        context_parallel_strategy,
559
560
        context_parallel_causal_load_balanced,
        context_parallel_axis,
561
    )
562
563
564
    return output


565
def _fused_attn_fwd_rule(
566
    qkv,
567
    bias,
568
569
570
571
    q_seq_lens,
    kv_seq_lens,
    q_seq_offsets,
    kv_seq_offsets,
572
573
574
    seed,
    attn_bias_type,
    attn_mask_type,
575
    qkv_layout,
576
577
578
    scaling_factor,
    dropout_probability,
    is_training,
579
    max_segments_per_seq,
580
    window_size,
581
    context_parallel_strategy,
582
583
    context_parallel_causal_load_balanced,
    context_parallel_axis,
584
585
):
    output, softmax_aux, rng_state = tex.fused_attn_fwd(
586
        qkv,
587
        bias,
588
589
590
591
        q_seq_lens,
        kv_seq_lens,
        q_seq_offsets,
        kv_seq_offsets,
592
593
594
        seed,
        attn_bias_type=attn_bias_type.value,
        attn_mask_type=attn_mask_type.value,
595
        qkv_layout=qkv_layout.value,
596
597
598
        scaling_factor=scaling_factor,
        dropout_probability=dropout_probability,
        is_training=is_training,
599
        max_segments_per_seq=max_segments_per_seq,
600
        window_size=window_size,
601
        context_parallel_strategy=context_parallel_strategy,
602
603
        context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
        context_parallel_axis=context_parallel_axis,
604
605
606
607
608
    )
    output = checkpoint_name(output, "context")
    softmax_aux = checkpoint_name(softmax_aux, "context")
    rng_state = checkpoint_name(rng_state, "context")
    return output, (
609
        qkv,
610
        bias,
611
612
613
614
        q_seq_lens,
        kv_seq_lens,
        q_seq_offsets,
        kv_seq_offsets,
615
616
617
618
619
620
621
        softmax_aux,
        rng_state,
        output,
    )


def _fused_attn_bwd_rule(
622
623
624
625
626
627
628
    attn_bias_type,
    attn_mask_type,
    qkv_layout,
    scaling_factor,
    dropout_probability,
    is_training,
    max_segments_per_seq,
629
    window_size,
630
    context_parallel_strategy,
631
632
    context_parallel_causal_load_balanced,
    context_parallel_axis,
633
634
    ctx,
    dz,
635
):
636
637
638
639
640
641
642
643
644
645
646
647
648
    (
        qkv,
        bias,
        q_seq_lens,
        kv_seq_lens,
        q_seq_offsets,
        kv_seq_offsets,
        softmax_aux,
        rng_state,
        output,
    ) = ctx
    grad_qkv, grad_bias = tex.fused_attn_bwd(
        qkv,
649
650
651
652
653
        bias,
        softmax_aux,
        rng_state,
        output,
        dz,
654
655
656
657
        q_seq_lens,
        kv_seq_lens,
        q_seq_offsets,
        kv_seq_offsets,
658
659
        attn_bias_type=attn_bias_type.value,
        attn_mask_type=attn_mask_type.value,
660
        qkv_layout=qkv_layout.value,
661
662
663
        scaling_factor=scaling_factor,
        dropout_probability=dropout_probability,
        is_training=is_training,
664
        max_segments_per_seq=max_segments_per_seq,
665
        window_size=window_size,
666
        context_parallel_strategy=context_parallel_strategy,
667
668
        context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
        context_parallel_axis=context_parallel_axis,
669
    )
670
671
    if attn_bias_type == AttnBiasType.NO_BIAS:
        grad_bias = None
672
    return grad_qkv, grad_bias, None, None, None, None, None
673
674
675


_fused_attn.defvjp(_fused_attn_fwd_rule, _fused_attn_bwd_rule)