attention.py 23.6 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
#
# See LICENSE for license information.
"""JAX multi-head attention modules"""

from enum import Enum
from functools import partial
8
from typing import Optional, Tuple
9
from jax.ad_checkpoint import checkpoint_name
10
11
12
import jax
import jax.numpy as jnp

13
14
15
from transformer_engine.transformer_engine_jax import NVTE_Bias_Type
from transformer_engine.transformer_engine_jax import NVTE_Mask_Type
from transformer_engine.transformer_engine_jax import NVTE_QKV_Layout
16
17
from transformer_engine.transformer_engine_jax import NVTE_QKV_Format
from transformer_engine.transformer_engine_jax import nvte_get_qkv_format
18

19
from . import cpp_extensions as tex
20
21
22


class AttnBiasType(Enum):
23
24
25
26
27
    """
    NO_BIAS: Softmax is performed as softmax(scale * qk)
    PRE_SCALE_BIAS: Softmax is performed as softmax(scale * (qk + bias))
    POST_SCALE_BIAS: Softmax is performed as softmax(scale * qk + bias)
    """
28

29
30
31
32
33
34
    NO_BIAS = NVTE_Bias_Type.NVTE_NO_BIAS
    PRE_SCALE_BIAS = NVTE_Bias_Type.NVTE_PRE_SCALE_BIAS
    POST_SCALE_BIAS = NVTE_Bias_Type.NVTE_POST_SCALE_BIAS


class AttnMaskType(Enum):
35
36
37
38
39
40
    """
    NO_MASK: No attention mask is applied.
    PADDING_MASK: Indicates the presence of paddings at the end of each sequence.
    CAUSAL_MASK: An upper triangular mask is applied to the softmax inputs.
    PADDING_CAUSAL_MASK: A combination of both causal and padding masks.
    """
41

42
43
44
    NO_MASK = NVTE_Mask_Type.NVTE_NO_MASK
    PADDING_MASK = NVTE_Mask_Type.NVTE_PADDING_MASK
    CAUSAL_MASK = NVTE_Mask_Type.NVTE_CAUSAL_MASK
45
    PADDING_CAUSAL_MASK = NVTE_Mask_Type.NVTE_PADDING_CAUSAL_MASK
46
47
    CAUSAL_BOTTOM_RIGHT_MASK = NVTE_Mask_Type.NVTE_CAUSAL_BOTTOM_RIGHT_MASK
    PADDING_CAUSAL_BOTTOM_RIGHT_MASK = NVTE_Mask_Type.NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK
48
49


50
class QKVLayout(Enum):
51
52
53
54
55
56
57
58
59
60
    """
    BSHD Format:
        - BS3HD: q,k,v are interleave packed as a tensor with shape [b, s, 3, h, d].
        - BSHD_BS2HD: q with shape [b, s, h, d] and kv are interleaved with shape [b, s, 2, h, d].
        - BSHD_BSHD_BSHD: q,k,v are seperate tensors with shape [b, s, h, d]
    THD Format: Shape is same as BSHD layout but allow multiple segments packed in a sequence.
        - T3HD: q,k,v are interleave packed as a tensor with shape [b, s, 3, h, d].
        - THD_T2HD: q with shape [b, s, h, d] and kv are interleaved with shape [b, s, 2, h, d].
        - THD_THD_THD: q,k,v are seperate tensors with shape [b, s, h, d]
    """
61

62
63
    BS3HD = NVTE_QKV_Layout.NVTE_BS3HD
    BSHD_BS2HD = NVTE_QKV_Layout.NVTE_BSHD_BS2HD
64
    BSHD_BSHD_BSHD = NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
    T3HD = NVTE_QKV_Layout.NVTE_T3HD
    THD_T2HD = NVTE_QKV_Layout.NVTE_THD_T2HD
    THD_THD_THD = NVTE_QKV_Layout.NVTE_THD_THD_THD


class QKVFormat(Enum):
    """
    SBHD: q,k,v memory layout with [s, b, ..., h, d]
    BSHD: q,k,v memory layout with [b, s, ..., h, d]
    THD: q,k,v memory layout is same as BSHD, but allow multiple segments packed in a sequence.
    """

    SBHD = NVTE_QKV_Format.NVTE_SBHD
    BSHD = NVTE_QKV_Format.NVTE_BSHD
    THD = NVTE_QKV_Format.NVTE_THD


def get_qkv_format(qkv_layout):
    """
    Get qkv_format from qkv_layout
    """
    return QKVFormat(nvte_get_qkv_format(qkv_layout.value))
87
88


89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
def make_swa_mask(
    max_seqlen_q: int,
    max_seqlen_kv: int,
    window_size: Optional[Tuple[int, int]] = None,
    attn_mask_type: AttnMaskType = AttnMaskType.NO_MASK,
    dtype: jax.typing.DTypeLike = jnp.float32,
):
    """
    Generate sliding window mask. `True` or `1` means keep the element.

    For `CAUSAL_BOTTOM_RIGHT_MASK` and `PADDING_CAUSAL_BOTTOM_RIGHT_MASK` mask type,
    the sliding window diagonal is aligned to the bottom right corner, and for other
    mask types, the top left corner.

    Parameters
    ----------
    max_seqlen_q: int
        Maximum sequence length for queries.
    max_seqlen_kv: int
        Maximum sequence length for keys and values.
    window_size: Optional[Tuple[int, int]] = None
        Sliding window size for local attention, where query at position i attends to keys
        in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
        + window_size[1]] inclusive. Negative number in window size means infinity window.
        `None` means no sliding window.
    attn_mask_type: AttnMaskType, default = AttnMaskType.NO_MASK
    dtype: jax.typing.DTypeLike, default=jnp.float32
        The mask data type.
    Returns
    ----------
    swa_mask: jax.numpy.tensor
        Matrix with shape [max_seqlen_q, max_seqlen_kv]. Elements with value 1 are the positions
        that will get attention, value 0 are the masked out positions.
    """
    swa_mask = jnp.ones((max_seqlen_q, max_seqlen_kv), dtype=dtype)
    if window_size is None:
        return swa_mask
    bottom_right_masks = [
        AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK,
        AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK,
    ]
    left_window, right_window = window_size
    if attn_mask_type in bottom_right_masks:
        if left_window < 0:
            left_window = max_seqlen_kv
        if right_window < 0:
            right_window = max_seqlen_kv
        bottom_right_shift = max_seqlen_kv - max_seqlen_q
        swa_mask = jnp.triu(swa_mask, k=-left_window + bottom_right_shift)
        swa_mask = jnp.tril(swa_mask, k=right_window + bottom_right_shift)
    else:
        if left_window < 0:
            left_window = max_seqlen_q
        if right_window < 0:
            right_window = max_seqlen_q
        swa_mask = jnp.triu(swa_mask, k=-left_window)
        swa_mask = jnp.tril(swa_mask, k=right_window)
    return swa_mask


149
150
151
152
153
154
def canonicalize_attn_mask_type(attn_mask_type: str):
    """Convert string attn_mask_type to AttnMaskType
    TE-JAX currently fall back to the padding version kernels for the libraries integration.
    The overhead between padding and non-padding version should be small.
    However, we will lease this limitation in the near feature.
    """
155
    match attn_mask_type:
156
        case "no_mask":
157
            return AttnMaskType.NO_MASK
158
        case "padding":
159
            return AttnMaskType.PADDING_MASK
160
        case "causal":
161
            return AttnMaskType.CAUSAL_MASK
162
163
        case "causal_bottom_right" | "bottom_right_causal":
            return AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK
164
        case "padding_causal" | "causal_padding":
165
            return AttnMaskType.PADDING_CAUSAL_MASK
166
167
168
169
170
171
172
        case (
            "padding_causal_bottom_right"
            | "causal_padding_bottom_right"
            | "bottom_right_causal_padding"
            | "bottom_right_padding_causal"
        ):
            return AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK
173
    raise ValueError(
174
175
176
        f"Unsupported {attn_mask_type=}, supported attn_mask_type={{'no_mask', 'padding', 'causal',"
        " 'padding_causal', 'causal_padding', 'causal_bottom_right',"
        " 'padding_causal_bottom_right'}"
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
    )


def is_fused_attn_kernel_available(
    q_dtype,
    kv_dtype,
    qkv_layout,
    attn_bias_type,
    attn_mask_type,
    dropout_probability,
    q_num_heads,
    kv_num_heads,
    q_max_seqlen,
    kv_max_seqlen,
    head_dim,
192
    window_size: Optional[Tuple[int, int]] = None,
193
    is_context_parallel: bool = False,
194
):
195
    """
196
    To check whether the fused attention kernel is supported
197
    """
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223

    def make_helper(attn_mask_type):
        return tex.FusedAttnHelper(
            q_dtype,
            kv_dtype,
            qkv_layout.value,
            attn_bias_type.value,
            attn_mask_type.value,
            dropout_probability,
            q_num_heads,
            kv_num_heads,
            q_max_seqlen,
            kv_max_seqlen,
            head_dim,
            (-1, -1) if window_size is None else window_size,
        )

    if not make_helper(attn_mask_type).is_fused_attn_kernel_available():
        return False

    # For context parallel need to check additional masking types
    if is_context_parallel and attn_mask_type == AttnMaskType.CAUSAL_MASK:
        if not make_helper(AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK).is_fused_attn_kernel_available():
            return False

    return True
224
225


226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
def _obtain_batch_and_max_seqlen(qkv, qkv_layout):
    match qkv_layout:
        case QKVLayout.BS3HD | QKVLayout.T3HD:
            assert len(qkv) == 1, f"qkv must be (qkvpacked,) with {qkv_layout=}"
            batch, q_max_seqlen, *_ = qkv[0].shape
            kv_max_seqlen = q_max_seqlen
        case QKVLayout.BSHD_BS2HD | QKVLayout.THD_T2HD:
            assert len(qkv) == 2, f"qkv must be (query, kvpacked) with {qkv_layout=}"
            batch, q_max_seqlen, *_ = qkv[0].shape
            kv_max_seqlen = qkv[1].shape[1]
        case QKVLayout.BSHD_BSHD_BSHD | QKVLayout.THD_THD_THD:
            assert len(qkv) == 3, f"qkv must be (query, key, value) with {qkv_layout=}"
            batch, q_max_seqlen, *_ = qkv[0].shape
            kv_max_seqlen = qkv[1].shape[1]
        case _:
            raise ValueError(f"Unsupported {qkv_layout=}")
    return batch, q_max_seqlen, kv_max_seqlen
243

244

245
246
def reorder_causal_load_balancing(tensor, cp_size: int, tensor_format: QKVFormat):
    """Reorders a tensor for load balancing the compute of causal attention."""
247
248
    seq_dim = 1 if tensor_format == QKVFormat.BSHD else 0
    return tex.attention.reorder_causal_load_balancing(tensor, cp_size, seq_dim, False)
249
250
251
252


def inverse_reorder_causal_load_balancing(tensor, cp_size: int, tensor_format: QKVFormat):
    """Inverse operation of `reorder_causal_load_balancing`."""
253
254
    seq_dim = 1 if tensor_format == QKVFormat.BSHD else 0
    return tex.attention.reorder_causal_load_balancing(tensor, cp_size, seq_dim, True)
255
256


257
258
259
260
261
def fused_attn(
    qkv: Tuple[jnp.ndarray, ...],
    bias: Optional[jnp.ndarray],
    mask: Optional[jnp.ndarray],
    seed: Optional[jnp.ndarray],
262
263
    attn_bias_type: AttnBiasType,
    attn_mask_type: AttnMaskType,
264
    qkv_layout: QKVLayout,
265
266
267
    scaling_factor: float,
    dropout_probability: float,
    is_training: bool,
268
    window_size: Optional[Tuple[int, int]] = None,
269
270
    context_parallel_causal_load_balanced: bool = False,
    context_parallel_axis: str = "",
271
):
272
    """
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
    Perform non-THD (non-packed) cuDNN fused attention.

    This function implements the following formula:
        BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2
    Args:
        qkv (Tuple[jnp.ndarray, ...]): A tuple containing query, key, and value tensors.
        It supports three formats:
            - `(qkv_packed,)`: For interleaved QKV packed format, typically used when query, key,
              and value have the same shape (e.g., self-attention).
            - `(query, kv_packed)`: For separate query and KV packed format, typically used when
              query has a different shape (e.g., cross-attention).
            - `(query, key, value)`: For separate query, key, and value tensors.
        bias (Optional[jnp.ndarray]): An optional bias tensor to be added to the attention scores.
        mask (Optional[jnp.ndarray]):
            An optional mask tensor to mask out the attention scores, `True` means mask out.
            Intra-sequence padding is not valid. The padded tokens can only on the right-most.
            Otherwise the results will be wrong.
        seed (Optional[jnp.ndarray]): Optional random seed for dropout.
        attn_bias_type (NVTE_Bias_Type): Type of attention bias.
        attn_mask_type (NVTE_Mask_Type): Type of attention mask.
        qkv_layout (NVTE_QKV_Layout): Layout of the QKV tensors.
        scaling_factor (float): Scaling factor for the attention scores.
        dropout_probability (float): Dropout probability to apply during attention.
        is_training (bool): Flag indicating whether the model is in training mode.
297
        window_size (Optional[Tuple[int, int]]): Sliding window size.
298
299
300
        context_parallel_causal_load_balanced (bool):
            Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
        context_parallel_axis (str): The name of the context parallel axis.
301
302
    Returns:
        (jnp.ndarray): The output tensor from the fused attention.
303
    """
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
    assert (
        get_qkv_format(qkv_layout) != QKVFormat.THD
    ), "Please use transformer_engine.jax.attention.fused_attn_thd for THD format."

    # Check inputs qkv
    match qkv_layout:
        case NVTE_QKV_Layout.NVTE_BS3HD:
            assert len(qkv) == 1, f"qkv=(packed_qkv,) is expected with {qkv_layout=} but got {qkv=}"
        case NVTE_QKV_Layout.NVTE_BSHD_BS2HD:
            assert (
                len(qkv) == 2
            ), f"qkv=(query, packed_kv) is expected with {qkv_layout=} but got {qkv=}"
        case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD:
            assert (
                len(qkv) == 3
            ), f"qkv=(query, key, value) is expected with {qkv_layout=} but got {qkv=}"

    # convert the mask to seqlens, mask doesn't support ragged offsets
322
323
324
325
326
    if attn_mask_type in [
        AttnMaskType.NO_MASK,
        AttnMaskType.CAUSAL_MASK,
        AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK,
    ]:
327
328
329
        batch, q_max_seqlen, kv_max_seqlen = _obtain_batch_and_max_seqlen(qkv, qkv_layout)
        q_seq_lens = jnp.full((batch,), q_max_seqlen, dtype=jnp.int32)
        kv_seq_lens = jnp.full((batch,), kv_max_seqlen, dtype=jnp.int32)
zlsh80826's avatar
zlsh80826 committed
330
    else:
331
        assert mask is not None
332
        mask = jnp.logical_not(mask)
333
        q_seq_lens = jnp.sum(mask, axis=-2, dtype=jnp.int32)[..., 0, 0]
334
        if attn_mask_type == AttnMaskType.PADDING_MASK:
335
            kv_seq_lens = jnp.sum(mask, axis=-1, dtype=jnp.int32)[..., 0, 0]
336
337
        else:
            # When mask is causal, the actual seqlen is not the last row, use max to find it
338
            kv_seq_lens = jnp.max(jnp.sum(mask, axis=-1, dtype=jnp.int32), axis=(-1, -2))
339

340
341
    output = _fused_attn(
        qkv,
342
        bias,
343
344
345
346
        q_seq_lens,
        kv_seq_lens,
        None,
        None,
347
        seed,
348
349
350
        attn_bias_type=attn_bias_type,
        attn_mask_type=attn_mask_type,
        qkv_layout=qkv_layout,
351
352
353
        scaling_factor=scaling_factor,
        dropout_probability=dropout_probability,
        is_training=is_training,
354
        max_segments_per_seq=1,
355
        window_size=window_size,
356
357
        context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
        context_parallel_axis=context_parallel_axis,
358
    )
359

360
    return output
361
362


363
364
365
366
367
368
369
370
def fused_attn_thd(
    qkv: Tuple[jnp.ndarray, ...],
    bias: Optional[jnp.ndarray],
    q_seq_lens: jnp.ndarray,
    kv_seq_lens: jnp.ndarray,
    q_seq_offsets: jnp.ndarray,
    kv_seq_offsets: jnp.ndarray,
    seed: Optional[jnp.ndarray],
371
372
    attn_bias_type: AttnBiasType,
    attn_mask_type: AttnMaskType,
373
    qkv_layout: QKVLayout,
374
375
376
    scaling_factor: float,
    dropout_probability: float,
    is_training: bool,
377
    max_segments_per_seq: int = 1,
378
    window_size: Optional[Tuple[int, int]] = None,
379
380
    context_parallel_causal_load_balanced: bool = False,
    context_parallel_axis: str = "",
381
):
382
    """
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
    (Experimental) Perform THD (packed) cuDNN fused attention.

    This function implements the following formula:
        BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2
    Args:
        qkv (Tuple[jnp.ndarray, ...]): A tuple containing query, key, and value tensors.
        It supports three formats:
            - `(qkv_packed,)`: For interleaved QKV packed format, typically used when query, key,
              and value have the same shape (e.g., self-attention).
            - `(query, kv_packed)`: For separate query and KV packed format, typically used when
              query has a different shape (e.g., cross-attention).
            - `(query, key, value)`: For separate query, key, and value tensors.
        bias (Optional[jnp.ndarray]): An optional bias tensor to be added to the attention scores.
        q_seqlen (jnp.ndarray):
            Sequence lengths for the query, with shape [batch, max_seqlen]. Unused positions are
            padded with -1.
        kv_seqlen (jnp.ndarray):
            Sequence lengths for the key and value, with shape [batch, max_seqlen]. Unused positions
            are padded with -1.
        q_seq_offsets (jnp.ndarray):
            The offsets in the sequence dim for the query, with shape [batch, max_seqlen + 1].
            Unused positions are padded with -1.
        kv_seq_offsets (jnp.ndarray):
            The offsets in the sequence dim for the query, with shape [batch, max_seqlen + 1].
            Unused positions are padded with -1.
        seed (Optional[jnp.ndarray]): Optional random seed for dropout.
        attn_bias_type (NVTE_Bias_Type): Type of attention bias.
        attn_mask_type (NVTE_Mask_Type): Type of attention mask.
        qkv_layout (NVTE_QKV_Layout): Layout of the QKV tensors.
        scaling_factor (float): Scaling factor for the attention scores.
        dropout_probability (float): Dropout probability to apply during attention.
        is_training (bool): Flag indicating whether the model is in training mode.
        max_segments_per_seq (int):
            Indicating the maximum number of segments inside a sequence. This parameter is to
            constrain the limit usage and need to be static during the e2e training. The XLA compile
            time and memory consumption is proportional to `max_segments_per_seq`.
419
420
        window_size (Optional[Tuple[int, int]]):
            Sliding window size.
421
422
423
        context_parallel_causal_load_balanced (bool):
            Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
        context_parallel_axis (str): The name of the context parallel axis.
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
    Returns:
        (jnp.ndarray): The output tensor from the fused attention.

    Examples:
        >>> # segment_ids = [[1, 1, 2, 3], [1, 1, 2, 0]], 0 means padded tokens
        >>> b, s, h, d = 2, 4, 12, 64
        >>> qkv = jnp.zeros((b, s, 3, h, d), dtype=jnp.bfloat16)
        >>> # 3 segments in first seq, 2 segments in second seq
        >>> q_seq_lens = kv_seq_lens = jnp.asarray([[2, 1, 1, -1], [2, 1, -1, -1]])
        >>> # seq_offsets need to include the end offset of the last segments
        >>> q_seq_offsets = kv_seq_offsets = jnp.asarray([[0, 2, 3, 4, -1], [0, 2, 3, -1, -1]])
        >>> out = fused_attn_thd((qkv,), None, q_seq_lens, kv_seq_lens,
                                 q_seq_offsets, kv_seq_offsets, None,
                                 AttnBiasType.NO_BIAS, AttnMaskType.PADDING_CAUSAL_MASK,
                                 QKVLayout.T3HD, 0.125, 0, True, 3)
439
    """
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
    assert (
        get_qkv_format(qkv_layout) == QKVFormat.THD
    ), "Please use transformer_engine.jax.attention.fused_attn for non-THD format."

    # Check inputs qkv
    match qkv_layout:
        case NVTE_QKV_Layout.NVTE_T3HD:
            assert len(qkv) == 1, f"qkv=(packed_qkv,) is expected with {qkv_layout=} but got {qkv=}"
        case NVTE_QKV_Layout.NVTE_THD_T2HD:
            assert (
                len(qkv) == 2
            ), f"qkv=(query, packed_kv) is expected with {qkv_layout=} but got {qkv=}"
        case NVTE_QKV_Layout.NVTE_THD_THD_THD:
            assert (
                len(qkv) == 3
            ), f"qkv=(query, key, value) is expected with {qkv_layout=} but got {qkv=}"

    batch, q_max_seqlen, kv_max_seqlen = _obtain_batch_and_max_seqlen(qkv, qkv_layout)
    assert q_seq_lens.shape == (batch, q_max_seqlen)
    assert kv_seq_lens.shape == (batch, kv_max_seqlen)
    assert q_seq_offsets.shape == (batch, q_max_seqlen + 1)
    assert kv_seq_offsets.shape == (batch, kv_max_seqlen + 1)
462

463
    output = _fused_attn(
464
        qkv,
465
        bias,
466
467
468
469
        q_seq_lens,
        kv_seq_lens,
        q_seq_offsets,
        kv_seq_offsets,
470
471
472
        seed,
        attn_bias_type=attn_bias_type,
        attn_mask_type=attn_mask_type,
473
        qkv_layout=qkv_layout,
474
475
476
        scaling_factor=scaling_factor,
        dropout_probability=dropout_probability,
        is_training=is_training,
477
        max_segments_per_seq=max_segments_per_seq,
478
        window_size=window_size,
479
480
        context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
        context_parallel_axis=context_parallel_axis,
481
    )
482
483
484
485

    return output


486
@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16))
487
def _fused_attn(
488
489
490
491
492
493
    qkv: Tuple[jnp.ndarray, ...],
    bias: Optional[jnp.ndarray],
    q_seq_lens: jnp.ndarray,
    kv_seq_lens: jnp.ndarray,
    q_seq_offsets: Optional[jnp.ndarray],
    kv_seq_offsets: Optional[jnp.ndarray],
494
495
496
    seed: jnp.ndarray,
    attn_bias_type: AttnBiasType,
    attn_mask_type: AttnMaskType,
497
    qkv_layout: QKVLayout,
498
499
500
    scaling_factor: float,
    dropout_probability: float,
    is_training: bool,
501
    max_segments_per_seq: int,
502
    window_size: Optional[Tuple[int, int]],
503
504
    context_parallel_causal_load_balanced: bool,
    context_parallel_axis: str,
505
506
):
    output, _ = _fused_attn_fwd_rule(
507
        qkv,
508
        bias,
509
510
511
512
        q_seq_lens,
        kv_seq_lens,
        q_seq_offsets,
        kv_seq_offsets,
513
514
515
        seed,
        attn_bias_type,
        attn_mask_type,
516
        qkv_layout,
517
518
519
        scaling_factor,
        dropout_probability,
        is_training,
520
        max_segments_per_seq,
521
        window_size,
522
523
        context_parallel_causal_load_balanced,
        context_parallel_axis,
524
    )
525
526
527
    return output


528
def _fused_attn_fwd_rule(
529
    qkv,
530
    bias,
531
532
533
534
    q_seq_lens,
    kv_seq_lens,
    q_seq_offsets,
    kv_seq_offsets,
535
536
537
    seed,
    attn_bias_type,
    attn_mask_type,
538
    qkv_layout,
539
540
541
    scaling_factor,
    dropout_probability,
    is_training,
542
    max_segments_per_seq,
543
    window_size,
544
545
    context_parallel_causal_load_balanced,
    context_parallel_axis,
546
547
):
    output, softmax_aux, rng_state = tex.fused_attn_fwd(
548
        qkv,
549
        bias,
550
551
552
553
        q_seq_lens,
        kv_seq_lens,
        q_seq_offsets,
        kv_seq_offsets,
554
555
556
        seed,
        attn_bias_type=attn_bias_type.value,
        attn_mask_type=attn_mask_type.value,
557
        qkv_layout=qkv_layout.value,
558
559
560
        scaling_factor=scaling_factor,
        dropout_probability=dropout_probability,
        is_training=is_training,
561
        max_segments_per_seq=max_segments_per_seq,
562
        window_size=window_size,
563
564
        context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
        context_parallel_axis=context_parallel_axis,
565
566
567
568
569
    )
    output = checkpoint_name(output, "context")
    softmax_aux = checkpoint_name(softmax_aux, "context")
    rng_state = checkpoint_name(rng_state, "context")
    return output, (
570
        qkv,
571
        bias,
572
573
574
575
        q_seq_lens,
        kv_seq_lens,
        q_seq_offsets,
        kv_seq_offsets,
576
577
578
579
580
581
582
        softmax_aux,
        rng_state,
        output,
    )


def _fused_attn_bwd_rule(
583
584
585
586
587
588
589
    attn_bias_type,
    attn_mask_type,
    qkv_layout,
    scaling_factor,
    dropout_probability,
    is_training,
    max_segments_per_seq,
590
    window_size,
591
592
    context_parallel_causal_load_balanced,
    context_parallel_axis,
593
594
    ctx,
    dz,
595
):
596
597
598
599
600
601
602
603
604
605
606
607
608
    (
        qkv,
        bias,
        q_seq_lens,
        kv_seq_lens,
        q_seq_offsets,
        kv_seq_offsets,
        softmax_aux,
        rng_state,
        output,
    ) = ctx
    grad_qkv, grad_bias = tex.fused_attn_bwd(
        qkv,
609
610
611
612
613
        bias,
        softmax_aux,
        rng_state,
        output,
        dz,
614
615
616
617
        q_seq_lens,
        kv_seq_lens,
        q_seq_offsets,
        kv_seq_offsets,
618
619
        attn_bias_type=attn_bias_type.value,
        attn_mask_type=attn_mask_type.value,
620
        qkv_layout=qkv_layout.value,
621
622
623
        scaling_factor=scaling_factor,
        dropout_probability=dropout_probability,
        is_training=is_training,
624
        max_segments_per_seq=max_segments_per_seq,
625
        window_size=window_size,
626
627
        context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
        context_parallel_axis=context_parallel_axis,
628
    )
629
630
    if attn_bias_type == AttnBiasType.NO_BIAS:
        grad_bias = None
631
    return grad_qkv, grad_bias, None, None, None, None, None
632
633
634


_fused_attn.defvjp(_fused_attn_fwd_rule, _fused_attn_bwd_rule)