attention.py 24.2 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
#
# See LICENSE for license information.
"""JAX multi-head attention modules"""

from enum import Enum
from functools import partial
8
from typing import Optional, Tuple
9
from jax.ad_checkpoint import checkpoint_name
10
11
12
import jax
import jax.numpy as jnp

13
14
15
from transformer_engine.transformer_engine_jax import NVTE_Bias_Type
from transformer_engine.transformer_engine_jax import NVTE_Mask_Type
from transformer_engine.transformer_engine_jax import NVTE_QKV_Layout
16
17
from transformer_engine.transformer_engine_jax import NVTE_QKV_Format
from transformer_engine.transformer_engine_jax import nvte_get_qkv_format
18

19
from . import cpp_extensions as tex
20
21
22


class AttnBiasType(Enum):
23
24
25
26
27
    """
    NO_BIAS: Softmax is performed as softmax(scale * qk)
    PRE_SCALE_BIAS: Softmax is performed as softmax(scale * (qk + bias))
    POST_SCALE_BIAS: Softmax is performed as softmax(scale * qk + bias)
    """
28

29
30
31
32
33
34
    NO_BIAS = NVTE_Bias_Type.NVTE_NO_BIAS
    PRE_SCALE_BIAS = NVTE_Bias_Type.NVTE_PRE_SCALE_BIAS
    POST_SCALE_BIAS = NVTE_Bias_Type.NVTE_POST_SCALE_BIAS


class AttnMaskType(Enum):
35
36
37
38
39
40
    """
    NO_MASK: No attention mask is applied.
    PADDING_MASK: Indicates the presence of paddings at the end of each sequence.
    CAUSAL_MASK: An upper triangular mask is applied to the softmax inputs.
    PADDING_CAUSAL_MASK: A combination of both causal and padding masks.
    """
41

42
43
44
    NO_MASK = NVTE_Mask_Type.NVTE_NO_MASK
    PADDING_MASK = NVTE_Mask_Type.NVTE_PADDING_MASK
    CAUSAL_MASK = NVTE_Mask_Type.NVTE_CAUSAL_MASK
45
    PADDING_CAUSAL_MASK = NVTE_Mask_Type.NVTE_PADDING_CAUSAL_MASK
46
47
    CAUSAL_BOTTOM_RIGHT_MASK = NVTE_Mask_Type.NVTE_CAUSAL_BOTTOM_RIGHT_MASK
    PADDING_CAUSAL_BOTTOM_RIGHT_MASK = NVTE_Mask_Type.NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK
48
49


50
class QKVLayout(Enum):
51
52
53
54
55
56
57
58
59
60
    """
    BSHD Format:
        - BS3HD: q,k,v are interleave packed as a tensor with shape [b, s, 3, h, d].
        - BSHD_BS2HD: q with shape [b, s, h, d] and kv are interleaved with shape [b, s, 2, h, d].
        - BSHD_BSHD_BSHD: q,k,v are seperate tensors with shape [b, s, h, d]
    THD Format: Shape is same as BSHD layout but allow multiple segments packed in a sequence.
        - T3HD: q,k,v are interleave packed as a tensor with shape [b, s, 3, h, d].
        - THD_T2HD: q with shape [b, s, h, d] and kv are interleaved with shape [b, s, 2, h, d].
        - THD_THD_THD: q,k,v are seperate tensors with shape [b, s, h, d]
    """
61

62
63
    BS3HD = NVTE_QKV_Layout.NVTE_BS3HD
    BSHD_BS2HD = NVTE_QKV_Layout.NVTE_BSHD_BS2HD
64
    BSHD_BSHD_BSHD = NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
    T3HD = NVTE_QKV_Layout.NVTE_T3HD
    THD_T2HD = NVTE_QKV_Layout.NVTE_THD_T2HD
    THD_THD_THD = NVTE_QKV_Layout.NVTE_THD_THD_THD


class QKVFormat(Enum):
    """
    SBHD: q,k,v memory layout with [s, b, ..., h, d]
    BSHD: q,k,v memory layout with [b, s, ..., h, d]
    THD: q,k,v memory layout is same as BSHD, but allow multiple segments packed in a sequence.
    """

    SBHD = NVTE_QKV_Format.NVTE_SBHD
    BSHD = NVTE_QKV_Format.NVTE_BSHD
    THD = NVTE_QKV_Format.NVTE_THD


82
83
84
85
86
87
88
89
90
91
92
93
94
class CPStrategy(Enum):
    """Defines the context parallel strategies of Jax fused attention.

    DEFAULT: Default strategy will choose automatically if context parallel axis is sharded.
    ALL_GATHER: All-gather/reduce scatter implementation.
    RING: Ring attention implementation (https://arxiv.org/abs/2310.01889).
    """

    DEFAULT = 0
    ALL_GATHER = 1
    RING = 2


95
96
97
98
99
def get_qkv_format(qkv_layout):
    """
    Get qkv_format from qkv_layout
    """
    return QKVFormat(nvte_get_qkv_format(qkv_layout.value))
100
101


102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
def make_swa_mask(
    max_seqlen_q: int,
    max_seqlen_kv: int,
    window_size: Optional[Tuple[int, int]] = None,
    attn_mask_type: AttnMaskType = AttnMaskType.NO_MASK,
    dtype: jax.typing.DTypeLike = jnp.float32,
):
    """
    Generate sliding window mask. `True` or `1` means keep the element.

    For `CAUSAL_BOTTOM_RIGHT_MASK` and `PADDING_CAUSAL_BOTTOM_RIGHT_MASK` mask type,
    the sliding window diagonal is aligned to the bottom right corner, and for other
    mask types, the top left corner.

    Parameters
    ----------
    max_seqlen_q: int
        Maximum sequence length for queries.
    max_seqlen_kv: int
        Maximum sequence length for keys and values.
    window_size: Optional[Tuple[int, int]] = None
        Sliding window size for local attention, where query at position i attends to keys
        in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
        + window_size[1]] inclusive. Negative number in window size means infinity window.
        `None` means no sliding window.
    attn_mask_type: AttnMaskType, default = AttnMaskType.NO_MASK
    dtype: jax.typing.DTypeLike, default=jnp.float32
        The mask data type.
    Returns
    ----------
    swa_mask: jax.numpy.tensor
        Matrix with shape [max_seqlen_q, max_seqlen_kv]. Elements with value 1 are the positions
        that will get attention, value 0 are the masked out positions.
    """
    swa_mask = jnp.ones((max_seqlen_q, max_seqlen_kv), dtype=dtype)
    if window_size is None:
        return swa_mask
    bottom_right_masks = [
        AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK,
        AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK,
    ]
    left_window, right_window = window_size
    if attn_mask_type in bottom_right_masks:
        if left_window < 0:
            left_window = max_seqlen_kv
        if right_window < 0:
            right_window = max_seqlen_kv
        bottom_right_shift = max_seqlen_kv - max_seqlen_q
        swa_mask = jnp.triu(swa_mask, k=-left_window + bottom_right_shift)
        swa_mask = jnp.tril(swa_mask, k=right_window + bottom_right_shift)
    else:
        if left_window < 0:
            left_window = max_seqlen_q
        if right_window < 0:
            right_window = max_seqlen_q
        swa_mask = jnp.triu(swa_mask, k=-left_window)
        swa_mask = jnp.tril(swa_mask, k=right_window)
    return swa_mask


162
163
164
165
166
167
def canonicalize_attn_mask_type(attn_mask_type: str):
    """Convert string attn_mask_type to AttnMaskType
    TE-JAX currently fall back to the padding version kernels for the libraries integration.
    The overhead between padding and non-padding version should be small.
    However, we will lease this limitation in the near feature.
    """
168
    match attn_mask_type:
169
        case "no_mask":
170
            return AttnMaskType.NO_MASK
171
        case "padding":
172
            return AttnMaskType.PADDING_MASK
173
        case "causal":
174
            return AttnMaskType.CAUSAL_MASK
175
176
        case "causal_bottom_right" | "bottom_right_causal":
            return AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK
177
        case "padding_causal" | "causal_padding":
178
            return AttnMaskType.PADDING_CAUSAL_MASK
179
180
181
182
183
184
185
        case (
            "padding_causal_bottom_right"
            | "causal_padding_bottom_right"
            | "bottom_right_causal_padding"
            | "bottom_right_padding_causal"
        ):
            return AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK
186
    raise ValueError(
187
188
189
        f"Unsupported {attn_mask_type=}, supported attn_mask_type={{'no_mask', 'padding', 'causal',"
        " 'padding_causal', 'causal_padding', 'causal_bottom_right',"
        " 'padding_causal_bottom_right'}"
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
    )


def is_fused_attn_kernel_available(
    q_dtype,
    kv_dtype,
    qkv_layout,
    attn_bias_type,
    attn_mask_type,
    dropout_probability,
    q_num_heads,
    kv_num_heads,
    q_max_seqlen,
    kv_max_seqlen,
    head_dim,
205
    window_size: Optional[Tuple[int, int]] = None,
206
):
207
    """
208
    To check whether the fused attention kernel is supported
209
    """
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230

    def make_helper(attn_mask_type):
        return tex.FusedAttnHelper(
            q_dtype,
            kv_dtype,
            qkv_layout.value,
            attn_bias_type.value,
            attn_mask_type.value,
            dropout_probability,
            q_num_heads,
            kv_num_heads,
            q_max_seqlen,
            kv_max_seqlen,
            head_dim,
            (-1, -1) if window_size is None else window_size,
        )

    if not make_helper(attn_mask_type).is_fused_attn_kernel_available():
        return False

    return True
231
232


233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
def _obtain_batch_and_max_seqlen(qkv, qkv_layout):
    match qkv_layout:
        case QKVLayout.BS3HD | QKVLayout.T3HD:
            assert len(qkv) == 1, f"qkv must be (qkvpacked,) with {qkv_layout=}"
            batch, q_max_seqlen, *_ = qkv[0].shape
            kv_max_seqlen = q_max_seqlen
        case QKVLayout.BSHD_BS2HD | QKVLayout.THD_T2HD:
            assert len(qkv) == 2, f"qkv must be (query, kvpacked) with {qkv_layout=}"
            batch, q_max_seqlen, *_ = qkv[0].shape
            kv_max_seqlen = qkv[1].shape[1]
        case QKVLayout.BSHD_BSHD_BSHD | QKVLayout.THD_THD_THD:
            assert len(qkv) == 3, f"qkv must be (query, key, value) with {qkv_layout=}"
            batch, q_max_seqlen, *_ = qkv[0].shape
            kv_max_seqlen = qkv[1].shape[1]
        case _:
            raise ValueError(f"Unsupported {qkv_layout=}")
    return batch, q_max_seqlen, kv_max_seqlen
250

251

252
253
def reorder_causal_load_balancing(tensor, cp_size: int, tensor_format: QKVFormat):
    """Reorders a tensor for load balancing the compute of causal attention."""
254
255
    seq_dim = 1 if tensor_format == QKVFormat.BSHD else 0
    return tex.attention.reorder_causal_load_balancing(tensor, cp_size, seq_dim, False)
256
257
258
259


def inverse_reorder_causal_load_balancing(tensor, cp_size: int, tensor_format: QKVFormat):
    """Inverse operation of `reorder_causal_load_balancing`."""
260
261
    seq_dim = 1 if tensor_format == QKVFormat.BSHD else 0
    return tex.attention.reorder_causal_load_balancing(tensor, cp_size, seq_dim, True)
262
263


264
265
266
267
268
def fused_attn(
    qkv: Tuple[jnp.ndarray, ...],
    bias: Optional[jnp.ndarray],
    mask: Optional[jnp.ndarray],
    seed: Optional[jnp.ndarray],
269
270
    attn_bias_type: AttnBiasType,
    attn_mask_type: AttnMaskType,
271
    qkv_layout: QKVLayout,
272
273
274
    scaling_factor: float,
    dropout_probability: float,
    is_training: bool,
275
    window_size: Optional[Tuple[int, int]] = None,
276
    context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT,
277
278
    context_parallel_causal_load_balanced: bool = False,
    context_parallel_axis: str = "",
279
):
280
    """
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
    Perform non-THD (non-packed) cuDNN fused attention.

    This function implements the following formula:
        BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2
    Args:
        qkv (Tuple[jnp.ndarray, ...]): A tuple containing query, key, and value tensors.
        It supports three formats:
            - `(qkv_packed,)`: For interleaved QKV packed format, typically used when query, key,
              and value have the same shape (e.g., self-attention).
            - `(query, kv_packed)`: For separate query and KV packed format, typically used when
              query has a different shape (e.g., cross-attention).
            - `(query, key, value)`: For separate query, key, and value tensors.
        bias (Optional[jnp.ndarray]): An optional bias tensor to be added to the attention scores.
        mask (Optional[jnp.ndarray]):
            An optional mask tensor to mask out the attention scores, `True` means mask out.
            Intra-sequence padding is not valid. The padded tokens can only on the right-most.
            Otherwise the results will be wrong.
        seed (Optional[jnp.ndarray]): Optional random seed for dropout.
        attn_bias_type (NVTE_Bias_Type): Type of attention bias.
        attn_mask_type (NVTE_Mask_Type): Type of attention mask.
        qkv_layout (NVTE_QKV_Layout): Layout of the QKV tensors.
        scaling_factor (float): Scaling factor for the attention scores.
        dropout_probability (float): Dropout probability to apply during attention.
        is_training (bool): Flag indicating whether the model is in training mode.
305
        window_size (Optional[Tuple[int, int]]): Sliding window size.
306
307
308
        context_parallel_causal_load_balanced (bool):
            Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
        context_parallel_axis (str): The name of the context parallel axis.
309
310
    Returns:
        (jnp.ndarray): The output tensor from the fused attention.
311
    """
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
    assert (
        get_qkv_format(qkv_layout) != QKVFormat.THD
    ), "Please use transformer_engine.jax.attention.fused_attn_thd for THD format."

    # Check inputs qkv
    match qkv_layout:
        case NVTE_QKV_Layout.NVTE_BS3HD:
            assert len(qkv) == 1, f"qkv=(packed_qkv,) is expected with {qkv_layout=} but got {qkv=}"
        case NVTE_QKV_Layout.NVTE_BSHD_BS2HD:
            assert (
                len(qkv) == 2
            ), f"qkv=(query, packed_kv) is expected with {qkv_layout=} but got {qkv=}"
        case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD:
            assert (
                len(qkv) == 3
            ), f"qkv=(query, key, value) is expected with {qkv_layout=} but got {qkv=}"

    # convert the mask to seqlens, mask doesn't support ragged offsets
330
331
332
333
334
    if attn_mask_type in [
        AttnMaskType.NO_MASK,
        AttnMaskType.CAUSAL_MASK,
        AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK,
    ]:
335
336
337
        batch, q_max_seqlen, kv_max_seqlen = _obtain_batch_and_max_seqlen(qkv, qkv_layout)
        q_seq_lens = jnp.full((batch,), q_max_seqlen, dtype=jnp.int32)
        kv_seq_lens = jnp.full((batch,), kv_max_seqlen, dtype=jnp.int32)
zlsh80826's avatar
zlsh80826 committed
338
    else:
339
        assert mask is not None
340
        mask = jnp.logical_not(mask)
341
        q_seq_lens = jnp.sum(mask, axis=-2, dtype=jnp.int32)[..., 0, 0]
342
        if attn_mask_type == AttnMaskType.PADDING_MASK:
343
            kv_seq_lens = jnp.sum(mask, axis=-1, dtype=jnp.int32)[..., 0, 0]
344
345
        else:
            # When mask is causal, the actual seqlen is not the last row, use max to find it
346
            kv_seq_lens = jnp.max(jnp.sum(mask, axis=-1, dtype=jnp.int32), axis=(-1, -2))
347

348
349
    output = _fused_attn(
        qkv,
350
        bias,
351
352
353
354
        q_seq_lens,
        kv_seq_lens,
        None,
        None,
355
        seed,
356
357
358
        attn_bias_type=attn_bias_type,
        attn_mask_type=attn_mask_type,
        qkv_layout=qkv_layout,
359
360
361
        scaling_factor=scaling_factor,
        dropout_probability=dropout_probability,
        is_training=is_training,
362
        max_segments_per_seq=1,
363
        window_size=window_size,
364
        context_parallel_strategy=context_parallel_strategy,
365
366
        context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
        context_parallel_axis=context_parallel_axis,
367
    )
368

369
    return output
370
371


372
373
374
375
376
377
378
379
def fused_attn_thd(
    qkv: Tuple[jnp.ndarray, ...],
    bias: Optional[jnp.ndarray],
    q_seq_lens: jnp.ndarray,
    kv_seq_lens: jnp.ndarray,
    q_seq_offsets: jnp.ndarray,
    kv_seq_offsets: jnp.ndarray,
    seed: Optional[jnp.ndarray],
380
381
    attn_bias_type: AttnBiasType,
    attn_mask_type: AttnMaskType,
382
    qkv_layout: QKVLayout,
383
384
385
    scaling_factor: float,
    dropout_probability: float,
    is_training: bool,
386
    max_segments_per_seq: int = 1,
387
    window_size: Optional[Tuple[int, int]] = None,
388
    context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT,
389
390
    context_parallel_causal_load_balanced: bool = False,
    context_parallel_axis: str = "",
391
):
392
    """
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
    (Experimental) Perform THD (packed) cuDNN fused attention.

    This function implements the following formula:
        BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2
    Args:
        qkv (Tuple[jnp.ndarray, ...]): A tuple containing query, key, and value tensors.
        It supports three formats:
            - `(qkv_packed,)`: For interleaved QKV packed format, typically used when query, key,
              and value have the same shape (e.g., self-attention).
            - `(query, kv_packed)`: For separate query and KV packed format, typically used when
              query has a different shape (e.g., cross-attention).
            - `(query, key, value)`: For separate query, key, and value tensors.
        bias (Optional[jnp.ndarray]): An optional bias tensor to be added to the attention scores.
        q_seqlen (jnp.ndarray):
            Sequence lengths for the query, with shape [batch, max_seqlen]. Unused positions are
            padded with -1.
        kv_seqlen (jnp.ndarray):
            Sequence lengths for the key and value, with shape [batch, max_seqlen]. Unused positions
            are padded with -1.
        q_seq_offsets (jnp.ndarray):
            The offsets in the sequence dim for the query, with shape [batch, max_seqlen + 1].
            Unused positions are padded with -1.
        kv_seq_offsets (jnp.ndarray):
            The offsets in the sequence dim for the query, with shape [batch, max_seqlen + 1].
            Unused positions are padded with -1.
        seed (Optional[jnp.ndarray]): Optional random seed for dropout.
        attn_bias_type (NVTE_Bias_Type): Type of attention bias.
        attn_mask_type (NVTE_Mask_Type): Type of attention mask.
        qkv_layout (NVTE_QKV_Layout): Layout of the QKV tensors.
        scaling_factor (float): Scaling factor for the attention scores.
        dropout_probability (float): Dropout probability to apply during attention.
        is_training (bool): Flag indicating whether the model is in training mode.
        max_segments_per_seq (int):
            Indicating the maximum number of segments inside a sequence. This parameter is to
            constrain the limit usage and need to be static during the e2e training. The XLA compile
            time and memory consumption is proportional to `max_segments_per_seq`.
429
430
        window_size (Optional[Tuple[int, int]]):
            Sliding window size.
431
432
433
        context_parallel_causal_load_balanced (bool):
            Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
        context_parallel_axis (str): The name of the context parallel axis.
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
    Returns:
        (jnp.ndarray): The output tensor from the fused attention.

    Examples:
        >>> # segment_ids = [[1, 1, 2, 3], [1, 1, 2, 0]], 0 means padded tokens
        >>> b, s, h, d = 2, 4, 12, 64
        >>> qkv = jnp.zeros((b, s, 3, h, d), dtype=jnp.bfloat16)
        >>> # 3 segments in first seq, 2 segments in second seq
        >>> q_seq_lens = kv_seq_lens = jnp.asarray([[2, 1, 1, -1], [2, 1, -1, -1]])
        >>> # seq_offsets need to include the end offset of the last segments
        >>> q_seq_offsets = kv_seq_offsets = jnp.asarray([[0, 2, 3, 4, -1], [0, 2, 3, -1, -1]])
        >>> out = fused_attn_thd((qkv,), None, q_seq_lens, kv_seq_lens,
                                 q_seq_offsets, kv_seq_offsets, None,
                                 AttnBiasType.NO_BIAS, AttnMaskType.PADDING_CAUSAL_MASK,
                                 QKVLayout.T3HD, 0.125, 0, True, 3)
449
    """
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
    assert (
        get_qkv_format(qkv_layout) == QKVFormat.THD
    ), "Please use transformer_engine.jax.attention.fused_attn for non-THD format."

    # Check inputs qkv
    match qkv_layout:
        case NVTE_QKV_Layout.NVTE_T3HD:
            assert len(qkv) == 1, f"qkv=(packed_qkv,) is expected with {qkv_layout=} but got {qkv=}"
        case NVTE_QKV_Layout.NVTE_THD_T2HD:
            assert (
                len(qkv) == 2
            ), f"qkv=(query, packed_kv) is expected with {qkv_layout=} but got {qkv=}"
        case NVTE_QKV_Layout.NVTE_THD_THD_THD:
            assert (
                len(qkv) == 3
            ), f"qkv=(query, key, value) is expected with {qkv_layout=} but got {qkv=}"

    batch, q_max_seqlen, kv_max_seqlen = _obtain_batch_and_max_seqlen(qkv, qkv_layout)
    assert q_seq_lens.shape == (batch, q_max_seqlen)
    assert kv_seq_lens.shape == (batch, kv_max_seqlen)
    assert q_seq_offsets.shape == (batch, q_max_seqlen + 1)
    assert kv_seq_offsets.shape == (batch, kv_max_seqlen + 1)
472

473
    output = _fused_attn(
474
        qkv,
475
        bias,
476
477
478
479
        q_seq_lens,
        kv_seq_lens,
        q_seq_offsets,
        kv_seq_offsets,
480
481
482
        seed,
        attn_bias_type=attn_bias_type,
        attn_mask_type=attn_mask_type,
483
        qkv_layout=qkv_layout,
484
485
486
        scaling_factor=scaling_factor,
        dropout_probability=dropout_probability,
        is_training=is_training,
487
        max_segments_per_seq=max_segments_per_seq,
488
        window_size=window_size,
489
        context_parallel_strategy=context_parallel_strategy,
490
491
        context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
        context_parallel_axis=context_parallel_axis,
492
    )
493
494
495
496

    return output


497
@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17))
498
def _fused_attn(
499
500
501
502
503
504
    qkv: Tuple[jnp.ndarray, ...],
    bias: Optional[jnp.ndarray],
    q_seq_lens: jnp.ndarray,
    kv_seq_lens: jnp.ndarray,
    q_seq_offsets: Optional[jnp.ndarray],
    kv_seq_offsets: Optional[jnp.ndarray],
505
506
507
    seed: jnp.ndarray,
    attn_bias_type: AttnBiasType,
    attn_mask_type: AttnMaskType,
508
    qkv_layout: QKVLayout,
509
510
511
    scaling_factor: float,
    dropout_probability: float,
    is_training: bool,
512
    max_segments_per_seq: int,
513
    window_size: Optional[Tuple[int, int]],
514
    context_parallel_strategy: CPStrategy,
515
516
    context_parallel_causal_load_balanced: bool,
    context_parallel_axis: str,
517
518
):
    output, _ = _fused_attn_fwd_rule(
519
        qkv,
520
        bias,
521
522
523
524
        q_seq_lens,
        kv_seq_lens,
        q_seq_offsets,
        kv_seq_offsets,
525
526
527
        seed,
        attn_bias_type,
        attn_mask_type,
528
        qkv_layout,
529
530
531
        scaling_factor,
        dropout_probability,
        is_training,
532
        max_segments_per_seq,
533
        window_size,
534
        context_parallel_strategy,
535
536
        context_parallel_causal_load_balanced,
        context_parallel_axis,
537
    )
538
539
540
    return output


541
def _fused_attn_fwd_rule(
542
    qkv,
543
    bias,
544
545
546
547
    q_seq_lens,
    kv_seq_lens,
    q_seq_offsets,
    kv_seq_offsets,
548
549
550
    seed,
    attn_bias_type,
    attn_mask_type,
551
    qkv_layout,
552
553
554
    scaling_factor,
    dropout_probability,
    is_training,
555
    max_segments_per_seq,
556
    window_size,
557
    context_parallel_strategy,
558
559
    context_parallel_causal_load_balanced,
    context_parallel_axis,
560
561
):
    output, softmax_aux, rng_state = tex.fused_attn_fwd(
562
        qkv,
563
        bias,
564
565
566
567
        q_seq_lens,
        kv_seq_lens,
        q_seq_offsets,
        kv_seq_offsets,
568
569
570
        seed,
        attn_bias_type=attn_bias_type.value,
        attn_mask_type=attn_mask_type.value,
571
        qkv_layout=qkv_layout.value,
572
573
574
        scaling_factor=scaling_factor,
        dropout_probability=dropout_probability,
        is_training=is_training,
575
        max_segments_per_seq=max_segments_per_seq,
576
        window_size=window_size,
577
        context_parallel_strategy=context_parallel_strategy,
578
579
        context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
        context_parallel_axis=context_parallel_axis,
580
581
582
583
584
    )
    output = checkpoint_name(output, "context")
    softmax_aux = checkpoint_name(softmax_aux, "context")
    rng_state = checkpoint_name(rng_state, "context")
    return output, (
585
        qkv,
586
        bias,
587
588
589
590
        q_seq_lens,
        kv_seq_lens,
        q_seq_offsets,
        kv_seq_offsets,
591
592
593
594
595
596
597
        softmax_aux,
        rng_state,
        output,
    )


def _fused_attn_bwd_rule(
598
599
600
601
602
603
604
    attn_bias_type,
    attn_mask_type,
    qkv_layout,
    scaling_factor,
    dropout_probability,
    is_training,
    max_segments_per_seq,
605
    window_size,
606
    context_parallel_strategy,
607
608
    context_parallel_causal_load_balanced,
    context_parallel_axis,
609
610
    ctx,
    dz,
611
):
612
613
614
615
616
617
618
619
620
621
622
623
624
    (
        qkv,
        bias,
        q_seq_lens,
        kv_seq_lens,
        q_seq_offsets,
        kv_seq_offsets,
        softmax_aux,
        rng_state,
        output,
    ) = ctx
    grad_qkv, grad_bias = tex.fused_attn_bwd(
        qkv,
625
626
627
628
629
        bias,
        softmax_aux,
        rng_state,
        output,
        dz,
630
631
632
633
        q_seq_lens,
        kv_seq_lens,
        q_seq_offsets,
        kv_seq_offsets,
634
635
        attn_bias_type=attn_bias_type.value,
        attn_mask_type=attn_mask_type.value,
636
        qkv_layout=qkv_layout.value,
637
638
639
        scaling_factor=scaling_factor,
        dropout_probability=dropout_probability,
        is_training=is_training,
640
        max_segments_per_seq=max_segments_per_seq,
641
        window_size=window_size,
642
        context_parallel_strategy=context_parallel_strategy,
643
644
        context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
        context_parallel_axis=context_parallel_axis,
645
    )
646
647
    if attn_bias_type == AttnBiasType.NO_BIAS:
        grad_bias = None
648
    return grad_qkv, grad_bias, None, None, None, None, None
649
650
651


_fused_attn.defvjp(_fused_attn_fwd_rule, _fused_attn_bwd_rule)