attention.py 23.3 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
#
# See LICENSE for license information.
"""JAX multi-head attention modules"""

from enum import Enum
from functools import partial
8
from typing import Optional, Tuple
9
from jax.ad_checkpoint import checkpoint_name
10
11
12
import jax
import jax.numpy as jnp

13
14
15
from transformer_engine.transformer_engine_jax import NVTE_Bias_Type
from transformer_engine.transformer_engine_jax import NVTE_Mask_Type
from transformer_engine.transformer_engine_jax import NVTE_QKV_Layout
16
17
from transformer_engine.transformer_engine_jax import NVTE_QKV_Format
from transformer_engine.transformer_engine_jax import nvte_get_qkv_format
18

19
from . import cpp_extensions as tex
20
21
22


class AttnBiasType(Enum):
23
24
25
26
27
    """
    NO_BIAS: Softmax is performed as softmax(scale * qk)
    PRE_SCALE_BIAS: Softmax is performed as softmax(scale * (qk + bias))
    POST_SCALE_BIAS: Softmax is performed as softmax(scale * qk + bias)
    """
28

29
30
31
32
33
34
    NO_BIAS = NVTE_Bias_Type.NVTE_NO_BIAS
    PRE_SCALE_BIAS = NVTE_Bias_Type.NVTE_PRE_SCALE_BIAS
    POST_SCALE_BIAS = NVTE_Bias_Type.NVTE_POST_SCALE_BIAS


class AttnMaskType(Enum):
35
36
37
38
39
40
    """
    NO_MASK: No attention mask is applied.
    PADDING_MASK: Indicates the presence of paddings at the end of each sequence.
    CAUSAL_MASK: An upper triangular mask is applied to the softmax inputs.
    PADDING_CAUSAL_MASK: A combination of both causal and padding masks.
    """
41

42
43
44
    NO_MASK = NVTE_Mask_Type.NVTE_NO_MASK
    PADDING_MASK = NVTE_Mask_Type.NVTE_PADDING_MASK
    CAUSAL_MASK = NVTE_Mask_Type.NVTE_CAUSAL_MASK
45
    PADDING_CAUSAL_MASK = NVTE_Mask_Type.NVTE_PADDING_CAUSAL_MASK
46
47
    CAUSAL_BOTTOM_RIGHT_MASK = NVTE_Mask_Type.NVTE_CAUSAL_BOTTOM_RIGHT_MASK
    PADDING_CAUSAL_BOTTOM_RIGHT_MASK = NVTE_Mask_Type.NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK
48
49


50
class QKVLayout(Enum):
51
52
53
54
55
56
57
58
59
60
    """
    BSHD Format:
        - BS3HD: q,k,v are interleave packed as a tensor with shape [b, s, 3, h, d].
        - BSHD_BS2HD: q with shape [b, s, h, d] and kv are interleaved with shape [b, s, 2, h, d].
        - BSHD_BSHD_BSHD: q,k,v are seperate tensors with shape [b, s, h, d]
    THD Format: Shape is same as BSHD layout but allow multiple segments packed in a sequence.
        - T3HD: q,k,v are interleave packed as a tensor with shape [b, s, 3, h, d].
        - THD_T2HD: q with shape [b, s, h, d] and kv are interleaved with shape [b, s, 2, h, d].
        - THD_THD_THD: q,k,v are seperate tensors with shape [b, s, h, d]
    """
61

62
63
    BS3HD = NVTE_QKV_Layout.NVTE_BS3HD
    BSHD_BS2HD = NVTE_QKV_Layout.NVTE_BSHD_BS2HD
64
    BSHD_BSHD_BSHD = NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
    T3HD = NVTE_QKV_Layout.NVTE_T3HD
    THD_T2HD = NVTE_QKV_Layout.NVTE_THD_T2HD
    THD_THD_THD = NVTE_QKV_Layout.NVTE_THD_THD_THD


class QKVFormat(Enum):
    """
    SBHD: q,k,v memory layout with [s, b, ..., h, d]
    BSHD: q,k,v memory layout with [b, s, ..., h, d]
    THD: q,k,v memory layout is same as BSHD, but allow multiple segments packed in a sequence.
    """

    SBHD = NVTE_QKV_Format.NVTE_SBHD
    BSHD = NVTE_QKV_Format.NVTE_BSHD
    THD = NVTE_QKV_Format.NVTE_THD


def get_qkv_format(qkv_layout):
    """
    Get qkv_format from qkv_layout
    """
    return QKVFormat(nvte_get_qkv_format(qkv_layout.value))
87
88


89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
def make_swa_mask(
    max_seqlen_q: int,
    max_seqlen_kv: int,
    window_size: Optional[Tuple[int, int]] = None,
    attn_mask_type: AttnMaskType = AttnMaskType.NO_MASK,
    dtype: jax.typing.DTypeLike = jnp.float32,
):
    """
    Generate sliding window mask. `True` or `1` means keep the element.

    For `CAUSAL_BOTTOM_RIGHT_MASK` and `PADDING_CAUSAL_BOTTOM_RIGHT_MASK` mask type,
    the sliding window diagonal is aligned to the bottom right corner, and for other
    mask types, the top left corner.

    Parameters
    ----------
    max_seqlen_q: int
        Maximum sequence length for queries.
    max_seqlen_kv: int
        Maximum sequence length for keys and values.
    window_size: Optional[Tuple[int, int]] = None
        Sliding window size for local attention, where query at position i attends to keys
        in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
        + window_size[1]] inclusive. Negative number in window size means infinity window.
        `None` means no sliding window.
    attn_mask_type: AttnMaskType, default = AttnMaskType.NO_MASK
    dtype: jax.typing.DTypeLike, default=jnp.float32
        The mask data type.
    Returns
    ----------
    swa_mask: jax.numpy.tensor
        Matrix with shape [max_seqlen_q, max_seqlen_kv]. Elements with value 1 are the positions
        that will get attention, value 0 are the masked out positions.
    """
    swa_mask = jnp.ones((max_seqlen_q, max_seqlen_kv), dtype=dtype)
    if window_size is None:
        return swa_mask
    bottom_right_masks = [
        AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK,
        AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK,
    ]
    left_window, right_window = window_size
    if attn_mask_type in bottom_right_masks:
        if left_window < 0:
            left_window = max_seqlen_kv
        if right_window < 0:
            right_window = max_seqlen_kv
        bottom_right_shift = max_seqlen_kv - max_seqlen_q
        swa_mask = jnp.triu(swa_mask, k=-left_window + bottom_right_shift)
        swa_mask = jnp.tril(swa_mask, k=right_window + bottom_right_shift)
    else:
        if left_window < 0:
            left_window = max_seqlen_q
        if right_window < 0:
            right_window = max_seqlen_q
        swa_mask = jnp.triu(swa_mask, k=-left_window)
        swa_mask = jnp.tril(swa_mask, k=right_window)
    return swa_mask


149
150
151
152
153
154
def canonicalize_attn_mask_type(attn_mask_type: str):
    """Convert string attn_mask_type to AttnMaskType
    TE-JAX currently fall back to the padding version kernels for the libraries integration.
    The overhead between padding and non-padding version should be small.
    However, we will lease this limitation in the near feature.
    """
155
    match attn_mask_type:
156
        case "no_mask":
157
            return AttnMaskType.NO_MASK
158
        case "padding":
159
            return AttnMaskType.PADDING_MASK
160
        case "causal":
161
            return AttnMaskType.CAUSAL_MASK
162
163
        case "causal_bottom_right" | "bottom_right_causal":
            return AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK
164
        case "padding_causal" | "causal_padding":
165
            return AttnMaskType.PADDING_CAUSAL_MASK
166
167
168
169
170
171
172
        case (
            "padding_causal_bottom_right"
            | "causal_padding_bottom_right"
            | "bottom_right_causal_padding"
            | "bottom_right_padding_causal"
        ):
            return AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK
173
    raise ValueError(
174
175
176
        f"Unsupported {attn_mask_type=}, supported attn_mask_type={{'no_mask', 'padding', 'causal',"
        " 'padding_causal', 'causal_padding', 'causal_bottom_right',"
        " 'padding_causal_bottom_right'}"
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
    )


def is_fused_attn_kernel_available(
    q_dtype,
    kv_dtype,
    qkv_layout,
    attn_bias_type,
    attn_mask_type,
    dropout_probability,
    q_num_heads,
    kv_num_heads,
    q_max_seqlen,
    kv_max_seqlen,
    head_dim,
192
    window_size: Optional[Tuple[int, int]] = None,
193
):
194
    """
195
    To check whether the fused attention kernel is supported
196
    """
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217

    def make_helper(attn_mask_type):
        return tex.FusedAttnHelper(
            q_dtype,
            kv_dtype,
            qkv_layout.value,
            attn_bias_type.value,
            attn_mask_type.value,
            dropout_probability,
            q_num_heads,
            kv_num_heads,
            q_max_seqlen,
            kv_max_seqlen,
            head_dim,
            (-1, -1) if window_size is None else window_size,
        )

    if not make_helper(attn_mask_type).is_fused_attn_kernel_available():
        return False

    return True
218
219


220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
def _obtain_batch_and_max_seqlen(qkv, qkv_layout):
    match qkv_layout:
        case QKVLayout.BS3HD | QKVLayout.T3HD:
            assert len(qkv) == 1, f"qkv must be (qkvpacked,) with {qkv_layout=}"
            batch, q_max_seqlen, *_ = qkv[0].shape
            kv_max_seqlen = q_max_seqlen
        case QKVLayout.BSHD_BS2HD | QKVLayout.THD_T2HD:
            assert len(qkv) == 2, f"qkv must be (query, kvpacked) with {qkv_layout=}"
            batch, q_max_seqlen, *_ = qkv[0].shape
            kv_max_seqlen = qkv[1].shape[1]
        case QKVLayout.BSHD_BSHD_BSHD | QKVLayout.THD_THD_THD:
            assert len(qkv) == 3, f"qkv must be (query, key, value) with {qkv_layout=}"
            batch, q_max_seqlen, *_ = qkv[0].shape
            kv_max_seqlen = qkv[1].shape[1]
        case _:
            raise ValueError(f"Unsupported {qkv_layout=}")
    return batch, q_max_seqlen, kv_max_seqlen
237

238

239
240
def reorder_causal_load_balancing(tensor, cp_size: int, tensor_format: QKVFormat):
    """Reorders a tensor for load balancing the compute of causal attention."""
241
242
    seq_dim = 1 if tensor_format == QKVFormat.BSHD else 0
    return tex.attention.reorder_causal_load_balancing(tensor, cp_size, seq_dim, False)
243
244
245
246


def inverse_reorder_causal_load_balancing(tensor, cp_size: int, tensor_format: QKVFormat):
    """Inverse operation of `reorder_causal_load_balancing`."""
247
248
    seq_dim = 1 if tensor_format == QKVFormat.BSHD else 0
    return tex.attention.reorder_causal_load_balancing(tensor, cp_size, seq_dim, True)
249
250


251
252
253
254
255
def fused_attn(
    qkv: Tuple[jnp.ndarray, ...],
    bias: Optional[jnp.ndarray],
    mask: Optional[jnp.ndarray],
    seed: Optional[jnp.ndarray],
256
257
    attn_bias_type: AttnBiasType,
    attn_mask_type: AttnMaskType,
258
    qkv_layout: QKVLayout,
259
260
261
    scaling_factor: float,
    dropout_probability: float,
    is_training: bool,
262
    window_size: Optional[Tuple[int, int]] = None,
263
264
    context_parallel_causal_load_balanced: bool = False,
    context_parallel_axis: str = "",
265
):
266
    """
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
    Perform non-THD (non-packed) cuDNN fused attention.

    This function implements the following formula:
        BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2
    Args:
        qkv (Tuple[jnp.ndarray, ...]): A tuple containing query, key, and value tensors.
        It supports three formats:
            - `(qkv_packed,)`: For interleaved QKV packed format, typically used when query, key,
              and value have the same shape (e.g., self-attention).
            - `(query, kv_packed)`: For separate query and KV packed format, typically used when
              query has a different shape (e.g., cross-attention).
            - `(query, key, value)`: For separate query, key, and value tensors.
        bias (Optional[jnp.ndarray]): An optional bias tensor to be added to the attention scores.
        mask (Optional[jnp.ndarray]):
            An optional mask tensor to mask out the attention scores, `True` means mask out.
            Intra-sequence padding is not valid. The padded tokens can only on the right-most.
            Otherwise the results will be wrong.
        seed (Optional[jnp.ndarray]): Optional random seed for dropout.
        attn_bias_type (NVTE_Bias_Type): Type of attention bias.
        attn_mask_type (NVTE_Mask_Type): Type of attention mask.
        qkv_layout (NVTE_QKV_Layout): Layout of the QKV tensors.
        scaling_factor (float): Scaling factor for the attention scores.
        dropout_probability (float): Dropout probability to apply during attention.
        is_training (bool): Flag indicating whether the model is in training mode.
291
        window_size (Optional[Tuple[int, int]]): Sliding window size.
292
293
294
        context_parallel_causal_load_balanced (bool):
            Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
        context_parallel_axis (str): The name of the context parallel axis.
295
296
    Returns:
        (jnp.ndarray): The output tensor from the fused attention.
297
    """
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
    assert (
        get_qkv_format(qkv_layout) != QKVFormat.THD
    ), "Please use transformer_engine.jax.attention.fused_attn_thd for THD format."

    # Check inputs qkv
    match qkv_layout:
        case NVTE_QKV_Layout.NVTE_BS3HD:
            assert len(qkv) == 1, f"qkv=(packed_qkv,) is expected with {qkv_layout=} but got {qkv=}"
        case NVTE_QKV_Layout.NVTE_BSHD_BS2HD:
            assert (
                len(qkv) == 2
            ), f"qkv=(query, packed_kv) is expected with {qkv_layout=} but got {qkv=}"
        case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD:
            assert (
                len(qkv) == 3
            ), f"qkv=(query, key, value) is expected with {qkv_layout=} but got {qkv=}"

    # convert the mask to seqlens, mask doesn't support ragged offsets
316
317
318
319
320
    if attn_mask_type in [
        AttnMaskType.NO_MASK,
        AttnMaskType.CAUSAL_MASK,
        AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK,
    ]:
321
322
323
        batch, q_max_seqlen, kv_max_seqlen = _obtain_batch_and_max_seqlen(qkv, qkv_layout)
        q_seq_lens = jnp.full((batch,), q_max_seqlen, dtype=jnp.int32)
        kv_seq_lens = jnp.full((batch,), kv_max_seqlen, dtype=jnp.int32)
zlsh80826's avatar
zlsh80826 committed
324
    else:
325
        assert mask is not None
326
        mask = jnp.logical_not(mask)
327
        q_seq_lens = jnp.sum(mask, axis=-2, dtype=jnp.int32)[..., 0, 0]
328
        if attn_mask_type == AttnMaskType.PADDING_MASK:
329
            kv_seq_lens = jnp.sum(mask, axis=-1, dtype=jnp.int32)[..., 0, 0]
330
331
        else:
            # When mask is causal, the actual seqlen is not the last row, use max to find it
332
            kv_seq_lens = jnp.max(jnp.sum(mask, axis=-1, dtype=jnp.int32), axis=(-1, -2))
333

334
335
    output = _fused_attn(
        qkv,
336
        bias,
337
338
339
340
        q_seq_lens,
        kv_seq_lens,
        None,
        None,
341
        seed,
342
343
344
        attn_bias_type=attn_bias_type,
        attn_mask_type=attn_mask_type,
        qkv_layout=qkv_layout,
345
346
347
        scaling_factor=scaling_factor,
        dropout_probability=dropout_probability,
        is_training=is_training,
348
        max_segments_per_seq=1,
349
        window_size=window_size,
350
351
        context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
        context_parallel_axis=context_parallel_axis,
352
    )
353

354
    return output
355
356


357
358
359
360
361
362
363
364
def fused_attn_thd(
    qkv: Tuple[jnp.ndarray, ...],
    bias: Optional[jnp.ndarray],
    q_seq_lens: jnp.ndarray,
    kv_seq_lens: jnp.ndarray,
    q_seq_offsets: jnp.ndarray,
    kv_seq_offsets: jnp.ndarray,
    seed: Optional[jnp.ndarray],
365
366
    attn_bias_type: AttnBiasType,
    attn_mask_type: AttnMaskType,
367
    qkv_layout: QKVLayout,
368
369
370
    scaling_factor: float,
    dropout_probability: float,
    is_training: bool,
371
    max_segments_per_seq: int = 1,
372
    window_size: Optional[Tuple[int, int]] = None,
373
374
    context_parallel_causal_load_balanced: bool = False,
    context_parallel_axis: str = "",
375
):
376
    """
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
    (Experimental) Perform THD (packed) cuDNN fused attention.

    This function implements the following formula:
        BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2
    Args:
        qkv (Tuple[jnp.ndarray, ...]): A tuple containing query, key, and value tensors.
        It supports three formats:
            - `(qkv_packed,)`: For interleaved QKV packed format, typically used when query, key,
              and value have the same shape (e.g., self-attention).
            - `(query, kv_packed)`: For separate query and KV packed format, typically used when
              query has a different shape (e.g., cross-attention).
            - `(query, key, value)`: For separate query, key, and value tensors.
        bias (Optional[jnp.ndarray]): An optional bias tensor to be added to the attention scores.
        q_seqlen (jnp.ndarray):
            Sequence lengths for the query, with shape [batch, max_seqlen]. Unused positions are
            padded with -1.
        kv_seqlen (jnp.ndarray):
            Sequence lengths for the key and value, with shape [batch, max_seqlen]. Unused positions
            are padded with -1.
        q_seq_offsets (jnp.ndarray):
            The offsets in the sequence dim for the query, with shape [batch, max_seqlen + 1].
            Unused positions are padded with -1.
        kv_seq_offsets (jnp.ndarray):
            The offsets in the sequence dim for the query, with shape [batch, max_seqlen + 1].
            Unused positions are padded with -1.
        seed (Optional[jnp.ndarray]): Optional random seed for dropout.
        attn_bias_type (NVTE_Bias_Type): Type of attention bias.
        attn_mask_type (NVTE_Mask_Type): Type of attention mask.
        qkv_layout (NVTE_QKV_Layout): Layout of the QKV tensors.
        scaling_factor (float): Scaling factor for the attention scores.
        dropout_probability (float): Dropout probability to apply during attention.
        is_training (bool): Flag indicating whether the model is in training mode.
        max_segments_per_seq (int):
            Indicating the maximum number of segments inside a sequence. This parameter is to
            constrain the limit usage and need to be static during the e2e training. The XLA compile
            time and memory consumption is proportional to `max_segments_per_seq`.
413
414
        window_size (Optional[Tuple[int, int]]):
            Sliding window size.
415
416
417
        context_parallel_causal_load_balanced (bool):
            Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
        context_parallel_axis (str): The name of the context parallel axis.
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
    Returns:
        (jnp.ndarray): The output tensor from the fused attention.

    Examples:
        >>> # segment_ids = [[1, 1, 2, 3], [1, 1, 2, 0]], 0 means padded tokens
        >>> b, s, h, d = 2, 4, 12, 64
        >>> qkv = jnp.zeros((b, s, 3, h, d), dtype=jnp.bfloat16)
        >>> # 3 segments in first seq, 2 segments in second seq
        >>> q_seq_lens = kv_seq_lens = jnp.asarray([[2, 1, 1, -1], [2, 1, -1, -1]])
        >>> # seq_offsets need to include the end offset of the last segments
        >>> q_seq_offsets = kv_seq_offsets = jnp.asarray([[0, 2, 3, 4, -1], [0, 2, 3, -1, -1]])
        >>> out = fused_attn_thd((qkv,), None, q_seq_lens, kv_seq_lens,
                                 q_seq_offsets, kv_seq_offsets, None,
                                 AttnBiasType.NO_BIAS, AttnMaskType.PADDING_CAUSAL_MASK,
                                 QKVLayout.T3HD, 0.125, 0, True, 3)
433
    """
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
    assert (
        get_qkv_format(qkv_layout) == QKVFormat.THD
    ), "Please use transformer_engine.jax.attention.fused_attn for non-THD format."

    # Check inputs qkv
    match qkv_layout:
        case NVTE_QKV_Layout.NVTE_T3HD:
            assert len(qkv) == 1, f"qkv=(packed_qkv,) is expected with {qkv_layout=} but got {qkv=}"
        case NVTE_QKV_Layout.NVTE_THD_T2HD:
            assert (
                len(qkv) == 2
            ), f"qkv=(query, packed_kv) is expected with {qkv_layout=} but got {qkv=}"
        case NVTE_QKV_Layout.NVTE_THD_THD_THD:
            assert (
                len(qkv) == 3
            ), f"qkv=(query, key, value) is expected with {qkv_layout=} but got {qkv=}"

    batch, q_max_seqlen, kv_max_seqlen = _obtain_batch_and_max_seqlen(qkv, qkv_layout)
    assert q_seq_lens.shape == (batch, q_max_seqlen)
    assert kv_seq_lens.shape == (batch, kv_max_seqlen)
    assert q_seq_offsets.shape == (batch, q_max_seqlen + 1)
    assert kv_seq_offsets.shape == (batch, kv_max_seqlen + 1)
456

457
    output = _fused_attn(
458
        qkv,
459
        bias,
460
461
462
463
        q_seq_lens,
        kv_seq_lens,
        q_seq_offsets,
        kv_seq_offsets,
464
465
466
        seed,
        attn_bias_type=attn_bias_type,
        attn_mask_type=attn_mask_type,
467
        qkv_layout=qkv_layout,
468
469
470
        scaling_factor=scaling_factor,
        dropout_probability=dropout_probability,
        is_training=is_training,
471
        max_segments_per_seq=max_segments_per_seq,
472
        window_size=window_size,
473
474
        context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
        context_parallel_axis=context_parallel_axis,
475
    )
476
477
478
479

    return output


480
@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16))
481
def _fused_attn(
482
483
484
485
486
487
    qkv: Tuple[jnp.ndarray, ...],
    bias: Optional[jnp.ndarray],
    q_seq_lens: jnp.ndarray,
    kv_seq_lens: jnp.ndarray,
    q_seq_offsets: Optional[jnp.ndarray],
    kv_seq_offsets: Optional[jnp.ndarray],
488
489
490
    seed: jnp.ndarray,
    attn_bias_type: AttnBiasType,
    attn_mask_type: AttnMaskType,
491
    qkv_layout: QKVLayout,
492
493
494
    scaling_factor: float,
    dropout_probability: float,
    is_training: bool,
495
    max_segments_per_seq: int,
496
    window_size: Optional[Tuple[int, int]],
497
498
    context_parallel_causal_load_balanced: bool,
    context_parallel_axis: str,
499
500
):
    output, _ = _fused_attn_fwd_rule(
501
        qkv,
502
        bias,
503
504
505
506
        q_seq_lens,
        kv_seq_lens,
        q_seq_offsets,
        kv_seq_offsets,
507
508
509
        seed,
        attn_bias_type,
        attn_mask_type,
510
        qkv_layout,
511
512
513
        scaling_factor,
        dropout_probability,
        is_training,
514
        max_segments_per_seq,
515
        window_size,
516
517
        context_parallel_causal_load_balanced,
        context_parallel_axis,
518
    )
519
520
521
    return output


522
def _fused_attn_fwd_rule(
523
    qkv,
524
    bias,
525
526
527
528
    q_seq_lens,
    kv_seq_lens,
    q_seq_offsets,
    kv_seq_offsets,
529
530
531
    seed,
    attn_bias_type,
    attn_mask_type,
532
    qkv_layout,
533
534
535
    scaling_factor,
    dropout_probability,
    is_training,
536
    max_segments_per_seq,
537
    window_size,
538
539
    context_parallel_causal_load_balanced,
    context_parallel_axis,
540
541
):
    output, softmax_aux, rng_state = tex.fused_attn_fwd(
542
        qkv,
543
        bias,
544
545
546
547
        q_seq_lens,
        kv_seq_lens,
        q_seq_offsets,
        kv_seq_offsets,
548
549
550
        seed,
        attn_bias_type=attn_bias_type.value,
        attn_mask_type=attn_mask_type.value,
551
        qkv_layout=qkv_layout.value,
552
553
554
        scaling_factor=scaling_factor,
        dropout_probability=dropout_probability,
        is_training=is_training,
555
        max_segments_per_seq=max_segments_per_seq,
556
        window_size=window_size,
557
558
        context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
        context_parallel_axis=context_parallel_axis,
559
560
561
562
563
    )
    output = checkpoint_name(output, "context")
    softmax_aux = checkpoint_name(softmax_aux, "context")
    rng_state = checkpoint_name(rng_state, "context")
    return output, (
564
        qkv,
565
        bias,
566
567
568
569
        q_seq_lens,
        kv_seq_lens,
        q_seq_offsets,
        kv_seq_offsets,
570
571
572
573
574
575
576
        softmax_aux,
        rng_state,
        output,
    )


def _fused_attn_bwd_rule(
577
578
579
580
581
582
583
    attn_bias_type,
    attn_mask_type,
    qkv_layout,
    scaling_factor,
    dropout_probability,
    is_training,
    max_segments_per_seq,
584
    window_size,
585
586
    context_parallel_causal_load_balanced,
    context_parallel_axis,
587
588
    ctx,
    dz,
589
):
590
591
592
593
594
595
596
597
598
599
600
601
602
    (
        qkv,
        bias,
        q_seq_lens,
        kv_seq_lens,
        q_seq_offsets,
        kv_seq_offsets,
        softmax_aux,
        rng_state,
        output,
    ) = ctx
    grad_qkv, grad_bias = tex.fused_attn_bwd(
        qkv,
603
604
605
606
607
        bias,
        softmax_aux,
        rng_state,
        output,
        dz,
608
609
610
611
        q_seq_lens,
        kv_seq_lens,
        q_seq_offsets,
        kv_seq_offsets,
612
613
        attn_bias_type=attn_bias_type.value,
        attn_mask_type=attn_mask_type.value,
614
        qkv_layout=qkv_layout.value,
615
616
617
        scaling_factor=scaling_factor,
        dropout_probability=dropout_probability,
        is_training=is_training,
618
        max_segments_per_seq=max_segments_per_seq,
619
        window_size=window_size,
620
621
        context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
        context_parallel_axis=context_parallel_axis,
622
    )
623
624
    if attn_bias_type == AttnBiasType.NO_BIAS:
        grad_bias = None
625
    return grad_qkv, grad_bias, None, None, None, None, None
626
627
628


_fused_attn.defvjp(_fused_attn_fwd_rule, _fused_attn_bwd_rule)