attention.py 43.7 KB
Newer Older
1
2
3
4
5
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""JAX/TE custom ops for attention"""
from dataclasses import dataclass
6
from functools import partial, reduce, cache
7
import operator
8
import os
9
from typing import Optional, Tuple
10
11
12
13
14
15
16
17
18
19
20
21
22
import warnings

import jax.numpy as jnp
from jax import dtypes
from jax.interpreters import mlir
from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding

from transformer_engine import transformer_engine_jax
from transformer_engine.transformer_engine_jax import (
    NVTE_Bias_Type,
    NVTE_Mask_Type,
    NVTE_QKV_Layout,
23
    NVTE_QKV_Format,
24
    NVTE_Fused_Attn_Backend,
25
    nvte_get_qkv_format,
26
27
28
29
30
31
32
)
from .base import BasePrimitive, register_primitive
from .custom_call import custom_caller, CustomCallArgsWrapper
from .misc import (
    check_valid_batch_dims,
    jax_dtype_to_te_dtype,
    te_dtype_to_jax_dtype,
33
    get_padded_spec,
34
    get_cudnn_version,
35
36
37
38
39
40
41
42
)
from ..sharding import (
    all_reduce_sum_along_dp_fsdp,
    get_all_mesh_axes,
    num_of_devices,
)


43
44
45
46
47
__all__ = [
    "FusedAttnHelper",
    "fused_attn_fwd",
    "fused_attn_bwd",
]
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74


@dataclass(frozen=True)
class FusedAttnHelper:
    """
    Helper for the fused attention backend
    """

    q_dtype: jnp.dtype
    kv_dtype: jnp.dtype
    qkv_layout: NVTE_QKV_Layout
    attn_bias_type: NVTE_Bias_Type
    attn_mask_type: NVTE_Mask_Type
    dropout_probability: float
    q_num_heads: int
    kv_num_heads: int
    q_max_seqlen: int
    kv_max_seqlen: int
    head_dim: int

    def is_fused_attn_kernel_available(self):
        """Check if there is available fused attention kernel"""
        return self.get_fused_attn_backend() != NVTE_Fused_Attn_Backend.NVTE_No_Backend

    def get_fused_attn_backend(self):
        """Get the fused attention kernel backend"""
        return transformer_engine_jax.get_fused_attn_backend(
75
76
77
78
79
80
81
82
83
84
85
86
            jax_dtype_to_te_dtype(self.q_dtype),
            jax_dtype_to_te_dtype(self.kv_dtype),
            self.qkv_layout,
            self.attn_bias_type,
            self.attn_mask_type,
            self.dropout_probability,
            self.q_num_heads,
            self.kv_num_heads,
            self.q_max_seqlen,
            self.kv_max_seqlen,
            self.head_dim,
        )
87

88
89
90
91
92
93
    @staticmethod
    @cache
    def is_non_deterministic_allowed():
        """Check if non-deterministic kernels are allowed"""
        return bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1")))

94
95
96
97
    @staticmethod
    def parse_qkv_aval(q_aval, k_aval, v_aval, qkv_layout):
        """Parse qkv aval"""
        match qkv_layout:
98
            case NVTE_QKV_Layout.NVTE_BS3HD | NVTE_QKV_Layout.NVTE_T3HD:
99
100
101
102
103
104
                *q_batch_shape, q_max_seqlen, nqkv, attn_heads, q_head_dim = q_aval.shape
                kv_batch_shape = q_batch_shape
                kv_max_seqlen = q_max_seqlen
                num_gqa_groups = attn_heads
                kv_head_dim = q_head_dim
                assert nqkv == 3
105
            case NVTE_QKV_Layout.NVTE_BSHD_BS2HD | NVTE_QKV_Layout.NVTE_THD_T2HD:
106
107
108
                *q_batch_shape, q_max_seqlen, attn_heads, q_head_dim = q_aval.shape
                *kv_batch_shape, kv_max_seqlen, nkv, num_gqa_groups, kv_head_dim = k_aval.shape
                assert nkv == 2
109
            case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD | NVTE_QKV_Layout.NVTE_THD_THD_THD:
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
                *q_batch_shape, q_max_seqlen, attn_heads, q_head_dim = q_aval.shape
                *kv_batch_shape, kv_max_seqlen, num_gqa_groups, kv_head_dim = k_aval.shape
                assert k_aval.shape == v_aval.shape
            case _:
                raise ValueError(f"Unexpected {qkv_layout=}")
        assert q_batch_shape == kv_batch_shape
        assert q_head_dim == kv_head_dim
        assert q_aval.dtype == k_aval.dtype == v_aval.dtype

        return (q_batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, q_head_dim)


@dataclass(frozen=True)
class _FusedAttnRNGStateChecker:
    """
    Checker for guarding the fused attention rng state.
    The fused attention backend requires a 64 bits seed and a 64 bits offset.
    However, JAX doesn't enable 64 bits by default,
    so we have to emulate seed as two 32 bits array.
    The offset calculation is maintained in the backend.
    """
131

132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
    rng_state_dtype: jnp.dtype = jnp.uint32
    # (seed,) with internal dtype int64
    seed_size: int = 2
    # (seed, offset) with internal dtype int64
    rng_state_size: int = 2 * 2

    def check_seed(self, seed, dropout_probability, is_training):
        """
        Check the seed and convert the data type of seed if possible.
        """
        # Jax can't bind None, create a dummy tensor for None
        if seed is None:
            dropout_enabled = dropout_probability > 0 and is_training
            assert not dropout_enabled, "seed is not allowed to be None when dropout is enabled."
            seed = jnp.zeros(2, dtype=self.rng_state_dtype)
            seed = jnp.repeat(seed, num_of_devices())

        if seed.dtype != self.rng_state_dtype:
            warnings.warn(
                f"Requested {seed.dtype=} is not available, and will be "
                f"casted to dtype {self.rng_state_dtype}. "
153
154
                "Please use threefry/rbg/unsafe_rbg PRNG implementations to remove this warning."
            )
155
156
157
158
159
160
161
162
163
164
165
166
167
            seed = seed.astype(self.rng_state_dtype)

        assert seed.dtype == self.rng_state_dtype
        # Backend takes an int64_t seed, so only the first two u32 elements are taken
        assert seed.size >= self.seed_size

        return seed


def generate_cu_seqlen(actual_seqlen):
    """
    Generating cumsum seqlen for a batch
    """
168
169
170
    cu_seqlen = jnp.cumsum(actual_seqlen, axis=-1)
    cu_seqlen = jnp.where(actual_seqlen < 0, -1, cu_seqlen)
    cu_seqlen = jnp.insert(cu_seqlen, 0, values=0, axis=-1)
171
172
173
174
175
176
177
    return cu_seqlen


class FusedAttnFwdPrimitive(BasePrimitive):
    """
    Fused Attention Forward Primitive
    """
178

179
180
    name = "te_fused_attn_forward"
    multiple_results = True
181
    impl_static_args = (9, 10, 11, 12, 13, 14, 15)
182
183
184
185
    inner_primitive = None
    outer_primitive = None

    @staticmethod
186
187
188
189
190
191
192
    def abstract(
        q_aval,
        k_aval,
        v_aval,
        bias_aval,
        q_seqlen_or_cu_seqlen_aval,
        kv_seqlen_or_cu_seqlen_aval,
193
194
        _q_seq_offsets,
        _k_seq_offsets,
195
196
197
198
199
200
201
202
        seed_aval,
        *,
        attn_bias_type,
        attn_mask_type,
        qkv_layout,
        scaling_factor,
        dropout_probability,
        is_training,
203
        max_segments_per_seq,
204
    ):
205
206
207
208
209
210
211
212
213
214
        """
        Fused attention fwd abstract
        """
        q_dtype = dtypes.canonicalize_dtype(q_aval.dtype)
        k_dtype = dtypes.canonicalize_dtype(k_aval.dtype)
        v_dtype = dtypes.canonicalize_dtype(v_aval.dtype)
        bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype)
        assert q_dtype == k_dtype == v_dtype == bias_dtype
        assert q_seqlen_or_cu_seqlen_aval.dtype == kv_seqlen_or_cu_seqlen_aval.dtype

215
        batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = (
216
            FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, qkv_layout)
217
        )
218
219
220
221
222

        output_shape = (*batch_shape, q_max_seqlen, attn_heads, head_dim)
        out_aval = q_aval.update(shape=output_shape, dtype=q_dtype)

        # backend determines the softmax buffer shape/dtype
223
224
225
226
227
228
229
230
231
232
233
234
235
        backend = FusedAttnHelper(
            q_dtype,
            k_dtype,
            qkv_layout,
            attn_bias_type,
            attn_mask_type,
            dropout_probability,
            attn_heads,
            num_gqa_groups,
            q_max_seqlen,
            kv_max_seqlen,
            head_dim,
        ).get_fused_attn_backend()
236
237
238
239
240

        if backend == NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen:
            softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, kv_max_seqlen)
            softmax_dtype = q_dtype
        elif backend == NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen:
241
            softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, max_segments_per_seq)
242
243
            softmax_dtype = dtypes.canonicalize_dtype(jnp.float32)
        else:
244
            raise ValueError(f"Unsupported {backend=}")
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
        softmax_aux_aval = q_aval.update(shape=softmax_shape, dtype=softmax_dtype)

        # JAX does not enable 64-bit int by default so we get XLA to allocate x8 memory with
        # 32-bit unsigned int to get the buffer size we need in the C++ kernel
        checker = _FusedAttnRNGStateChecker()
        seed_dtype = dtypes.canonicalize_dtype(seed_aval.dtype)
        assert seed_dtype == checker.rng_state_dtype
        rng_state_shape = (seed_aval.shape[0], checker.rng_state_size)
        rng_state_aval = seed_aval.update(shape=rng_state_shape, dtype=checker.rng_state_dtype)

        if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
            bias_batch = bias_heads = 0
        else:
            *bias_batch_shape, bias_heads, _, _ = bias_aval.shape
            bias_batch = reduce(operator.mul, bias_batch_shape)

        # do a dummy kernel call here to get workspace buffer shapes/dtypes that XLA needs to
        # prepare for the active fused-attn backend
        input_batch = reduce(operator.mul, batch_shape)
        wkspace_info = transformer_engine_jax.get_fused_attn_fwd_workspace_sizes(
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
            input_batch,
            bias_batch,
            q_max_seqlen,
            kv_max_seqlen,
            attn_heads,
            num_gqa_groups,
            bias_heads,
            head_dim,
            scaling_factor,
            dropout_probability,
            attn_bias_type,
            attn_mask_type,
            qkv_layout,
            jax_dtype_to_te_dtype(q_aval.dtype),
            is_training,
280
            max_segments_per_seq,
281
282
283
284
        )
        wkspace_aval = q_aval.update(
            shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1])
        )
285
286
287
288
289
290
291
292

        return out_aval, softmax_aux_aval, rng_state_aval, wkspace_aval

    @staticmethod
    def outer_abstract(*args, **kwargs):
        """
        Fused attention fwd outer primitive abstract
        """
293
294
295
        out_aval, softmax_aux_aval, rng_state_aval, _ = FusedAttnFwdPrimitive.abstract(
            *args, **kwargs
        )
296
297
298
        return out_aval, softmax_aux_aval, rng_state_aval

    @staticmethod
299
300
301
302
303
304
305
306
    def lowering(
        ctx,
        q,
        k,
        v,
        bias,
        q_cu_seqlen,
        kv_cu_seqlen,
307
308
        q_seq_offsets,
        k_seq_offsets,
309
310
311
312
313
314
315
316
        seed,
        *,
        attn_bias_type,
        attn_mask_type,
        qkv_layout,
        scaling_factor,
        dropout_probability,
        is_training,
317
        max_segments_per_seq,
318
    ):
319
320
321
        """
        Fused attention fwd lowering rules
        """
322
323
324
325
326
327
328
329
330
331
332
        operands = [
            q,
            k,
            v,
            bias,
            q_cu_seqlen,
            kv_cu_seqlen,
            q_seq_offsets,
            k_seq_offsets,
            seed,
        ]
333
334
335
336
337
338
339
340
341
        operand_shapes = map(lambda x: x.type.shape, operands)
        out_types = [
            ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
            for output in ctx.avals_out
        ]
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

        q_aval, k_aval, v_aval, bias_aval, *_ = ctx.avals_in

342
        batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = (
343
            FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, qkv_layout)
344
        )
345
346
347
348
349
350
351
352
353
354
355
356

        input_batch = reduce(operator.mul, batch_shape)

        if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
            bias_batch = bias_heads = 0
        else:
            *bias_batch_shape, bias_heads, _, _ = bias_aval.shape
            bias_batch = reduce(operator.mul, bias_batch_shape)

        wkspace_aval = ctx.avals_out[-1]

        opaque = transformer_engine_jax.pack_fused_attn_descriptor(
357
358
359
360
361
362
363
364
            input_batch,
            bias_batch,
            q_max_seqlen,
            kv_max_seqlen,
            attn_heads,
            num_gqa_groups,
            bias_heads,
            head_dim,
365
            max_segments_per_seq,
366
367
368
369
370
371
372
373
374
            wkspace_aval.size,
            scaling_factor,
            dropout_probability,
            attn_bias_type,
            attn_mask_type,
            qkv_layout,
            jax_dtype_to_te_dtype(q_aval.dtype),
            jax_dtype_to_te_dtype(wkspace_aval.dtype),
            is_training,
375
            not FusedAttnHelper.is_non_deterministic_allowed(),
376
        )
377
378
379
380
381
382

        out = custom_caller(FusedAttnFwdPrimitive.name, args, opaque, has_side_effect=False)

        return out

    @staticmethod
383
384
385
386
387
388
389
    def impl(
        q,
        k,
        v,
        bias,
        q_seqlen,
        kv_seqlen,
390
391
        q_seq_offsets,
        k_seq_offsets,
392
393
394
395
396
397
398
        seed,
        attn_bias_type,
        attn_mask_type,
        qkv_layout,
        scaling_factor,
        dropout_probability,
        is_training,
399
        max_segments_per_seq,
400
    ):
401
402
        assert FusedAttnFwdPrimitive.inner_primitive is not None

403
404
        if nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format.NVTE_THD:

405
            def _fix_len_take(x, condition, fill_value=-1):
406
407
408
409
                x_shape = x.shape
                x = x.flatten()
                size = x.size
                indices = jnp.nonzero(condition.flatten(), size=size, fill_value=size)[0]
410
                y = jnp.take(x, indices, fill_value=fill_value)
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
                return jnp.reshape(y, x_shape)

            def convert_to_2d(offsets, batch, max_seqlen):
                offsets_2d = jnp.where(
                    offsets >= 0,
                    offsets + (jnp.arange(batch) * max_seqlen)[..., jnp.newaxis],
                    offsets,
                )
                return offsets_2d

            match qkv_layout:
                case NVTE_QKV_Layout.NVTE_T3HD:
                    kv_max_seqlen = q_max_seqlen = q.shape[-4]
                    kv_batch = q_batch = reduce(operator.mul, q.shape[:-4])
                case NVTE_QKV_Layout.NVTE_THD_T2HD:
                    q_max_seqlen = q.shape[-3]
                    q_batch = reduce(operator.mul, q.shape[:-3])
                    kv_max_seqlen = k.shape[-4]
                    kv_batch = reduce(operator.mul, k.shape[:-4])
                case NVTE_QKV_Layout.NVTE_THD_THD_THD:
                    q_max_seqlen = q.shape[-3]
                    q_batch = reduce(operator.mul, q.shape[:-3])
                    kv_max_seqlen = k.shape[-3]
                    kv_batch = reduce(operator.mul, k.shape[:-3])

            # Gather valid q_seqlen, which is greater than 0
437
            # cuDNN version < 9.3.0:
438
            # [[3, 5, 7, -1, -1], [2, 4, 6, -1, -1]] -> [[3, 5, 7, 2, 4], [6, -1, -1, -1, -1]]
439
440
441
442
443
444
445
446
            # cuDNN version >= 9.3.0, which supports act_seqlen = 0
            # [[3, 5, 7, -1, -1], [2, 4, 6, -1, -1]] -> [[3, 5, 7, 2, 4], [6, 0, 0, 0, 0]]
            if get_cudnn_version() >= (9, 3, 0):
                fill_value = 0
            else:
                fill_value = -1
            q_seqlen = _fix_len_take(q_seqlen, q_seqlen > 0, fill_value=fill_value)
            kv_seqlen = _fix_len_take(kv_seqlen, kv_seqlen > 0, fill_value=fill_value)
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463

            # Flatten the offset calculation
            # max_seqlen = 8, [[0, 3, 5, -1], [0, 2, 4, -1]] -> [[0, 3, 5, -1], [8, 11, 13, -1]]
            q_seq_offsets = convert_to_2d(q_seq_offsets, q_batch, q_max_seqlen)
            k_seq_offsets = convert_to_2d(k_seq_offsets, kv_batch, kv_max_seqlen)
            # Gather valid q_seq_offsets, which is greater and equal to 0
            # [[0, 3, 5, -1], [8, 11, 13, -1]] -> [[0, 3, 5, 8], [11, 13, -1, -1]]
            q_seq_offsets = _fix_len_take(q_seq_offsets, q_seq_offsets >= 0)
            k_seq_offsets = _fix_len_take(k_seq_offsets, k_seq_offsets >= 0)

            # Set the unused position to max size (batch * max_seqlen)
            # [[0, 3, 5, 8], [11, 13, -1, -1]] -> [[0, 3, 5, 8], [11, 13, b*s, b*s]]
            q_seq_offsets = jnp.where(q_seq_offsets < 0, q_batch * q_max_seqlen, q_seq_offsets)
            k_seq_offsets = jnp.where(k_seq_offsets < 0, kv_batch * kv_max_seqlen, k_seq_offsets)

        q_cu_seqlen = generate_cu_seqlen(q_seqlen.flatten())
        kv_cu_seqlen = generate_cu_seqlen(kv_seqlen.flatten())
464
465
466
467
468
469
470
471

        output, softmax_aux, rng_state, _ = FusedAttnFwdPrimitive.inner_primitive.bind(
            q,
            k,
            v,
            bias,
            q_cu_seqlen,
            kv_cu_seqlen,
472
473
            q_seq_offsets,
            k_seq_offsets,
474
475
476
477
478
479
            seed,
            attn_bias_type=attn_bias_type,
            attn_mask_type=attn_mask_type,
            qkv_layout=qkv_layout,
            scaling_factor=scaling_factor,
            dropout_probability=dropout_probability,
480
            is_training=is_training,
481
            max_segments_per_seq=max_segments_per_seq,
482
        )
483
484
485
        return output, softmax_aux, rng_state

    @staticmethod
486
487
488
489
490
491
492
493
494
495
    def batcher(
        batched_args,
        batch_dims,
        *,
        attn_bias_type,
        attn_mask_type,
        qkv_layout,
        scaling_factor,
        dropout_probability,
        is_training,
496
        max_segments_per_seq,
497
    ):
498
499
500
501
502
        check_valid_batch_dims(batch_dims)
        assert FusedAttnFwdPrimitive.outer_primitive is not None
        q_bdim, *_, seed_bdim = batch_dims

        out_bdims = q_bdim, q_bdim, seed_bdim
503
504
505
506
507
508
509
510
511
        return (
            FusedAttnFwdPrimitive.outer_primitive.bind(
                *batched_args,
                attn_bias_type=attn_bias_type,
                attn_mask_type=attn_mask_type,
                qkv_layout=qkv_layout,
                scaling_factor=scaling_factor,
                dropout_probability=dropout_probability,
                is_training=is_training,
512
                max_segments_per_seq=max_segments_per_seq,
513
514
515
            ),
            out_bdims,
        )
516
517

    @staticmethod
518
519
520
521
522
523
524
    def infer_sharding_from_operands(
        attn_bias_type,
        attn_mask_type,
        qkv_layout,
        scaling_factor,
        dropout_probability,
        is_training,
525
        max_segments_per_seq,
526
527
528
529
        mesh,
        arg_infos,
        result_infos,
    ):
530
        del attn_bias_type, attn_mask_type, scaling_factor
531
        del dropout_probability, is_training, max_segments_per_seq, result_infos
532
533
534
        q_spec = get_padded_spec(arg_infos[0])
        k_spec = get_padded_spec(arg_infos[1])
        match qkv_layout:
535
            case NVTE_QKV_Layout.NVTE_BS3HD | NVTE_QKV_Layout.NVTE_T3HD:
536
537
538
                # q_spec = (...batch, q_seqlen, head, hidden)
                out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec[:-3], *q_spec[-2:]))
                softmax_aux_sharding = NamedSharding(
539
540
                    mesh, PartitionSpec(*q_spec[:-4], q_spec[-2], q_spec[-4], None)
                )
541
            case NVTE_QKV_Layout.NVTE_BSHD_BS2HD | NVTE_QKV_Layout.NVTE_THD_T2HD:
542
543
544
545
                # q_spec = (...batch, q_seqlen, head, hidden)
                # k_spec = (...batch, kv_seqlen, 2, num_gqa_groups, hidden)
                out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
                softmax_aux_sharding = NamedSharding(
546
547
                    mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], k_spec[-4])
                )
548
            case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD | NVTE_QKV_Layout.NVTE_THD_THD_THD:
549
550
551
552
                # q_spec = (...batch, q_seqlen, head, hidden)
                # k_spec = (...batch, kv_seqlen, num_gqa_groups, hidden)
                out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
                softmax_aux_sharding = NamedSharding(
553
554
                    mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], k_spec[-3])
                )
555
556
557
558
559
560
            case _:
                raise ValueError(f"Unsupported {qkv_layout=}")
        rng_state_sharding = NamedSharding(mesh, PartitionSpec(get_all_mesh_axes(), None))
        return (out_sharding, softmax_aux_sharding, rng_state_sharding)

    @staticmethod
561
562
563
564
565
566
567
    def partition(
        attn_bias_type,
        attn_mask_type,
        qkv_layout,
        scaling_factor,
        dropout_probability,
        is_training,
568
        max_segments_per_seq,
569
570
571
572
        mesh,
        arg_infos,
        result_infos,
    ):
573
574
        out_sharding = result_infos[0].sharding
        softmax_aux_sharding = result_infos[1].sharding
575
576
577
        rng_state_sharding = seed_sharding = NamedSharding(
            mesh, PartitionSpec(get_all_mesh_axes(), None)
        )
578
579
        arg_shardings = tuple([arg_i.sharding for arg_i in arg_infos[:-1]] + [seed_sharding])
        out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding)
580
581
582
583
584
585
586
587
        impl = partial(
            FusedAttnFwdPrimitive.impl,
            attn_bias_type=attn_bias_type,
            attn_mask_type=attn_mask_type,
            qkv_layout=qkv_layout,
            scaling_factor=scaling_factor,
            dropout_probability=dropout_probability,
            is_training=is_training,
588
            max_segments_per_seq=max_segments_per_seq,
589
        )
590
591
592
593
594
595
596
597
598
599
        return mesh, impl, out_shardings, arg_shardings


register_primitive(FusedAttnFwdPrimitive)


class FusedAttnBwdPrimitive(BasePrimitive):
    """
    Fused Attention Backward Primitive
    """
600

601
602
    name = "te_fused_attn_backward"
    multiple_results = True
603
    impl_static_args = (12, 13, 14, 15, 16, 17, 18)
604
605
606
607
    inner_primitive = None
    outer_primitive = None

    @staticmethod
608
609
610
611
612
613
614
615
616
    def abstract(
        q_aval,
        k_aval,
        v_aval,
        bias_aval,
        softmax_aux_aval,
        rng_state_aval,
        output_aval,
        doutput_aval,
617
618
619
620
        q_seqlen_or_cu_seqlen_aval,
        kv_seqlen_or_cu_seqlen_aval,
        _q_seq_offsets,
        _k_seq_offsets,
621
622
623
624
625
626
627
        *,
        attn_bias_type,
        attn_mask_type,
        qkv_layout,
        scaling_factor,
        dropout_probability,
        is_training,
628
        max_segments_per_seq,
629
    ):
630
631
632
633
634
635
636
637
638
639
640
        """
        Fused attention bwd abstract
        """
        del softmax_aux_aval, rng_state_aval, output_aval

        q_dtype = dtypes.canonicalize_dtype(q_aval.dtype)
        k_dtype = dtypes.canonicalize_dtype(k_aval.dtype)
        v_dtype = dtypes.canonicalize_dtype(v_aval.dtype)
        bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype)
        doutput_dtype = dtypes.canonicalize_dtype(doutput_aval.dtype)
        assert q_dtype == k_dtype == v_dtype == bias_dtype == doutput_dtype
641
        assert q_seqlen_or_cu_seqlen_aval.dtype == kv_seqlen_or_cu_seqlen_aval.dtype
642

643
        batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = (
644
            FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, qkv_layout)
645
        )
646
647
648
649
650
651
652

        if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
            bias_batch = bias_heads = 0
        else:
            *bias_batch_shape, bias_heads, _, _ = bias_aval.shape
            bias_batch = reduce(operator.mul, bias_batch_shape)

653
654
        deterministic = not FusedAttnHelper.is_non_deterministic_allowed()

655
        input_batch = reduce(operator.mul, batch_shape)
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
        wkspace_shape, wkspace_dtype = transformer_engine_jax.get_fused_attn_bwd_workspace_sizes(
            input_batch,
            bias_batch,
            q_max_seqlen,
            kv_max_seqlen,
            attn_heads,
            num_gqa_groups,
            bias_heads,
            head_dim,
            scaling_factor,
            dropout_probability,
            attn_bias_type,
            attn_mask_type,
            qkv_layout,
            jax_dtype_to_te_dtype(q_aval.dtype),
            is_training,
672
            deterministic,
673
            max_segments_per_seq,
674
        )
675
676
677
678
679

        dq_aval = q_aval.update(shape=q_aval.shape, dtype=q_dtype)
        dk_aval = k_aval.update(shape=k_aval.shape, dtype=k_dtype)
        dv_aval = v_aval.update(shape=v_aval.shape, dtype=v_dtype)
        dbias_aval = bias_aval.update(shape=bias_aval.shape, dtype=bias_dtype)
680
681
682
        wkspace_aval = q_aval.update(
            shape=wkspace_shape, dtype=te_dtype_to_jax_dtype(wkspace_dtype)
        )
683
684
685
686
687
688
689
690

        return dq_aval, dk_aval, dv_aval, dbias_aval, wkspace_aval

    @staticmethod
    def outer_abstract(*args, **kwargs):
        """
        Fused attention fwd outer primitive abstract
        """
691
        dq_aval, dk_aval, dv_aval, dbias_aval, _ = FusedAttnBwdPrimitive.abstract(*args, **kwargs)
692
693
694
        return dq_aval, dk_aval, dv_aval, dbias_aval

    @staticmethod
695
696
697
698
699
700
701
702
703
704
705
706
    def lowering(
        ctx,
        q,
        k,
        v,
        bias,
        softmax_aux,
        rng_state,
        output,
        doutput,
        q_cu_seqlen,
        kv_cu_seqlen,
707
708
        q_seq_offsets,
        k_seq_offsets,
709
710
711
712
713
714
715
        *,
        attn_bias_type,
        attn_mask_type,
        qkv_layout,
        scaling_factor,
        dropout_probability,
        is_training,
716
        max_segments_per_seq,
717
    ):
718
719
720
721
        """
        Fused attention bwd lowering rules
        """
        operands = [
722
723
724
725
726
727
728
729
730
731
            q,
            k,
            v,
            bias,
            softmax_aux,
            rng_state,
            output,
            doutput,
            q_cu_seqlen,
            kv_cu_seqlen,
732
733
            q_seq_offsets,
            k_seq_offsets,
734
735
736
737
738
739
740
741
742
743
744
        ]
        operand_shapes = map(lambda x: x.type.shape, operands)
        out_types = [
            ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
            for output in ctx.avals_out
        ]

        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

        q_aval, k_aval, v_aval, bias_aval, *_ = ctx.avals_in

745
        batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = (
746
            FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, qkv_layout)
747
        )
748
749
750
751
752
753
754
755
756
757
758
759

        input_batch = reduce(operator.mul, batch_shape)

        if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
            bias_batch = bias_heads = 0
        else:
            *bias_batch_shape, bias_heads, _, _ = bias_aval.shape
            bias_batch = reduce(operator.mul, bias_batch_shape)

        wkspace_aval = ctx.avals_out[-1]

        opaque = transformer_engine_jax.pack_fused_attn_descriptor(
760
761
762
763
764
765
766
767
            input_batch,
            bias_batch,
            q_max_seqlen,
            kv_max_seqlen,
            attn_heads,
            num_gqa_groups,
            bias_heads,
            head_dim,
768
            max_segments_per_seq,
769
770
771
772
773
774
775
776
777
            wkspace_aval.size,
            scaling_factor,
            dropout_probability,
            attn_bias_type,
            attn_mask_type,
            qkv_layout,
            jax_dtype_to_te_dtype(q_aval.dtype),
            jax_dtype_to_te_dtype(wkspace_aval.dtype),
            is_training,
778
            not FusedAttnHelper.is_non_deterministic_allowed(),
779
        )
780
781
782
783
784
785

        out = custom_caller(FusedAttnBwdPrimitive.name, args, opaque, has_side_effect=False)

        return out

    @staticmethod
786
787
788
789
790
791
792
793
794
795
796
    def impl(
        q,
        k,
        v,
        bias,
        softmax_aux,
        rng_state,
        output,
        doutput,
        q_seqlen,
        kv_seqlen,
797
798
        q_seq_offsets,
        k_seq_offsets,
799
800
801
802
803
804
        attn_bias_type,
        attn_mask_type,
        qkv_layout,
        scaling_factor,
        dropout_probability,
        is_training,
805
        max_segments_per_seq,
806
    ):
807
808
        assert FusedAttnBwdPrimitive.inner_primitive is not None

809
810
        if nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format.NVTE_THD:

811
            def _fix_len_take(x, condition, fill_value=-1):
812
813
814
815
816
                x_shape = x.shape
                x = x.flatten()
                size = x.size
                indices = jnp.nonzero(condition.flatten(), size=size, fill_value=size)[0]
                # TODO(rewang): try indices_are_sorted
817
                y = jnp.take(x, indices, fill_value=fill_value)
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
                return jnp.reshape(y, x_shape)

            def convert_to_2d(offsets, batch, max_seqlen):
                offsets_2d = jnp.where(
                    offsets >= 0,
                    offsets + (jnp.arange(batch) * max_seqlen)[..., jnp.newaxis],
                    offsets,
                )
                return offsets_2d

            match qkv_layout:
                case NVTE_QKV_Layout.NVTE_T3HD:
                    kv_max_seqlen = q_max_seqlen = q.shape[-4]
                    kv_batch = q_batch = reduce(operator.mul, q.shape[:-4])
                case NVTE_QKV_Layout.NVTE_THD_T2HD:
                    q_max_seqlen = q.shape[-3]
                    q_batch = reduce(operator.mul, q.shape[:-3])
                    kv_max_seqlen = k.shape[-4]
                    kv_batch = reduce(operator.mul, k.shape[:-4])
                case NVTE_QKV_Layout.NVTE_THD_THD_THD:
                    q_max_seqlen = q.shape[-3]
                    q_batch = reduce(operator.mul, q.shape[:-3])
                    kv_max_seqlen = k.shape[-3]
                    kv_batch = reduce(operator.mul, k.shape[:-3])

            # Gather valid q_seqlen, which is greater than 0
844
            # cuDNN version < 9.3.0:
845
            # [[3, 5, 7, -1, -1], [2, 4, 6, -1, -1]] -> [[3, 5, 7, 2, 4], [6, -1, -1, -1, -1]]
846
847
848
849
850
851
852
853
            # cuDNN version >= 9.3.0, which supports act_seqlen = 0
            # [[3, 5, 7, -1, -1], [2, 4, 6, -1, -1]] -> [[3, 5, 7, 2, 4], [6, 0, 0, 0, 0]]
            if get_cudnn_version() >= (9, 3, 0):
                fill_value = 0
            else:
                fill_value = -1
            q_seqlen = _fix_len_take(q_seqlen, q_seqlen > 0, fill_value=fill_value)
            kv_seqlen = _fix_len_take(kv_seqlen, kv_seqlen > 0, fill_value=fill_value)
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870

            # Flatten the offset calculation
            # max_seqlen = 8, [[0, 3, 5, -1], [0, 2, 4, -1]] -> [[0, 3, 5, -1], [8, 11, 13, -1]]
            q_seq_offsets = convert_to_2d(q_seq_offsets, q_batch, q_max_seqlen)
            k_seq_offsets = convert_to_2d(k_seq_offsets, kv_batch, kv_max_seqlen)
            # Gather valid q_seq_offsets, which is greater and equal to 0
            # [[0, 3, 5, -1], [8, 11, 13, -1]] -> [[0, 3, 5, 8], [11, 13, -1, -1]]
            q_seq_offsets = _fix_len_take(q_seq_offsets, q_seq_offsets >= 0)
            k_seq_offsets = _fix_len_take(k_seq_offsets, k_seq_offsets >= 0)

            # Set the unused position to max size (batch * max_seqlen)
            # [[0, 3, 5, 8], [11, 13, -1, -1]] -> [[0, 3, 5, 8], [11, 13, b*s, b*s]]
            q_seq_offsets = jnp.where(q_seq_offsets < 0, q_batch * q_max_seqlen, q_seq_offsets)
            k_seq_offsets = jnp.where(k_seq_offsets < 0, kv_batch * kv_max_seqlen, k_seq_offsets)

        q_cu_seqlen = generate_cu_seqlen(q_seqlen.flatten())
        kv_cu_seqlen = generate_cu_seqlen(kv_seqlen.flatten())
871
872
873
874
875
876
877
878
879
880
881
882

        dq, dk, dv, dbias, _ = FusedAttnBwdPrimitive.inner_primitive.bind(
            q,
            k,
            v,
            bias,
            softmax_aux,
            rng_state,
            output,
            doutput,
            q_cu_seqlen,
            kv_cu_seqlen,
883
884
            q_seq_offsets,
            k_seq_offsets,
885
886
887
888
889
            attn_bias_type=attn_bias_type,
            attn_mask_type=attn_mask_type,
            qkv_layout=qkv_layout,
            scaling_factor=scaling_factor,
            dropout_probability=dropout_probability,
890
            is_training=is_training,
891
            max_segments_per_seq=max_segments_per_seq,
892
        )
893
894
895
        return dq, dk, dv, dbias

    @staticmethod
896
897
898
899
900
901
902
903
904
905
    def batcher(
        batched_args,
        batch_dims,
        *,
        attn_bias_type,
        attn_mask_type,
        qkv_layout,
        scaling_factor,
        dropout_probability,
        is_training,
906
        max_segments_per_seq,
907
    ):
908
909
910
911
912
        check_valid_batch_dims(batch_dims)
        assert FusedAttnBwdPrimitive.outer_primitive is not None
        q_bdim, k_bdim, v_bdim, *_ = batch_dims

        out_bdims = q_bdim, k_bdim, v_bdim, q_bdim
913
914
915
916
917
918
919
920
921
        return (
            FusedAttnBwdPrimitive.outer_primitive.bind(
                *batched_args,
                attn_bias_type=attn_bias_type,
                attn_mask_type=attn_mask_type,
                qkv_layout=qkv_layout,
                scaling_factor=scaling_factor,
                dropout_probability=dropout_probability,
                is_training=is_training,
922
                max_segments_per_seq=max_segments_per_seq,
923
924
925
            ),
            out_bdims,
        )
926
927

    @staticmethod
928
929
930
931
932
933
934
    def infer_sharding_from_operands(
        attn_bias_type,
        attn_mask_type,
        qkv_layout,
        scaling_factor,
        dropout_probability,
        is_training,
935
        max_segments_per_seq,
936
937
938
939
        mesh,
        arg_infos,
        result_infos,
    ):
940
        del attn_bias_type, attn_mask_type, qkv_layout, scaling_factor, max_segments_per_seq
941
942
943
944
945
946
947
948
949
950
951
952
        del dropout_probability, is_training, result_infos
        q_spec = get_padded_spec(arg_infos[0])
        k_spec = get_padded_spec(arg_infos[1])
        v_spec = get_padded_spec(arg_infos[2])
        bias_spec = get_padded_spec(arg_infos[3])
        dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
        dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec))
        dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec))
        dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec))
        return (dq_sharding, dk_sharding, dv_sharding, dbias_sharding)

    @staticmethod
953
954
955
956
957
958
959
    def partition(
        attn_bias_type,
        attn_mask_type,
        qkv_layout,
        scaling_factor,
        dropout_probability,
        is_training,
960
        max_segments_per_seq,
961
962
963
964
        mesh,
        arg_infos,
        result_infos,
    ):
965
966
967
968
969
970
971
972
973
974
975
976
        del result_infos
        q_spec = get_padded_spec(arg_infos[0])
        k_spec = get_padded_spec(arg_infos[1])
        v_spec = get_padded_spec(arg_infos[2])
        bias_spec = get_padded_spec(arg_infos[3])
        dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
        dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec))
        dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec))
        dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec))
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
        out_shardings = (dq_sharding, dk_sharding, dv_sharding, dbias_sharding)

977
        def sharded_impl(
978
979
980
981
982
983
984
985
986
987
988
989
            q,
            k,
            v,
            bias,
            softmax_aux,
            rng_state,
            output,
            doutput,
            q_cu_seqlen,
            kv_cu_seqlen,
            q_seq_offsets,
            k_seq_offsets,
990
        ):
991
992
993
994
995
996
997
998
999
1000
1001
            local_dq, local_dk, local_dv, local_dbias = FusedAttnBwdPrimitive.impl(
                q,
                k,
                v,
                bias,
                softmax_aux,
                rng_state,
                output,
                doutput,
                q_cu_seqlen,
                kv_cu_seqlen,
1002
1003
                q_seq_offsets,
                k_seq_offsets,
1004
1005
1006
1007
1008
                attn_bias_type=attn_bias_type,
                attn_mask_type=attn_mask_type,
                qkv_layout=qkv_layout,
                scaling_factor=scaling_factor,
                dropout_probability=dropout_probability,
1009
                is_training=is_training,
1010
                max_segments_per_seq=max_segments_per_seq,
1011
            )
1012
1013
            global_dbias = local_dbias
            if attn_bias_type is not NVTE_Bias_Type.NVTE_NO_BIAS:
1014
                global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh)
1015
1016
1017
1018
1019
1020
1021
1022
            return local_dq, local_dk, local_dv, global_dbias

        return mesh, sharded_impl, out_shardings, arg_shardings


register_primitive(FusedAttnBwdPrimitive)


1023
1024
1025
def fused_attn_fwd(
    qkv: Tuple[jnp.ndarray, ...],
    bias: Optional[jnp.ndarray],
1026
1027
    q_seqlen: jnp.ndarray,
    kv_seqlen: jnp.ndarray,
1028
1029
1030
    q_seq_offsets: Optional[jnp.ndarray],
    kv_seq_offsets: Optional[jnp.ndarray],
    seed: Optional[jnp.ndarray],
1031
1032
    attn_bias_type: NVTE_Bias_Type,
    attn_mask_type: NVTE_Mask_Type,
1033
    qkv_layout: NVTE_QKV_Layout,
1034
1035
1036
    scaling_factor: float,
    dropout_probability: float,
    is_training: bool,
1037
1038
    max_segments_per_seq: int,
) -> jnp.ndarray:
1039
    """
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
    Perform the forward pass of with cuDNN fused attention implementations.

    This function implements the following formula:
        BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2
    Args:
        qkv (Tuple[jnp.ndarray, ...]): A tuple containing query, key, and value tensors.
        It supports three formats:
            - `(qkv_packed,)`: For interleaved QKV packed format, typically used when query, key,
              and value have the same shape (e.g., self-attention).
            - `(query, kv_packed)`: For separate query and KV packed format, typically used when
              query has a different shape (e.g., cross-attention).
            - `(query, key, value)`: For separate query, key, and value tensors.
        bias (Optional[jnp.ndarray]): An optional bias tensor to be added to the attention scores.
        q_seqlen (jnp.ndarray): Sequence lengths for the query, with shape [batch,].
        kv_seqlen (jnp.ndarray): Sequence lengths for the key and value, with shape [batch,].
        q_seq_offsets (jnp.ndarray):
            The offsets in the sequence dim for the query, with shape [batch + 1,].
        kv_seq_offsets (jnp.ndarray):
            The offsets in the sequence dim for the query, with shape [batch + 1,].
        seed (Optional[jnp.ndarray]): Optional random seed for dropout.
        attn_bias_type (NVTE_Bias_Type): Type of attention bias.
        attn_mask_type (NVTE_Mask_Type): Type of attention mask.
        qkv_layout (NVTE_QKV_Layout): Layout of the QKV tensors.
        scaling_factor (float): Scaling factor for the attention scores.
        dropout_probability (float): Dropout probability to apply during attention.
        is_training (bool): Flag indicating whether the model is in training mode.
    Returns:
        (jnp.ndarray): The output tensor from the fused attention.
1068
    """
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
    seed = _FusedAttnRNGStateChecker().check_seed(seed, dropout_probability, is_training)

    assert (q_seq_offsets is None) == (
        kv_seq_offsets is None
    ), "Both q_seq_offsets and kv_seq_offsets must be either None or have values."
    is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format.NVTE_THD

    # For optional tensors, which custom calls doesn't support None
    _not_used = jnp.zeros(0, dtype=qkv[0].dtype)
    match qkv_layout:
        case NVTE_QKV_Layout.NVTE_BS3HD | NVTE_QKV_Layout.NVTE_T3HD:
            assert len(qkv) == 1, f"qkv=(packed_qkv,) is expected with {qkv_layout=} but got {qkv=}"
            qkv_for_primitive = [*qkv, _not_used, _not_used]
        case NVTE_QKV_Layout.NVTE_BSHD_BS2HD | NVTE_QKV_Layout.NVTE_THD_T2HD:
            assert (
                len(qkv) == 2
            ), f"qkv=(query, packed_kv) is expected with {qkv_layout=} but got {qkv=}"
            qkv_for_primitive = [*qkv, _not_used]
        case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD | NVTE_QKV_Layout.NVTE_THD_THD_THD:
            assert (
                len(qkv) == 3
            ), f"qkv=(query, key, value) is expected with {qkv_layout=} but got {qkv=}"
            qkv_for_primitive = qkv
1092
1093
1094

    if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
        assert bias is None
1095
        bias = jnp.zeros(0, dtype=qkv[0].dtype)
1096

1097
    return FusedAttnFwdPrimitive.outer_primitive.bind(
1098
        *qkv_for_primitive,
1099
1100
1101
        bias,
        q_seqlen,
        kv_seqlen,
1102
1103
        q_seq_offsets if is_ragged else _not_used,
        kv_seq_offsets if is_ragged else _not_used,
1104
1105
1106
        seed,
        attn_bias_type=attn_bias_type,
        attn_mask_type=attn_mask_type,
1107
        qkv_layout=qkv_layout,
1108
1109
1110
        scaling_factor=scaling_factor,
        dropout_probability=dropout_probability,
        is_training=is_training,
1111
        max_segments_per_seq=max_segments_per_seq,
1112
1113
1114
    )


1115
1116
1117
def fused_attn_bwd(
    qkv: Tuple[jnp.ndarray, ...],
    bias: Optional[jnp.ndarray],
1118
1119
1120
1121
1122
1123
    softmax_aux: jnp.ndarray,
    rng_state: jnp.ndarray,
    output: jnp.ndarray,
    doutput: jnp.ndarray,
    q_seqlen: jnp.ndarray,
    kv_seqlen: jnp.ndarray,
1124
1125
    q_seq_offsets: Optional[jnp.ndarray],
    kv_seq_offsets: Optional[jnp.ndarray],
1126
1127
    attn_bias_type: NVTE_Bias_Type,
    attn_mask_type: NVTE_Mask_Type,
1128
    qkv_layout: NVTE_QKV_Layout,
1129
1130
1131
    scaling_factor: float,
    dropout_probability: float,
    is_training: bool,
1132
    max_segments_per_seq: int,
1133
):
1134
    """
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
    Perform the backward pass of the cuDNN fused attention implementations.

    Args:
        qkv (Tuple[jnp.ndarray, ...]): A tuple containing the original query, key, and value tensors
        used in the forward pass. It supports three formats:
            - `(qkv_packed,)`: For interleaved QKV packed format, typically used when query, key,
              and value have the same shape (e.g., self-attention).
            - `(query, kv_packed)`: For separate query and KV packed format, typically used when
              query has a different shape (e.g., cross-attention).
            - `(query, key, value)`: For separate query, key, and value tensors.
        bias (Optional[jnp.ndarray]): An optional bias tensor to be added to the attention scores.
        softmax_aux (jnp.ndarray): Auxiliary tensors from the softmax step used in the forward pass.
        rng_state (jnp.ndarray): Auxiliary tensors to save the random state in the forward pass.
        output (jnp.ndarray): The output tensor from the forward pass.
        doutput (jnp.ndarray): The gradient with respect to the output.
        q_seqlen (jnp.ndarray): Sequence lengths for the query, with shape [batch,].
        kv_seqlen (jnp.ndarray): Sequence lengths for the key and value, with shape [batch,].
        q_seq_offsets (jnp.ndarray):
            The offsets in the sequence dim for the query, with shape [batch + 1,].
        kv_seq_offsets (jnp.ndarray):
            The offsets in the sequence dim for the query, with shape [batch + 1,].
        attn_bias_type (NVTE_Bias_Type): Type of attention bias.
        attn_mask_type (NVTE_Mask_Type): Type of attention mask.
        qkv_layout (NVTE_QKV_Layout): Layout of the QKV tensors.
        scaling_factor (float): Scaling factor for the attention scores.
        dropout_probability (float): Dropout probability to apply during attention.
        is_training (bool): Flag indicating whether the model is in training mode.

    Returns:
        Tuple[jnp.ndarray, ...], jnp.ndarray:
        - The first tuple contains the gradients with respect to the input `qkv` tensors in the
          same format as the input `qkv`.
        - The second value is the gradient with respect to `bias`, or `None` if `bias` is `None`.
1168
1169
    """

1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
    assert (q_seq_offsets is None) == (
        kv_seq_offsets is None
    ), "Both q_seq_offsets and kv_seq_offsets must be either None or have values."
    is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format.NVTE_THD

    # For optional tensors, which custom calls doesn't support None
    _not_used = jnp.zeros(0, dtype=qkv[0].dtype)

    match qkv_layout:
        case NVTE_QKV_Layout.NVTE_BS3HD | NVTE_QKV_Layout.NVTE_T3HD:
            assert len(qkv) == 1, f"qkv=(packed_qkv,) is expected with {qkv_layout=} but got {qkv=}"
            qkv_for_primitive = [*qkv, _not_used, _not_used]
        case NVTE_QKV_Layout.NVTE_BSHD_BS2HD | NVTE_QKV_Layout.NVTE_THD_T2HD:
            assert (
                len(qkv) == 2
            ), f"qkv=(query, packed_kv) is expected with {qkv_layout=} but got {qkv=}"
            qkv_for_primitive = [*qkv, _not_used]
        case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD | NVTE_QKV_Layout.NVTE_THD_THD_THD:
            assert (
                len(qkv) == 3
            ), f"qkv=(query, key, value) is expected with {qkv_layout=} but got {qkv=}"
            qkv_for_primitive = qkv
1192
1193
1194

    if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
        assert bias is None
1195
        bias = jnp.zeros(0, dtype=qkv[0].dtype)
1196

1197
1198
    *qkv_grads, bias_grad = FusedAttnBwdPrimitive.outer_primitive.bind(
        *qkv_for_primitive,
1199
1200
1201
1202
1203
1204
1205
        bias,
        softmax_aux,
        rng_state,
        output,
        doutput,
        q_seqlen,
        kv_seqlen,
1206
1207
        q_seq_offsets if is_ragged else _not_used,
        kv_seq_offsets if is_ragged else _not_used,
1208
1209
        attn_bias_type=attn_bias_type,
        attn_mask_type=attn_mask_type,
1210
        qkv_layout=qkv_layout,
1211
1212
        scaling_factor=scaling_factor,
        dropout_probability=dropout_probability,
1213
        is_training=is_training,
1214
        max_segments_per_seq=max_segments_per_seq,
1215
    )
1216
    return tuple(qkv_grads[: len(qkv)]), bias_grad