attention.py 33.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""JAX/TE custom ops for attention"""
from dataclasses import dataclass
from functools import partial, reduce
import operator
import warnings

import jax.numpy as jnp
from jax import dtypes
from jax.interpreters import mlir
from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding

from transformer_engine import transformer_engine_jax
from transformer_engine.transformer_engine_jax import (
    NVTE_Bias_Type,
    NVTE_Mask_Type,
    NVTE_QKV_Layout,
21
    NVTE_Fused_Attn_Backend,
22
23
24
25
26
27
28
)
from .base import BasePrimitive, register_primitive
from .custom_call import custom_caller, CustomCallArgsWrapper
from .misc import (
    check_valid_batch_dims,
    jax_dtype_to_te_dtype,
    te_dtype_to_jax_dtype,
29
    get_padded_spec,
30
31
32
33
34
35
36
37
)
from ..sharding import (
    all_reduce_sum_along_dp_fsdp,
    get_all_mesh_axes,
    num_of_devices,
)


38
39
40
41
42
43
44
45
46
__all__ = [
    "FusedAttnHelper",
    "fused_attn_fwd_qkvpacked",
    "fused_attn_bwd_qkvpacked",
    "fused_attn_fwd_kvpacked",
    "fused_attn_bwd_kvpacked",
    "fused_attn_fwd",
    "fused_attn_bwd",
]
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73


@dataclass(frozen=True)
class FusedAttnHelper:
    """
    Helper for the fused attention backend
    """

    q_dtype: jnp.dtype
    kv_dtype: jnp.dtype
    qkv_layout: NVTE_QKV_Layout
    attn_bias_type: NVTE_Bias_Type
    attn_mask_type: NVTE_Mask_Type
    dropout_probability: float
    q_num_heads: int
    kv_num_heads: int
    q_max_seqlen: int
    kv_max_seqlen: int
    head_dim: int

    def is_fused_attn_kernel_available(self):
        """Check if there is available fused attention kernel"""
        return self.get_fused_attn_backend() != NVTE_Fused_Attn_Backend.NVTE_No_Backend

    def get_fused_attn_backend(self):
        """Get the fused attention kernel backend"""
        return transformer_engine_jax.get_fused_attn_backend(
74
75
76
77
78
79
80
81
82
83
84
85
            jax_dtype_to_te_dtype(self.q_dtype),
            jax_dtype_to_te_dtype(self.kv_dtype),
            self.qkv_layout,
            self.attn_bias_type,
            self.attn_mask_type,
            self.dropout_probability,
            self.q_num_heads,
            self.kv_num_heads,
            self.q_max_seqlen,
            self.kv_max_seqlen,
            self.head_dim,
        )
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123

    @staticmethod
    def parse_qkv_aval(q_aval, k_aval, v_aval, qkv_layout):
        """Parse qkv aval"""
        match qkv_layout:
            case NVTE_QKV_Layout.NVTE_BS3HD:
                *q_batch_shape, q_max_seqlen, nqkv, attn_heads, q_head_dim = q_aval.shape
                kv_batch_shape = q_batch_shape
                kv_max_seqlen = q_max_seqlen
                num_gqa_groups = attn_heads
                kv_head_dim = q_head_dim
                assert nqkv == 3
            case NVTE_QKV_Layout.NVTE_BSHD_BS2HD:
                *q_batch_shape, q_max_seqlen, attn_heads, q_head_dim = q_aval.shape
                *kv_batch_shape, kv_max_seqlen, nkv, num_gqa_groups, kv_head_dim = k_aval.shape
                assert nkv == 2
            case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD:
                *q_batch_shape, q_max_seqlen, attn_heads, q_head_dim = q_aval.shape
                *kv_batch_shape, kv_max_seqlen, num_gqa_groups, kv_head_dim = k_aval.shape
                assert k_aval.shape == v_aval.shape
            case _:
                raise ValueError(f"Unexpected {qkv_layout=}")
        assert q_batch_shape == kv_batch_shape
        assert q_head_dim == kv_head_dim
        assert q_aval.dtype == k_aval.dtype == v_aval.dtype

        return (q_batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, q_head_dim)


@dataclass(frozen=True)
class _FusedAttnRNGStateChecker:
    """
    Checker for guarding the fused attention rng state.
    The fused attention backend requires a 64 bits seed and a 64 bits offset.
    However, JAX doesn't enable 64 bits by default,
    so we have to emulate seed as two 32 bits array.
    The offset calculation is maintained in the backend.
    """
124

125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
    rng_state_dtype: jnp.dtype = jnp.uint32
    # (seed,) with internal dtype int64
    seed_size: int = 2
    # (seed, offset) with internal dtype int64
    rng_state_size: int = 2 * 2

    def check_seed(self, seed, dropout_probability, is_training):
        """
        Check the seed and convert the data type of seed if possible.
        """
        # Jax can't bind None, create a dummy tensor for None
        if seed is None:
            dropout_enabled = dropout_probability > 0 and is_training
            assert not dropout_enabled, "seed is not allowed to be None when dropout is enabled."
            seed = jnp.zeros(2, dtype=self.rng_state_dtype)
            seed = jnp.repeat(seed, num_of_devices())

        if seed.dtype != self.rng_state_dtype:
            warnings.warn(
                f"Requested {seed.dtype=} is not available, and will be "
                f"casted to dtype {self.rng_state_dtype}. "
146
147
                "Please use threefry/rbg/unsafe_rbg PRNG implementations to remove this warning."
            )
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
            seed = seed.astype(self.rng_state_dtype)

        assert seed.dtype == self.rng_state_dtype
        # Backend takes an int64_t seed, so only the first two u32 elements are taken
        assert seed.size >= self.seed_size

        return seed


def generate_cu_seqlen(actual_seqlen):
    """
    Generating cumsum seqlen for a batch
    """
    cu_seqlen = jnp.cumsum(actual_seqlen)
    cu_seqlen = jnp.hstack((0, cu_seqlen))
    return cu_seqlen


class FusedAttnFwdPrimitive(BasePrimitive):
    """
    Fused Attention Forward Primitive
    """
170

171
172
173
174
175
176
177
    name = "te_fused_attn_forward"
    multiple_results = True
    impl_static_args = (7, 8, 9, 10, 11, 12)
    inner_primitive = None
    outer_primitive = None

    @staticmethod
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
    def abstract(
        q_aval,
        k_aval,
        v_aval,
        bias_aval,
        q_seqlen_or_cu_seqlen_aval,
        kv_seqlen_or_cu_seqlen_aval,
        seed_aval,
        *,
        attn_bias_type,
        attn_mask_type,
        qkv_layout,
        scaling_factor,
        dropout_probability,
        is_training,
    ):
194
195
196
197
198
199
200
201
202
203
        """
        Fused attention fwd abstract
        """
        q_dtype = dtypes.canonicalize_dtype(q_aval.dtype)
        k_dtype = dtypes.canonicalize_dtype(k_aval.dtype)
        v_dtype = dtypes.canonicalize_dtype(v_aval.dtype)
        bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype)
        assert q_dtype == k_dtype == v_dtype == bias_dtype
        assert q_seqlen_or_cu_seqlen_aval.dtype == kv_seqlen_or_cu_seqlen_aval.dtype

204
        batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = (
205
            FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, qkv_layout)
206
        )
207
208
209
210
211

        output_shape = (*batch_shape, q_max_seqlen, attn_heads, head_dim)
        out_aval = q_aval.update(shape=output_shape, dtype=q_dtype)

        # backend determines the softmax buffer shape/dtype
212
213
214
215
216
217
218
219
220
221
222
223
224
        backend = FusedAttnHelper(
            q_dtype,
            k_dtype,
            qkv_layout,
            attn_bias_type,
            attn_mask_type,
            dropout_probability,
            attn_heads,
            num_gqa_groups,
            q_max_seqlen,
            kv_max_seqlen,
            head_dim,
        ).get_fused_attn_backend()
225
226
227
228
229
230
231
232

        if backend == NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen:
            softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, kv_max_seqlen)
            softmax_dtype = q_dtype
        elif backend == NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen:
            softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, 1)
            softmax_dtype = dtypes.canonicalize_dtype(jnp.float32)
        else:
233
            raise ValueError(f"Unsupported {backend=}")
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
        softmax_aux_aval = q_aval.update(shape=softmax_shape, dtype=softmax_dtype)

        # JAX does not enable 64-bit int by default so we get XLA to allocate x8 memory with
        # 32-bit unsigned int to get the buffer size we need in the C++ kernel
        checker = _FusedAttnRNGStateChecker()
        seed_dtype = dtypes.canonicalize_dtype(seed_aval.dtype)
        assert seed_dtype == checker.rng_state_dtype
        rng_state_shape = (seed_aval.shape[0], checker.rng_state_size)
        rng_state_aval = seed_aval.update(shape=rng_state_shape, dtype=checker.rng_state_dtype)

        if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
            bias_batch = bias_heads = 0
        else:
            *bias_batch_shape, bias_heads, _, _ = bias_aval.shape
            bias_batch = reduce(operator.mul, bias_batch_shape)

        # do a dummy kernel call here to get workspace buffer shapes/dtypes that XLA needs to
        # prepare for the active fused-attn backend
        input_batch = reduce(operator.mul, batch_shape)
        wkspace_info = transformer_engine_jax.get_fused_attn_fwd_workspace_sizes(
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
            input_batch,
            bias_batch,
            q_max_seqlen,
            kv_max_seqlen,
            attn_heads,
            num_gqa_groups,
            bias_heads,
            head_dim,
            scaling_factor,
            dropout_probability,
            attn_bias_type,
            attn_mask_type,
            qkv_layout,
            jax_dtype_to_te_dtype(q_aval.dtype),
            is_training,
        )
        wkspace_aval = q_aval.update(
            shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1])
        )
273
274
275
276
277
278
279
280

        return out_aval, softmax_aux_aval, rng_state_aval, wkspace_aval

    @staticmethod
    def outer_abstract(*args, **kwargs):
        """
        Fused attention fwd outer primitive abstract
        """
281
282
283
        out_aval, softmax_aux_aval, rng_state_aval, _ = FusedAttnFwdPrimitive.abstract(
            *args, **kwargs
        )
284
285
286
        return out_aval, softmax_aux_aval, rng_state_aval

    @staticmethod
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
    def lowering(
        ctx,
        q,
        k,
        v,
        bias,
        q_cu_seqlen,
        kv_cu_seqlen,
        seed,
        *,
        attn_bias_type,
        attn_mask_type,
        qkv_layout,
        scaling_factor,
        dropout_probability,
        is_training,
    ):
304
305
306
307
308
309
310
311
312
313
314
315
316
        """
        Fused attention fwd lowering rules
        """
        operands = [q, k, v, bias, q_cu_seqlen, kv_cu_seqlen, seed]
        operand_shapes = map(lambda x: x.type.shape, operands)
        out_types = [
            ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
            for output in ctx.avals_out
        ]
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

        q_aval, k_aval, v_aval, bias_aval, *_ = ctx.avals_in

317
        batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = (
318
            FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, qkv_layout)
319
        )
320
321
322
323
324
325
326
327
328
329
330
331

        input_batch = reduce(operator.mul, batch_shape)

        if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
            bias_batch = bias_heads = 0
        else:
            *bias_batch_shape, bias_heads, _, _ = bias_aval.shape
            bias_batch = reduce(operator.mul, bias_batch_shape)

        wkspace_aval = ctx.avals_out[-1]

        opaque = transformer_engine_jax.pack_fused_attn_descriptor(
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
            input_batch,
            bias_batch,
            q_max_seqlen,
            kv_max_seqlen,
            attn_heads,
            num_gqa_groups,
            bias_heads,
            head_dim,
            wkspace_aval.size,
            scaling_factor,
            dropout_probability,
            attn_bias_type,
            attn_mask_type,
            qkv_layout,
            jax_dtype_to_te_dtype(q_aval.dtype),
            jax_dtype_to_te_dtype(wkspace_aval.dtype),
            is_training,
        )
350
351
352
353
354
355

        out = custom_caller(FusedAttnFwdPrimitive.name, args, opaque, has_side_effect=False)

        return out

    @staticmethod
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
    def impl(
        q,
        k,
        v,
        bias,
        q_seqlen,
        kv_seqlen,
        seed,
        attn_bias_type,
        attn_mask_type,
        qkv_layout,
        scaling_factor,
        dropout_probability,
        is_training,
    ):
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
        assert FusedAttnFwdPrimitive.inner_primitive is not None

        q_cu_seqlen = generate_cu_seqlen(q_seqlen)
        kv_cu_seqlen = generate_cu_seqlen(kv_seqlen)

        output, softmax_aux, rng_state, _ = FusedAttnFwdPrimitive.inner_primitive.bind(
            q,
            k,
            v,
            bias,
            q_cu_seqlen,
            kv_cu_seqlen,
            seed,
            attn_bias_type=attn_bias_type,
            attn_mask_type=attn_mask_type,
            qkv_layout=qkv_layout,
            scaling_factor=scaling_factor,
            dropout_probability=dropout_probability,
389
390
            is_training=is_training,
        )
391
392
393
        return output, softmax_aux, rng_state

    @staticmethod
394
395
396
397
398
399
400
401
402
403
404
    def batcher(
        batched_args,
        batch_dims,
        *,
        attn_bias_type,
        attn_mask_type,
        qkv_layout,
        scaling_factor,
        dropout_probability,
        is_training,
    ):
405
406
407
408
409
        check_valid_batch_dims(batch_dims)
        assert FusedAttnFwdPrimitive.outer_primitive is not None
        q_bdim, *_, seed_bdim = batch_dims

        out_bdims = q_bdim, q_bdim, seed_bdim
410
411
412
413
414
415
416
417
418
419
420
421
        return (
            FusedAttnFwdPrimitive.outer_primitive.bind(
                *batched_args,
                attn_bias_type=attn_bias_type,
                attn_mask_type=attn_mask_type,
                qkv_layout=qkv_layout,
                scaling_factor=scaling_factor,
                dropout_probability=dropout_probability,
                is_training=is_training,
            ),
            out_bdims,
        )
422
423

    @staticmethod
424
425
426
427
428
429
430
431
432
433
434
    def infer_sharding_from_operands(
        attn_bias_type,
        attn_mask_type,
        qkv_layout,
        scaling_factor,
        dropout_probability,
        is_training,
        mesh,
        arg_infos,
        result_infos,
    ):
435
436
437
438
439
440
441
442
443
        del attn_bias_type, attn_mask_type, scaling_factor
        del dropout_probability, is_training, result_infos
        q_spec = get_padded_spec(arg_infos[0])
        k_spec = get_padded_spec(arg_infos[1])
        match qkv_layout:
            case NVTE_QKV_Layout.NVTE_BS3HD:
                # q_spec = (...batch, q_seqlen, head, hidden)
                out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec[:-3], *q_spec[-2:]))
                softmax_aux_sharding = NamedSharding(
444
445
                    mesh, PartitionSpec(*q_spec[:-4], q_spec[-2], q_spec[-4], None)
                )
446
447
448
449
450
            case NVTE_QKV_Layout.NVTE_BSHD_BS2HD:
                # q_spec = (...batch, q_seqlen, head, hidden)
                # k_spec = (...batch, kv_seqlen, 2, num_gqa_groups, hidden)
                out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
                softmax_aux_sharding = NamedSharding(
451
452
                    mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], k_spec[-4])
                )
453
454
455
456
457
            case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD:
                # q_spec = (...batch, q_seqlen, head, hidden)
                # k_spec = (...batch, kv_seqlen, num_gqa_groups, hidden)
                out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
                softmax_aux_sharding = NamedSharding(
458
459
                    mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], k_spec[-3])
                )
460
461
462
463
464
465
            case _:
                raise ValueError(f"Unsupported {qkv_layout=}")
        rng_state_sharding = NamedSharding(mesh, PartitionSpec(get_all_mesh_axes(), None))
        return (out_sharding, softmax_aux_sharding, rng_state_sharding)

    @staticmethod
466
467
468
469
470
471
472
473
474
475
476
    def partition(
        attn_bias_type,
        attn_mask_type,
        qkv_layout,
        scaling_factor,
        dropout_probability,
        is_training,
        mesh,
        arg_infos,
        result_infos,
    ):
477
478
        out_sharding = result_infos[0].sharding
        softmax_aux_sharding = result_infos[1].sharding
479
480
481
        rng_state_sharding = seed_sharding = NamedSharding(
            mesh, PartitionSpec(get_all_mesh_axes(), None)
        )
482
483
        arg_shardings = tuple([arg_i.sharding for arg_i in arg_infos[:-1]] + [seed_sharding])
        out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding)
484
485
486
487
488
489
490
491
492
        impl = partial(
            FusedAttnFwdPrimitive.impl,
            attn_bias_type=attn_bias_type,
            attn_mask_type=attn_mask_type,
            qkv_layout=qkv_layout,
            scaling_factor=scaling_factor,
            dropout_probability=dropout_probability,
            is_training=is_training,
        )
493
494
495
496
497
498
499
500
501
502
        return mesh, impl, out_shardings, arg_shardings


register_primitive(FusedAttnFwdPrimitive)


class FusedAttnBwdPrimitive(BasePrimitive):
    """
    Fused Attention Backward Primitive
    """
503

504
505
506
507
508
509
510
    name = "te_fused_attn_backward"
    multiple_results = True
    impl_static_args = (10, 11, 12, 13, 14, 15)
    inner_primitive = None
    outer_primitive = None

    @staticmethod
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
    def abstract(
        q_aval,
        k_aval,
        v_aval,
        bias_aval,
        softmax_aux_aval,
        rng_state_aval,
        output_aval,
        doutput_aval,
        q_cu_seqlen_aval,
        kv_cu_seqlen_aval,
        *,
        attn_bias_type,
        attn_mask_type,
        qkv_layout,
        scaling_factor,
        dropout_probability,
        is_training,
    ):
530
531
532
533
534
535
536
537
538
539
540
541
542
        """
        Fused attention bwd abstract
        """
        del softmax_aux_aval, rng_state_aval, output_aval

        q_dtype = dtypes.canonicalize_dtype(q_aval.dtype)
        k_dtype = dtypes.canonicalize_dtype(k_aval.dtype)
        v_dtype = dtypes.canonicalize_dtype(v_aval.dtype)
        bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype)
        doutput_dtype = dtypes.canonicalize_dtype(doutput_aval.dtype)
        assert q_dtype == k_dtype == v_dtype == bias_dtype == doutput_dtype
        assert q_cu_seqlen_aval.dtype == kv_cu_seqlen_aval.dtype

543
        batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = (
544
            FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, qkv_layout)
545
        )
546
547
548
549
550
551
552
553

        if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
            bias_batch = bias_heads = 0
        else:
            *bias_batch_shape, bias_heads, _, _ = bias_aval.shape
            bias_batch = reduce(operator.mul, bias_batch_shape)

        input_batch = reduce(operator.mul, batch_shape)
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
        wkspace_shape, wkspace_dtype = transformer_engine_jax.get_fused_attn_bwd_workspace_sizes(
            input_batch,
            bias_batch,
            q_max_seqlen,
            kv_max_seqlen,
            attn_heads,
            num_gqa_groups,
            bias_heads,
            head_dim,
            scaling_factor,
            dropout_probability,
            attn_bias_type,
            attn_mask_type,
            qkv_layout,
            jax_dtype_to_te_dtype(q_aval.dtype),
            is_training,
        )
571
572
573
574
575

        dq_aval = q_aval.update(shape=q_aval.shape, dtype=q_dtype)
        dk_aval = k_aval.update(shape=k_aval.shape, dtype=k_dtype)
        dv_aval = v_aval.update(shape=v_aval.shape, dtype=v_dtype)
        dbias_aval = bias_aval.update(shape=bias_aval.shape, dtype=bias_dtype)
576
577
578
        wkspace_aval = q_aval.update(
            shape=wkspace_shape, dtype=te_dtype_to_jax_dtype(wkspace_dtype)
        )
579
580
581
582
583
584
585
586

        return dq_aval, dk_aval, dv_aval, dbias_aval, wkspace_aval

    @staticmethod
    def outer_abstract(*args, **kwargs):
        """
        Fused attention fwd outer primitive abstract
        """
587
        dq_aval, dk_aval, dv_aval, dbias_aval, _ = FusedAttnBwdPrimitive.abstract(*args, **kwargs)
588
589
590
        return dq_aval, dk_aval, dv_aval, dbias_aval

    @staticmethod
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
    def lowering(
        ctx,
        q,
        k,
        v,
        bias,
        softmax_aux,
        rng_state,
        output,
        doutput,
        q_cu_seqlen,
        kv_cu_seqlen,
        *,
        attn_bias_type,
        attn_mask_type,
        qkv_layout,
        scaling_factor,
        dropout_probability,
        is_training,
    ):
611
612
613
614
        """
        Fused attention bwd lowering rules
        """
        operands = [
615
616
617
618
619
620
621
622
623
624
            q,
            k,
            v,
            bias,
            softmax_aux,
            rng_state,
            output,
            doutput,
            q_cu_seqlen,
            kv_cu_seqlen,
625
626
627
628
629
630
631
632
633
634
635
        ]
        operand_shapes = map(lambda x: x.type.shape, operands)
        out_types = [
            ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
            for output in ctx.avals_out
        ]

        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

        q_aval, k_aval, v_aval, bias_aval, *_ = ctx.avals_in

636
        batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = (
637
            FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, qkv_layout)
638
        )
639
640
641
642
643
644
645
646
647
648
649
650

        input_batch = reduce(operator.mul, batch_shape)

        if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
            bias_batch = bias_heads = 0
        else:
            *bias_batch_shape, bias_heads, _, _ = bias_aval.shape
            bias_batch = reduce(operator.mul, bias_batch_shape)

        wkspace_aval = ctx.avals_out[-1]

        opaque = transformer_engine_jax.pack_fused_attn_descriptor(
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
            input_batch,
            bias_batch,
            q_max_seqlen,
            kv_max_seqlen,
            attn_heads,
            num_gqa_groups,
            bias_heads,
            head_dim,
            wkspace_aval.size,
            scaling_factor,
            dropout_probability,
            attn_bias_type,
            attn_mask_type,
            qkv_layout,
            jax_dtype_to_te_dtype(q_aval.dtype),
            jax_dtype_to_te_dtype(wkspace_aval.dtype),
            is_training,
        )
669
670
671
672
673
674

        out = custom_caller(FusedAttnBwdPrimitive.name, args, opaque, has_side_effect=False)

        return out

    @staticmethod
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
    def impl(
        q,
        k,
        v,
        bias,
        softmax_aux,
        rng_state,
        output,
        doutput,
        q_seqlen,
        kv_seqlen,
        attn_bias_type,
        attn_mask_type,
        qkv_layout,
        scaling_factor,
        dropout_probability,
        is_training,
    ):
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
        assert FusedAttnBwdPrimitive.inner_primitive is not None

        q_cu_seqlen = generate_cu_seqlen(q_seqlen)
        kv_cu_seqlen = generate_cu_seqlen(kv_seqlen)

        dq, dk, dv, dbias, _ = FusedAttnBwdPrimitive.inner_primitive.bind(
            q,
            k,
            v,
            bias,
            softmax_aux,
            rng_state,
            output,
            doutput,
            q_cu_seqlen,
            kv_cu_seqlen,
            attn_bias_type=attn_bias_type,
            attn_mask_type=attn_mask_type,
            qkv_layout=qkv_layout,
            scaling_factor=scaling_factor,
            dropout_probability=dropout_probability,
714
715
            is_training=is_training,
        )
716
717
718
        return dq, dk, dv, dbias

    @staticmethod
719
720
721
722
723
724
725
726
727
728
729
    def batcher(
        batched_args,
        batch_dims,
        *,
        attn_bias_type,
        attn_mask_type,
        qkv_layout,
        scaling_factor,
        dropout_probability,
        is_training,
    ):
730
731
732
733
734
        check_valid_batch_dims(batch_dims)
        assert FusedAttnBwdPrimitive.outer_primitive is not None
        q_bdim, k_bdim, v_bdim, *_ = batch_dims

        out_bdims = q_bdim, k_bdim, v_bdim, q_bdim
735
736
737
738
739
740
741
742
743
744
745
746
        return (
            FusedAttnBwdPrimitive.outer_primitive.bind(
                *batched_args,
                attn_bias_type=attn_bias_type,
                attn_mask_type=attn_mask_type,
                qkv_layout=qkv_layout,
                scaling_factor=scaling_factor,
                dropout_probability=dropout_probability,
                is_training=is_training,
            ),
            out_bdims,
        )
747
748

    @staticmethod
749
750
751
752
753
754
755
756
757
758
759
    def infer_sharding_from_operands(
        attn_bias_type,
        attn_mask_type,
        qkv_layout,
        scaling_factor,
        dropout_probability,
        is_training,
        mesh,
        arg_infos,
        result_infos,
    ):
760
761
762
763
764
765
766
767
768
769
770
771
772
        del attn_bias_type, attn_mask_type, qkv_layout, scaling_factor
        del dropout_probability, is_training, result_infos
        q_spec = get_padded_spec(arg_infos[0])
        k_spec = get_padded_spec(arg_infos[1])
        v_spec = get_padded_spec(arg_infos[2])
        bias_spec = get_padded_spec(arg_infos[3])
        dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
        dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec))
        dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec))
        dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec))
        return (dq_sharding, dk_sharding, dv_sharding, dbias_sharding)

    @staticmethod
773
774
775
776
777
778
779
780
781
782
783
    def partition(
        attn_bias_type,
        attn_mask_type,
        qkv_layout,
        scaling_factor,
        dropout_probability,
        is_training,
        mesh,
        arg_infos,
        result_infos,
    ):
784
785
786
787
788
789
790
791
792
793
794
795
        del result_infos
        q_spec = get_padded_spec(arg_infos[0])
        k_spec = get_padded_spec(arg_infos[1])
        v_spec = get_padded_spec(arg_infos[2])
        bias_spec = get_padded_spec(arg_infos[3])
        dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
        dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec))
        dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec))
        dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec))
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
        out_shardings = (dq_sharding, dk_sharding, dv_sharding, dbias_sharding)

796
797
798
        def sharded_impl(
            q, k, v, bias, softmax_aux, rng_state, output, doutput, q_cu_seqlen, kv_cu_seqlen
        ):
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
            local_dq, local_dk, local_dv, local_dbias = FusedAttnBwdPrimitive.impl(
                q,
                k,
                v,
                bias,
                softmax_aux,
                rng_state,
                output,
                doutput,
                q_cu_seqlen,
                kv_cu_seqlen,
                attn_bias_type=attn_bias_type,
                attn_mask_type=attn_mask_type,
                qkv_layout=qkv_layout,
                scaling_factor=scaling_factor,
                dropout_probability=dropout_probability,
815
816
                is_training=is_training,
            )
817
818
819
820
821
822
823
824
825
826
827
            global_dbias = local_dbias
            if attn_bias_type is not NVTE_Bias_Type.NVTE_NO_BIAS:
                global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias)
            return local_dq, local_dk, local_dv, global_dbias

        return mesh, sharded_impl, out_shardings, arg_shardings


register_primitive(FusedAttnBwdPrimitive)


828
829
830
831
832
833
834
835
836
837
838
def fused_attn_fwd_qkvpacked(
    qkv: jnp.ndarray,
    bias: jnp.ndarray,
    seqlen: jnp.ndarray,
    seed: jnp.ndarray,
    attn_bias_type: NVTE_Bias_Type,
    attn_mask_type: NVTE_Mask_Type,
    scaling_factor: float,
    dropout_probability: float,
    is_training: bool,
):
839
840
841
842
843
844
845
846
847
848
849
850
    """
    Wrapper for TE self fused attention fwd
    Return BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2
    """
    checker = _FusedAttnRNGStateChecker()
    seed = checker.check_seed(seed, dropout_probability, is_training)

    if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
        assert bias is None
        bias = jnp.zeros(0, dtype=qkv.dtype)

    _not_used = jnp.zeros(0, qkv.dtype)
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
    return FusedAttnFwdPrimitive.outer_primitive.bind(
        qkv,
        _not_used,
        _not_used,
        bias,
        seqlen,
        seqlen,
        seed,
        attn_bias_type=attn_bias_type,
        attn_mask_type=attn_mask_type,
        qkv_layout=NVTE_QKV_Layout.NVTE_BS3HD,
        scaling_factor=scaling_factor,
        dropout_probability=dropout_probability,
        is_training=is_training,
    )


def fused_attn_bwd_qkvpacked(
    qkv: jnp.ndarray,
    bias: jnp.ndarray,
    softmax_aux: jnp.ndarray,
    rng_state: jnp.ndarray,
    output: jnp.ndarray,
    doutput: jnp.ndarray,
    seqlen: jnp.ndarray,
    attn_bias_type: NVTE_Bias_Type,
    attn_mask_type: NVTE_Mask_Type,
    scaling_factor: float,
    dropout_probability: float,
    is_training: bool,
):
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
    """
    Wrapper for TE self fused attention bwd
    Return the gradients of self fused attention with packed qkv input
    """
    if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
        assert bias is None
        bias = jnp.zeros(0, dtype=qkv.dtype)
    dummy_input = jnp.zeros(0, dtype=qkv.dtype)
    dqkv, *_, dbias = FusedAttnBwdPrimitive.outer_primitive.bind(
        qkv,
        dummy_input,
        dummy_input,
        bias,
        softmax_aux,
        rng_state,
        output,
        doutput,
        seqlen,
        seqlen,
        attn_bias_type=attn_bias_type,
        attn_mask_type=attn_mask_type,
        qkv_layout=NVTE_QKV_Layout.NVTE_BS3HD,
        scaling_factor=scaling_factor,
        dropout_probability=dropout_probability,
906
907
        is_training=is_training,
    )
908
909
910
    return dqkv, dbias


911
912
913
914
915
916
917
918
919
920
921
922
923
def fused_attn_fwd_kvpacked(
    q: jnp.ndarray,
    kv: jnp.ndarray,
    bias: jnp.ndarray,
    q_seqlen: jnp.ndarray,
    kv_seqlen: jnp.ndarray,
    seed: jnp.ndarray,
    attn_bias_type: NVTE_Bias_Type,
    attn_mask_type: NVTE_Mask_Type,
    scaling_factor: float,
    dropout_probability: float,
    is_training: bool,
):
924
925
926
927
928
929
930
931
932
933
934
    """
    Wrapper for TE fused attention fwd with kvpacked inputs
    Return BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2
    """
    checker = _FusedAttnRNGStateChecker()
    seed = checker.check_seed(seed, dropout_probability, is_training)

    if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
        assert bias is None
        bias = jnp.zeros(0, dtype=q.dtype)

935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
    return FusedAttnFwdPrimitive.outer_primitive.bind(
        q,
        kv,
        jnp.zeros(0, q.dtype),
        bias,
        q_seqlen,
        kv_seqlen,
        seed,
        attn_bias_type=attn_bias_type,
        attn_mask_type=attn_mask_type,
        qkv_layout=NVTE_QKV_Layout.NVTE_BSHD_BS2HD,
        scaling_factor=scaling_factor,
        dropout_probability=dropout_probability,
        is_training=is_training,
    )


def fused_attn_bwd_kvpacked(
    q: jnp.ndarray,
    kv: jnp.ndarray,
    bias: jnp.ndarray,
    softmax_aux: jnp.ndarray,
    rng_state: jnp.ndarray,
    output: jnp.ndarray,
    doutput: jnp.ndarray,
    q_seqlen: jnp.ndarray,
    kv_seqlen: jnp.ndarray,
    attn_bias_type: NVTE_Bias_Type,
    attn_mask_type: NVTE_Mask_Type,
    scaling_factor: float,
    dropout_probability: float,
    is_training: bool,
):
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
    """
    Wrapper for TE fused attention bwd with kvpacked inputs
    Return the gradients of fused attention with packed kv input
    """
    if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
        assert bias is None
        bias = jnp.zeros(0, dtype=q.dtype)
    dummy_input = jnp.zeros(0, q.dtype)
    dq, dkv, _, dbias = FusedAttnBwdPrimitive.outer_primitive.bind(
        q,
        kv,
        dummy_input,
        bias,
        softmax_aux,
        rng_state,
        output,
        doutput,
        q_seqlen,
        kv_seqlen,
        attn_bias_type=attn_bias_type,
        attn_mask_type=attn_mask_type,
        qkv_layout=NVTE_QKV_Layout.NVTE_BSHD_BS2HD,
        scaling_factor=scaling_factor,
        dropout_probability=dropout_probability,
992
993
        is_training=is_training,
    )
994
995
996
    return dq, dkv, dbias


997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
def fused_attn_fwd(
    q: jnp.ndarray,
    k: jnp.ndarray,
    v: jnp.ndarray,
    bias: jnp.ndarray,
    q_seqlen: jnp.ndarray,
    kv_seqlen: jnp.ndarray,
    seed: jnp.ndarray,
    attn_bias_type: NVTE_Bias_Type,
    attn_mask_type: NVTE_Mask_Type,
    scaling_factor: float,
    dropout_probability: float,
    is_training: bool,
):
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
    """
    Wrapper for TE fused attention fwd, where query, key, value are seperated tensors
    Return BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2
    """
    checker = _FusedAttnRNGStateChecker()
    seed = checker.check_seed(seed, dropout_probability, is_training)

    if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
        assert bias is None
        bias = jnp.zeros(0, dtype=q.dtype)

    return FusedAttnFwdPrimitive.outer_primitive.bind(
        q,
        k,
        v,
        bias,
        q_seqlen,
        kv_seqlen,
        seed,
        attn_bias_type=attn_bias_type,
        attn_mask_type=attn_mask_type,
        qkv_layout=NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD,
        scaling_factor=scaling_factor,
        dropout_probability=dropout_probability,
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
        is_training=is_training,
    )


def fused_attn_bwd(
    q: jnp.ndarray,
    k: jnp.ndarray,
    v: jnp.ndarray,
    bias: jnp.ndarray,
    softmax_aux: jnp.ndarray,
    rng_state: jnp.ndarray,
    output: jnp.ndarray,
    doutput: jnp.ndarray,
    q_seqlen: jnp.ndarray,
    kv_seqlen: jnp.ndarray,
    attn_bias_type: NVTE_Bias_Type,
    attn_mask_type: NVTE_Mask_Type,
    scaling_factor: float,
    dropout_probability: float,
    is_training: bool,
):
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
    """
    Wrapper for TE fused attention bwd
    Return the gradients of fused attention with seperated query, key, value tensors
    """
    if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
        assert bias is None
        bias = jnp.zeros(0, dtype=q.dtype)
    return FusedAttnBwdPrimitive.outer_primitive.bind(
        q,
        k,
        v,
        bias,
        softmax_aux,
        rng_state,
        output,
        doutput,
        q_seqlen,
        kv_seqlen,
        attn_bias_type=attn_bias_type,
        attn_mask_type=attn_mask_type,
        qkv_layout=NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD,
        scaling_factor=scaling_factor,
        dropout_probability=dropout_probability,
1079
1080
        is_training=is_training,
    )