attention.py 148 KB
Newer Older
1
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.
"""JAX/TE custom ops for attention"""
import operator
6
import os
7
import warnings
8
9
10
from dataclasses import dataclass, replace
from functools import partial, reduce
from typing import Optional, Tuple
11

12
import jax
13
import jax.numpy as jnp
14
from jax import dtypes, lax, ffi
15
from jax.sharding import PartitionSpec, NamedSharding
16
from jax.experimental.custom_partitioning import SdyShardingRule
17
18
19

import transformer_engine_jax
from transformer_engine_jax import NVTE_Fused_Attn_Backend
Reese Wang's avatar
Reese Wang committed
20
21
22
from transformer_engine.jax.attention import (
    AttnBiasType,
    AttnMaskType,
23
    AttnSoftmaxType,
Reese Wang's avatar
Reese Wang committed
24
25
26
27
28
    QKVLayout,
    QKVFormat,
    CPStrategy,
    SequenceDescriptor,
)
29
from ..sharding import with_sharding_constraint_by_logical_axes, HEAD_AXES, is_mesh_available
30

31
32
33
34
35
from .base import BasePrimitive, register_primitive
from .misc import (
    check_valid_batch_dims,
    jax_dtype_to_te_dtype,
    te_dtype_to_jax_dtype,
36
    get_padded_spec,
37
    get_cudnn_version,
38
    get_all_device_compute_capability,
39
40
)
from ..sharding import (
41
42
    global_mesh_resource,
    lax_paral_op,
43
    all_reduce_sum_along_dp_fsdp,
44
45
    get_mesh_axis_size,
    get_mesh_axis_rank,
46
    get_mesh_axis_rank_host,
47
48
    get_all_mesh_axes,
    num_of_devices,
49
    with_sharding_constraint,
50
51
52
)


53
54
55
56
57
__all__ = [
    "FusedAttnHelper",
    "fused_attn_fwd",
    "fused_attn_bwd",
]
58
59


60
61
62
63
64
65
@partial(
    jax.tree_util.register_dataclass,
    data_fields=[],
    meta_fields=[
        "attn_bias_type",
        "attn_mask_type",
66
        "softmax_type",
67
68
69
70
71
        "qkv_layout",
        "scaling_factor",
        "dropout_probability",
        "is_training",
        "max_segments_per_seq",
72
        "window_size",
73
        "bottom_right_diagonal",
74
75
        "context_parallel_load_balanced",
        "cp_axis",
76
        "cp_striped_window_size",
77
        "stripe_size",
78
79
80
81
82
83
84
85
    ],
)
@dataclass(frozen=True)
class _FusedAttnConfig:
    """
    Passes static configuration properties of fused attention.
    """

Reese Wang's avatar
Reese Wang committed
86
87
    attn_bias_type: AttnBiasType
    attn_mask_type: AttnMaskType
88
    softmax_type: AttnSoftmaxType
Reese Wang's avatar
Reese Wang committed
89
    qkv_layout: QKVLayout
90
91
92
93
    scaling_factor: float
    dropout_probability: float
    is_training: bool
    max_segments_per_seq: int
94
    window_size: Tuple[int, int]
95
    bottom_right_diagonal: bool
96
97
    context_parallel_load_balanced: bool
    cp_axis: str
98
99
100
101
    cp_striped_window_size: Tuple[int, int]  # Only for CP + Ring P2P + THD + SWA
    stripe_size: (
        int | None
    )  # Only for CP + Striped. For Ring P2P, stripe_size=1 only.For AG, stripe_size>=1.
102
103


104
105
106
107
108
109
@dataclass(frozen=True)
class FusedAttnHelper:
    """
    Helper for the fused attention backend
    """

110
    is_training: bool
111
112
    q_dtype: jnp.dtype
    kv_dtype: jnp.dtype
Reese Wang's avatar
Reese Wang committed
113
114
115
    qkv_layout: QKVLayout
    attn_bias_type: AttnBiasType
    attn_mask_type: AttnMaskType
116
    softmax_type: AttnSoftmaxType
117
118
119
120
121
    dropout_probability: float
    q_num_heads: int
    kv_num_heads: int
    q_max_seqlen: int
    kv_max_seqlen: int
122
123
    head_dim_qk: int
    head_dim_v: int
124
    window_size: Tuple[int, int]
125
126
127
128
129
130
131
132

    def is_fused_attn_kernel_available(self):
        """Check if there is available fused attention kernel"""
        return self.get_fused_attn_backend() != NVTE_Fused_Attn_Backend.NVTE_No_Backend

    def get_fused_attn_backend(self):
        """Get the fused attention kernel backend"""
        return transformer_engine_jax.get_fused_attn_backend(
133
            self.is_training,
134
135
            jax_dtype_to_te_dtype(self.q_dtype),
            jax_dtype_to_te_dtype(self.kv_dtype),
Reese Wang's avatar
Reese Wang committed
136
137
138
            self.qkv_layout.value,
            self.attn_bias_type.value,
            self.attn_mask_type.value,
139
            self.softmax_type.value,
140
141
142
143
144
            self.dropout_probability,
            self.q_num_heads,
            self.kv_num_heads,
            self.q_max_seqlen,
            self.kv_max_seqlen,
145
146
            self.head_dim_qk,
            self.head_dim_v,
147
148
            self.window_size[0],
            self.window_size[1],
149
            not self.is_non_deterministic_allowed(),
150
        )
151

152
153
154
155
156
    @staticmethod
    def is_non_deterministic_allowed():
        """Check if non-deterministic kernels are allowed"""
        return bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1")))

157
158
159
    @staticmethod
    def parse_qkv_aval(q_aval, k_aval, v_aval, qkv_layout):
        """Parse qkv aval"""
Reese Wang's avatar
Reese Wang committed
160
161
162
163
164
165
166
        if qkv_layout.get_qkv_format() == QKVFormat.SBHD:
            raise NotImplementedError
        if qkv_layout.is_qkvpacked():
            *q_batch_shape, q_max_seqlen, nqkv, attn_heads, q_head_dim = q_aval.shape
            kv_batch_shape = q_batch_shape
            kv_max_seqlen = q_max_seqlen
            num_gqa_groups = attn_heads
167
            v_head_dim = q_head_dim
Reese Wang's avatar
Reese Wang committed
168
169
170
            assert nqkv == 3
        elif qkv_layout.is_kvpacked():
            *q_batch_shape, q_max_seqlen, attn_heads, q_head_dim = q_aval.shape
171
172
173
            *kv_batch_shape, kv_max_seqlen, nkv, num_gqa_groups, v_head_dim = k_aval.shape
            assert q_batch_shape == kv_batch_shape
            assert q_head_dim == v_head_dim
Reese Wang's avatar
Reese Wang committed
174
175
176
            assert nkv == 2
        elif qkv_layout.is_separate():
            *q_batch_shape, q_max_seqlen, attn_heads, q_head_dim = q_aval.shape
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
            *k_batch_shape, k_max_seqlen, k_num_gqa_groups, k_head_dim = k_aval.shape
            *v_batch_shape, v_max_seqlen, v_num_gqa_groups, v_head_dim = v_aval.shape
            assert (
                q_head_dim == k_head_dim
            ), f"Mismatched q_head_dim: {q_head_dim} and k_head_dim: {k_head_dim}"
            assert (
                k_max_seqlen == v_max_seqlen
            ), f"Mismatched k_max_seqlen: {k_max_seqlen} and v_max_seqlen: {v_max_seqlen}"
            kv_max_seqlen = k_max_seqlen
            assert q_batch_shape == k_batch_shape == v_batch_shape, (
                f"Mismatched qkv batch size for q_batch_shape: {q_batch_shape}, k_batch_shape:"
                f" {k_batch_shape} and v_batch_shape: {v_batch_shape}"
            )
            assert k_num_gqa_groups == v_num_gqa_groups, (
                f"Mismatched k_num_gqa_groups: {k_num_gqa_groups} and v_num_gqa_groups:"
                f" {v_num_gqa_groups}"
            )
            num_gqa_groups = k_num_gqa_groups
Reese Wang's avatar
Reese Wang committed
195
196
        else:
            raise ValueError(f"Unexpected {qkv_layout=}")
197
198
199
200
201
202
203
204
205
206
207
208
209
        assert q_aval.dtype == k_aval.dtype == v_aval.dtype, (
            f"Mismatched data types for q_aval: {q_aval.dtype}, k_aval: {k_aval.dtype}, v_aval:"
            f" {v_aval.dtype}"
        )
        return (
            q_batch_shape,
            q_max_seqlen,
            kv_max_seqlen,
            attn_heads,
            num_gqa_groups,
            q_head_dim,
            v_head_dim,
        )
210
211
212
213
214
215
216
217
218
219
220


@dataclass(frozen=True)
class _FusedAttnRNGStateChecker:
    """
    Checker for guarding the fused attention rng state.
    The fused attention backend requires a 64 bits seed and a 64 bits offset.
    However, JAX doesn't enable 64 bits by default,
    so we have to emulate seed as two 32 bits array.
    The offset calculation is maintained in the backend.
    """
221

222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
    rng_state_dtype: jnp.dtype = jnp.uint32
    # (seed,) with internal dtype int64
    seed_size: int = 2
    # (seed, offset) with internal dtype int64
    rng_state_size: int = 2 * 2

    def check_seed(self, seed, dropout_probability, is_training):
        """
        Check the seed and convert the data type of seed if possible.
        """
        # Jax can't bind None, create a dummy tensor for None
        if seed is None:
            dropout_enabled = dropout_probability > 0 and is_training
            assert not dropout_enabled, "seed is not allowed to be None when dropout is enabled."
            seed = jnp.zeros(2, dtype=self.rng_state_dtype)
            seed = jnp.repeat(seed, num_of_devices())

        if seed.dtype != self.rng_state_dtype:
            warnings.warn(
                f"Requested {seed.dtype=} is not available, and will be "
                f"casted to dtype {self.rng_state_dtype}. "
243
244
                "Please use threefry/rbg/unsafe_rbg PRNG implementations to remove this warning."
            )
245
246
247
248
249
250
251
252
253
254
255
256
257
            seed = seed.astype(self.rng_state_dtype)

        assert seed.dtype == self.rng_state_dtype
        # Backend takes an int64_t seed, so only the first two u32 elements are taken
        assert seed.size >= self.seed_size

        return seed


def generate_cu_seqlen(actual_seqlen):
    """
    Generating cumsum seqlen for a batch
    """
258
259
    actual_seqlen = jnp.where(actual_seqlen < 0, 0, actual_seqlen)
    cu_seqlen = jnp.cumulative_sum(actual_seqlen, include_initial=True)
260
261
262
263
264
265
266
    return cu_seqlen


class FusedAttnFwdPrimitive(BasePrimitive):
    """
    Fused Attention Forward Primitive
    """
267

268
    name = "te_fused_attn_forward_ffi"
269
    multiple_results = True
270
    impl_static_args = (14,)
271
272
273
274
    inner_primitive = None
    outer_primitive = None

    @staticmethod
275
276
277
278
279
    def abstract(
        q_aval,
        k_aval,
        v_aval,
        bias_aval,
280
        softmax_offset_aval,
281
        seed_aval,
282
283
        q_seqlen_or_cu_seqlen_aval,
        kv_seqlen_or_cu_seqlen_aval,
284
285
        _q_seq_offsets,
        _k_seq_offsets,
286
287
288
289
        _q_segment_ids,
        _kv_segment_ids,
        _q_segment_pos,
        _kv_segment_pos,
290
        *,
291
        config: _FusedAttnConfig,
292
    ):
293
294
295
296
297
298
299
        """
        Fused attention fwd abstract
        """
        q_dtype = dtypes.canonicalize_dtype(q_aval.dtype)
        k_dtype = dtypes.canonicalize_dtype(k_aval.dtype)
        v_dtype = dtypes.canonicalize_dtype(v_aval.dtype)
        bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype)
300
301
302
303
304
305
306
        assert (
            q_dtype == k_dtype == v_dtype == bias_dtype
        ), f"q_dtype={q_dtype}, k_dtype={k_dtype}, v_dtype={v_dtype}, bias_dtype={bias_dtype}"
        assert q_seqlen_or_cu_seqlen_aval.dtype == kv_seqlen_or_cu_seqlen_aval.dtype, (
            f"q_seqlen_or_cu_seqlen_aval={q_seqlen_or_cu_seqlen_aval},"
            f" kv_seqlen_or_cu_seqlen_aval={kv_seqlen_or_cu_seqlen_aval}"
        )
307

308
309
310
311
312
313
314
315
316
        (
            batch_shape,
            q_max_seqlen,
            kv_max_seqlen,
            attn_heads,
            num_gqa_groups,
            q_head_dim,
            v_head_dim,
        ) = FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout)
317

318
        output_shape = (*batch_shape, q_max_seqlen, attn_heads, v_head_dim)
319
320
321
        out_aval = q_aval.update(shape=output_shape, dtype=q_dtype)

        # backend determines the softmax buffer shape/dtype
322
        backend = FusedAttnHelper(
323
            config.is_training,
324
325
            q_dtype,
            k_dtype,
326
327
328
            config.qkv_layout,
            config.attn_bias_type,
            config.attn_mask_type,
329
            config.softmax_type,
330
            config.dropout_probability,
331
332
333
334
            attn_heads,
            num_gqa_groups,
            q_max_seqlen,
            kv_max_seqlen,
335
336
            q_head_dim,
            v_head_dim,
337
            config.window_size,
338
        ).get_fused_attn_backend()
339
340
341
342
343

        if backend == NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen:
            softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, kv_max_seqlen)
            softmax_dtype = q_dtype
        elif backend == NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen:
344
345
            # cuDNN 9.6 reduces the required softmax shape
            if get_cudnn_version() >= (9, 6, 0):
346
347
348
349
                if config.qkv_layout.is_thd():
                    softmax_shape = (*batch_shape, q_max_seqlen, attn_heads, 1)
                else:
                    softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, 1)
350
351
352
353
354
355
356
            else:
                softmax_shape = (
                    *batch_shape,
                    attn_heads,
                    q_max_seqlen,
                    config.max_segments_per_seq,
                )
357
358
            softmax_dtype = dtypes.canonicalize_dtype(jnp.float32)
        else:
359
            raise ValueError(f"Unsupported {backend=}")
360
361
362
363
364
365
366
367
368
369
        softmax_aux_aval = q_aval.update(shape=softmax_shape, dtype=softmax_dtype)

        # JAX does not enable 64-bit int by default so we get XLA to allocate x8 memory with
        # 32-bit unsigned int to get the buffer size we need in the C++ kernel
        checker = _FusedAttnRNGStateChecker()
        seed_dtype = dtypes.canonicalize_dtype(seed_aval.dtype)
        assert seed_dtype == checker.rng_state_dtype
        rng_state_shape = (seed_aval.shape[0], checker.rng_state_size)
        rng_state_aval = seed_aval.update(shape=rng_state_shape, dtype=checker.rng_state_dtype)

Reese Wang's avatar
Reese Wang committed
370
        if config.attn_bias_type == AttnBiasType.NO_BIAS:
371
372
373
374
375
            bias_batch = bias_heads = 0
        else:
            *bias_batch_shape, bias_heads, _, _ = bias_aval.shape
            bias_batch = reduce(operator.mul, bias_batch_shape)

376
377
378
379
380
        bottom_right_diagonal = config.attn_mask_type in [
            AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK,
            AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK,
        ]

381
382
383
384
        # do a dummy kernel call here to get workspace buffer shapes/dtypes that XLA needs to
        # prepare for the active fused-attn backend
        input_batch = reduce(operator.mul, batch_shape)
        wkspace_info = transformer_engine_jax.get_fused_attn_fwd_workspace_sizes(
385
386
387
388
389
390
391
            input_batch,
            bias_batch,
            q_max_seqlen,
            kv_max_seqlen,
            attn_heads,
            num_gqa_groups,
            bias_heads,
392
393
            q_head_dim,
            v_head_dim,
394
395
            config.scaling_factor,
            config.dropout_probability,
Reese Wang's avatar
Reese Wang committed
396
397
            config.attn_bias_type.value,
            config.attn_mask_type.value,
398
            config.softmax_type.value,
Reese Wang's avatar
Reese Wang committed
399
            config.qkv_layout.value,
400
            jax_dtype_to_te_dtype(q_aval.dtype),
401
402
            config.is_training,
            config.max_segments_per_seq,
403
404
            config.window_size[0],
            config.window_size[1],
405
            bottom_right_diagonal,
406
407
408
409
        )
        wkspace_aval = q_aval.update(
            shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1])
        )
410

411
412
413
414
415
416
        assert softmax_offset_aval.dtype == jnp.float32
        if config.softmax_type != AttnSoftmaxType.VANILLA_SOFTMAX:
            assert softmax_offset_aval.shape == (1, attn_heads, 1, 1)
        else:
            assert softmax_offset_aval.shape == (0,)

417
418
419
420
421
422
423
        return out_aval, softmax_aux_aval, rng_state_aval, wkspace_aval

    @staticmethod
    def outer_abstract(*args, **kwargs):
        """
        Fused attention fwd outer primitive abstract
        """
424
425
426
        out_aval, softmax_aux_aval, rng_state_aval, _ = FusedAttnFwdPrimitive.abstract(
            *args, **kwargs
        )
427
428
429
        return out_aval, softmax_aux_aval, rng_state_aval

    @staticmethod
430
431
432
433
434
435
    def lowering(
        ctx,
        q,
        k,
        v,
        bias,
436
        softmax_offset,
437
        seed,
438
439
        q_cu_seqlen,
        kv_cu_seqlen,
440
441
        q_seq_offsets,
        k_seq_offsets,
442
443
444
445
        _q_segment_ids,
        _kv_segment_ids,
        _q_segment_pos,
        _kv_segment_pos,
446
        *,
447
        config: _FusedAttnConfig,
448
    ):
449
450
451
452
453
        """
        Fused attention fwd lowering rules
        """
        q_aval, k_aval, v_aval, bias_aval, *_ = ctx.avals_in

454
455
456
457
458
459
460
461
462
        (
            batch_shape,
            q_max_seqlen,
            kv_max_seqlen,
            attn_heads,
            num_gqa_groups,
            q_head_dim,
            v_head_dim,
        ) = FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout)
463
464
465

        input_batch = reduce(operator.mul, batch_shape)

Reese Wang's avatar
Reese Wang committed
466
        if config.attn_bias_type == AttnBiasType.NO_BIAS:
467
468
469
470
471
            bias_batch = bias_heads = 0
        else:
            *bias_batch_shape, bias_heads, _, _ = bias_aval.shape
            bias_batch = reduce(operator.mul, bias_batch_shape)

472
473
474
475
476
477
478
        if config.cp_striped_window_size is not None:
            window_size_left = config.cp_striped_window_size[0]
            window_size_right = config.cp_striped_window_size[1]
        else:
            window_size_left = config.window_size[0]
            window_size_right = config.window_size[1]

479
480
481
482
483
484
        return ffi.ffi_lowering(FusedAttnFwdPrimitive.name)(
            ctx,
            q,
            k,
            v,
            bias,
485
            softmax_offset,
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
            seed,
            q_cu_seqlen,
            kv_cu_seqlen,
            q_seq_offsets,
            k_seq_offsets,
            _q_segment_ids,
            _kv_segment_ids,
            _q_segment_pos,
            _kv_segment_pos,  # ffi_lowering needs number of parameters meets primitive.lowering
            input_batch=input_batch,
            bias_batch=bias_batch,
            q_max_seqlen=q_max_seqlen,
            kv_max_seqlen=kv_max_seqlen,
            attn_heads=attn_heads,
            num_gqa_groups=num_gqa_groups,
            bias_heads=bias_heads,
502
503
            qk_head_dim=q_head_dim,
            v_head_dim=v_head_dim,
504
505
506
507
508
509
510
511
            max_segments_per_seq=config.max_segments_per_seq,
            scaling_factor=float(config.scaling_factor),
            dropout_probability=float(config.dropout_probability),
            bias_type=int(config.attn_bias_type.value),
            mask_type=int(config.attn_mask_type.value),
            qkv_layout=int(config.qkv_layout.value),
            is_training=config.is_training,
            deterministic=not FusedAttnHelper.is_non_deterministic_allowed(),
512
513
            window_size_left=window_size_left,
            window_size_right=window_size_right,
514
            bottom_right_diagonal=config.bottom_right_diagonal,
515
            softmax_type=int(config.softmax_type.value),
516
        )
517
518

    @staticmethod
519
520
521
522
523
    def impl(
        q,
        k,
        v,
        bias,
524
        softmax_offset,
525
        seed,
526
527
        q_seqlen,
        kv_seqlen,
528
529
        q_seq_offsets,
        k_seq_offsets,
530
531
532
533
        _q_segment_ids,
        _kv_segment_ids,
        _q_segment_pos,
        _kv_segment_pos,
534
        config: _FusedAttnConfig,
535
    ):
536
537
        assert FusedAttnFwdPrimitive.inner_primitive is not None

538
539
540
541
542
543
544
545
546
547
548
549
550
551
        sequence_descriptor = SequenceDescriptor(
            seqlens=(q_seqlen, kv_seqlen),
            seq_offsets=(q_seq_offsets, k_seq_offsets),
            segment_ids=(_q_segment_ids, _kv_segment_ids),
            segment_pos=(_q_segment_pos, _kv_segment_pos),
        )
        (q_seqlen, kv_seqlen), (q_seq_offsets, k_seq_offsets) = (
            sequence_descriptor.get_seqlens_and_offsets(
                config.attn_mask_type,
                config.qkv_layout,
                config.window_size,
                config.max_segments_per_seq,
            )
        )
Reese Wang's avatar
Reese Wang committed
552
        if config.qkv_layout.is_thd():
553

554
            def _fix_len_take(x, condition, fill_value=-1):
555
556
557
558
                x_shape = x.shape
                x = x.flatten()
                size = x.size
                indices = jnp.nonzero(condition.flatten(), size=size, fill_value=size)[0]
559
                y = jnp.take(x, indices, fill_value=fill_value)
560
561
562
563
564
565
566
567
568
569
                return jnp.reshape(y, x_shape)

            def convert_to_2d(offsets, batch, max_seqlen):
                offsets_2d = jnp.where(
                    offsets >= 0,
                    offsets + (jnp.arange(batch) * max_seqlen)[..., jnp.newaxis],
                    offsets,
                )
                return offsets_2d

Reese Wang's avatar
Reese Wang committed
570
571
572
573
574
            batch, q_max_seqlen, kv_max_seqlen, *_ = FusedAttnHelper.parse_qkv_aval(
                q, k, v, config.qkv_layout
            )
            assert len(batch) == 1, f"Expected len(batch) == 1, but got {len(batch)=}"
            kv_batch = q_batch = batch[0]
575
576

            # Gather valid q_seqlen, which is greater than 0
577
            # cuDNN version < 9.3.0:
578
            # [[3, 5, 7, -1, -1], [2, 4, 6, -1, -1]] -> [[3, 5, 7, 2, 4], [6, -1, -1, -1, -1]]
579
580
581
582
583
584
            # cuDNN version >= 9.3.0, which supports act_seqlen = 0
            # [[3, 5, 7, -1, -1], [2, 4, 6, -1, -1]] -> [[3, 5, 7, 2, 4], [6, 0, 0, 0, 0]]
            if get_cudnn_version() >= (9, 3, 0):
                fill_value = 0
            else:
                fill_value = -1
585

586
587
            q_seqlen = _fix_len_take(q_seqlen, q_seqlen > 0, fill_value=fill_value)
            kv_seqlen = _fix_len_take(kv_seqlen, kv_seqlen > 0, fill_value=fill_value)
588
589
590
591
592

            # Flatten the offset calculation
            # max_seqlen = 8, [[0, 3, 5, -1], [0, 2, 4, -1]] -> [[0, 3, 5, -1], [8, 11, 13, -1]]
            q_seq_offsets = convert_to_2d(q_seq_offsets, q_batch, q_max_seqlen)
            k_seq_offsets = convert_to_2d(k_seq_offsets, kv_batch, kv_max_seqlen)
593

594
595
            # Gather valid q_seq_offsets, which is greater and equal to 0
            # [[0, 3, 5, -1], [8, 11, 13, -1]] -> [[0, 3, 5, 8], [11, 13, -1, -1]]
596
            # And set the unused position to max size (batch * max_seqlen)
597
            # [[0, 3, 5, 8], [11, 13, -1, -1]] -> [[0, 3, 5, 8], [11, 13, b*s, b*s]]
598
599
600
601
602
603
            q_seq_offsets = _fix_len_take(
                q_seq_offsets, q_seq_offsets >= 0, fill_value=q_batch * q_max_seqlen
            )
            k_seq_offsets = _fix_len_take(
                k_seq_offsets, k_seq_offsets >= 0, fill_value=kv_batch * kv_max_seqlen
            )
604
605
606

        q_cu_seqlen = generate_cu_seqlen(q_seqlen.flatten())
        kv_cu_seqlen = generate_cu_seqlen(kv_seqlen.flatten())
607
608
609
610
611
612

        output, softmax_aux, rng_state, _ = FusedAttnFwdPrimitive.inner_primitive.bind(
            q,
            k,
            v,
            bias,
613
            softmax_offset,
614
            seed,
615
616
            q_cu_seqlen,
            kv_cu_seqlen,
617
618
            q_seq_offsets,
            k_seq_offsets,
619
620
621
622
            _q_segment_ids,
            _kv_segment_ids,
            _q_segment_pos,
            _kv_segment_pos,
623
            config=config,
624
        )
625
626
627
        return output, softmax_aux, rng_state

    @staticmethod
628
    def batcher(batched_args, batch_dims, *, config):
629
630
        check_valid_batch_dims(batch_dims)
        assert FusedAttnFwdPrimitive.outer_primitive is not None
631
        q_bdim, _, _, _, _, seed_bdim, *_ = batch_dims
632
633

        out_bdims = q_bdim, q_bdim, seed_bdim
634
        return (
635
            FusedAttnFwdPrimitive.outer_primitive.bind(*batched_args, config=config),
636
637
            out_bdims,
        )
638
639

    @staticmethod
640
641
    def infer_sharding_from_operands(config, mesh, arg_infos, result_infos):
        del result_infos
642
        q_spec = get_padded_spec(arg_infos[0])
643
644
645
646
647

        # when supported softmax_aux shape is (b, s, h, 1) for thd on cudnn 9.6+
        # otherwise softmax_aux shape is (b, h, s, 1) or (b, h, s, max_segments)
        is_packed_softmax = get_cudnn_version() >= (9, 6, 0) and config.qkv_layout.is_thd()

Reese Wang's avatar
Reese Wang committed
648
649
650
        if config.qkv_layout.is_qkvpacked():
            # q_spec = (...batch, q_seqlen, 3, head, hidden)
            out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec[:-3], *q_spec[-2:]))
651
652
653
654
655
656
657
658
            if not is_packed_softmax:
                softmax_aux_sharding = NamedSharding(
                    mesh, PartitionSpec(*q_spec[:-4], q_spec[-2], q_spec[-4], None)
                )
            else:
                softmax_aux_sharding = NamedSharding(
                    mesh, PartitionSpec(*q_spec[:-4], q_spec[-4], q_spec[-2], None)
                )
Reese Wang's avatar
Reese Wang committed
659
660
661
662
        elif config.qkv_layout.is_kvpacked():
            # q_spec = (...batch, q_seqlen, head, hidden)
            # k_spec = (...batch, kv_seqlen, 2, num_gqa_groups, hidden)
            out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
663
664
665
666
667
668
669
670
            if not is_packed_softmax:
                softmax_aux_sharding = NamedSharding(
                    mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], None)
                )
            else:
                softmax_aux_sharding = NamedSharding(
                    mesh, PartitionSpec(*q_spec[:-3], q_spec[-3], q_spec[-2], None)
                )
Reese Wang's avatar
Reese Wang committed
671
672
673
674
        elif config.qkv_layout.is_separate():
            # q_spec = (...batch, q_seqlen, head, hidden)
            # k_spec = (...batch, kv_seqlen, num_gqa_groups, hidden)
            out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
675
676
677
678
679
680
681
682
            if not is_packed_softmax:
                softmax_aux_sharding = NamedSharding(
                    mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], None)
                )
            else:
                softmax_aux_sharding = NamedSharding(
                    mesh, PartitionSpec(*q_spec[:-3], q_spec[-3], q_spec[-2], None)
                )
Reese Wang's avatar
Reese Wang committed
683
684
        else:
            raise ValueError(f"Unsupported {config.qkv_layout=}")
685

686
687
688
689
        rng_state_sharding = NamedSharding(mesh, PartitionSpec(get_all_mesh_axes(), None))
        return (out_sharding, softmax_aux_sharding, rng_state_sharding)

    @staticmethod
690
    def partition(config, mesh, arg_infos, result_infos):
691
692
        out_sharding = result_infos[0].sharding
        softmax_aux_sharding = result_infos[1].sharding
693
694
695
        rng_state_sharding = seed_sharding = NamedSharding(
            mesh, PartitionSpec(get_all_mesh_axes(), None)
        )
696
        arg_shardings = [arg_i.sharding for arg_i in arg_infos]
697
        arg_shardings[5] = seed_sharding
698
699
        arg_shardings[-1] = arg_shardings[-3]
        arg_shardings[-2] = arg_shardings[-4]
700
        arg_shardings = tuple(arg_shardings)
701
        out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding)
702
        impl = partial(FusedAttnFwdPrimitive.impl, config=config)
703
704
        return mesh, impl, out_shardings, arg_shardings

705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
    @staticmethod
    def shardy_sharding_rule(config, mesh, value_types, result_types):
        del mesh, result_types

        # Keep in sync with `infer_sharding_from_operands`.
        # We only need the first input. Fill up the rest with placeholders.
        input_spec = [(f"…{x}",) for x in range(len(value_types))]
        # The RNG state sharding cannot be expressed as a Shardy rule. We use with_sharding_constraint
        # instead. This has to happen outside of the primitive, see `fused_attn_fwd`.
        rng_sharding = (f"…{len(value_types)}",)

        if config.qkv_layout.is_qkvpacked():
            input_spec[0] = ("…0", "seqlen", "three", "head", "hidden")
        elif config.qkv_layout.is_kvpacked() or config.qkv_layout.is_separate():
            input_spec[0] = ("…0", "seqlen", "head", "hidden")
        else:
            raise ValueError(f"Unsupported {config.qkv_layout=}")

        is_packed_softmax = get_cudnn_version() >= (9, 6, 0) and config.qkv_layout.is_thd()
        out_sharding = ("…0", "seqlen", "head", "hidden")
        if is_packed_softmax:
            softmax_aux_sharding = ("…0", "seqlen", "head", "i")
        else:
            softmax_aux_sharding = ("…0", "head", "seqlen", "i")

        return SdyShardingRule(
            tuple(input_spec), (out_sharding, softmax_aux_sharding, rng_sharding)
        )

734
735
736
737
738
739
740
741

register_primitive(FusedAttnFwdPrimitive)


class FusedAttnBwdPrimitive(BasePrimitive):
    """
    Fused Attention Backward Primitive
    """
742

743
    name = "te_fused_attn_backward_ffi"
744
    multiple_results = True
745
    impl_static_args = (17,)
746
747
748
749
    inner_primitive = None
    outer_primitive = None

    @staticmethod
750
751
752
753
754
    def abstract(
        q_aval,
        k_aval,
        v_aval,
        bias_aval,
755
        softmax_offset_aval,
756
757
758
759
        softmax_aux_aval,
        rng_state_aval,
        output_aval,
        doutput_aval,
760
761
762
763
        q_seqlen_or_cu_seqlen_aval,
        kv_seqlen_or_cu_seqlen_aval,
        _q_seq_offsets,
        _k_seq_offsets,
764
765
766
767
        _q_segment_ids,
        _kv_segment_ids,
        _q_segment_pos,
        _kv_segment_pos,
768
        *,
769
        config,
770
    ):
771
772
773
774
775
776
777
778
779
780
781
        """
        Fused attention bwd abstract
        """
        del softmax_aux_aval, rng_state_aval, output_aval

        q_dtype = dtypes.canonicalize_dtype(q_aval.dtype)
        k_dtype = dtypes.canonicalize_dtype(k_aval.dtype)
        v_dtype = dtypes.canonicalize_dtype(v_aval.dtype)
        bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype)
        doutput_dtype = dtypes.canonicalize_dtype(doutput_aval.dtype)
        assert q_dtype == k_dtype == v_dtype == bias_dtype == doutput_dtype
782
        assert q_seqlen_or_cu_seqlen_aval.dtype == kv_seqlen_or_cu_seqlen_aval.dtype
783

784
785
786
787
788
789
790
791
792
        (
            batch_shape,
            q_max_seqlen,
            kv_max_seqlen,
            attn_heads,
            num_gqa_groups,
            qk_head_dim,
            v_head_dim,
        ) = FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout)
793

Reese Wang's avatar
Reese Wang committed
794
        if config.attn_bias_type == AttnBiasType.NO_BIAS:
795
796
797
798
799
            bias_batch = bias_heads = 0
        else:
            *bias_batch_shape, bias_heads, _, _ = bias_aval.shape
            bias_batch = reduce(operator.mul, bias_batch_shape)

800
801
        deterministic = not FusedAttnHelper.is_non_deterministic_allowed()

802
        input_batch = reduce(operator.mul, batch_shape)
803
804
805
806
807
808
809
810
        wkspace_shape, wkspace_dtype = transformer_engine_jax.get_fused_attn_bwd_workspace_sizes(
            input_batch,
            bias_batch,
            q_max_seqlen,
            kv_max_seqlen,
            attn_heads,
            num_gqa_groups,
            bias_heads,
811
812
            qk_head_dim,
            v_head_dim,
813
814
            config.scaling_factor,
            config.dropout_probability,
Reese Wang's avatar
Reese Wang committed
815
816
            config.attn_bias_type.value,
            config.attn_mask_type.value,
817
            config.softmax_type.value,
Reese Wang's avatar
Reese Wang committed
818
            config.qkv_layout.value,
819
            jax_dtype_to_te_dtype(q_aval.dtype),
820
            config.is_training,
821
            deterministic,
822
            config.max_segments_per_seq,
823
824
            config.window_size[0],
            config.window_size[1],
825
            config.bottom_right_diagonal,
826
        )
827
828
829
830
831

        dq_aval = q_aval.update(shape=q_aval.shape, dtype=q_dtype)
        dk_aval = k_aval.update(shape=k_aval.shape, dtype=k_dtype)
        dv_aval = v_aval.update(shape=v_aval.shape, dtype=v_dtype)
        dbias_aval = bias_aval.update(shape=bias_aval.shape, dtype=bias_dtype)
832
833
834
        wkspace_aval = q_aval.update(
            shape=wkspace_shape, dtype=te_dtype_to_jax_dtype(wkspace_dtype)
        )
835

836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
        # Validate incoming softmax_offset shape and dtype
        assert (
            softmax_offset_aval.dtype == jnp.float32
        ), f"Incorrect softmax_offset dtype: {softmax_offset_aval.dtype}, expected: {jnp.float32}"
        if config.softmax_type != AttnSoftmaxType.VANILLA_SOFTMAX:
            assert softmax_offset_aval.shape == (1, attn_heads, 1, 1), (
                f"Incorrect softmax_offset shape for {config.softmax_type}:"
                f" {softmax_offset_aval.shape}, expected: (1, {attn_heads}, 1, 1)"
            )
        else:
            assert softmax_offset_aval.shape == (0,), (
                f"Incorrect softmax_offset shape for {config.softmax_type}:"
                f" {softmax_offset_aval.shape}, expected: (0,)"
            )

        if config.softmax_type == AttnSoftmaxType.VANILLA_SOFTMAX:
            dsoftmax_offset_aval = q_aval.update(
                shape=softmax_offset_aval.shape, dtype=softmax_offset_aval.dtype
            )
        else:
            dsoftmax_offset_aval = q_aval.update(shape=(1, attn_heads, 1, 1), dtype=jnp.float32)

        return dq_aval, dk_aval, dv_aval, dbias_aval, dsoftmax_offset_aval, wkspace_aval
859
860
861
862
863
864

    @staticmethod
    def outer_abstract(*args, **kwargs):
        """
        Fused attention fwd outer primitive abstract
        """
865
866
867
868
        dq_aval, dk_aval, dv_aval, dbias_aval, dsoftmax_offset_aval, _ = (
            FusedAttnBwdPrimitive.abstract(*args, **kwargs)
        )
        return dq_aval, dk_aval, dv_aval, dbias_aval, dsoftmax_offset_aval
869
870

    @staticmethod
871
872
873
874
875
876
    def lowering(
        ctx,
        q,
        k,
        v,
        bias,
877
        softmax_offset,
878
879
880
881
882
883
        softmax_aux,
        rng_state,
        output,
        doutput,
        q_cu_seqlen,
        kv_cu_seqlen,
884
885
        q_seq_offsets,
        k_seq_offsets,
886
887
888
889
        q_segment_ids,
        kv_segment_ids,
        q_segment_pos,
        kv_segment_pos,
890
        *,
891
        config,
892
    ):
893
894
895
896
897
        """
        Fused attention bwd lowering rules
        """
        q_aval, k_aval, v_aval, bias_aval, *_ = ctx.avals_in

898
899
900
901
902
903
904
905
906
        (
            batch_shape,
            q_max_seqlen,
            kv_max_seqlen,
            attn_heads,
            num_gqa_groups,
            qk_head_dim,
            v_head_dim,
        ) = FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout)
907
908
909

        input_batch = reduce(operator.mul, batch_shape)

Reese Wang's avatar
Reese Wang committed
910
        if config.attn_bias_type == AttnBiasType.NO_BIAS:
911
912
913
914
915
            bias_batch = bias_heads = 0
        else:
            *bias_batch_shape, bias_heads, _, _ = bias_aval.shape
            bias_batch = reduce(operator.mul, bias_batch_shape)

916
917
918
919
920
921
922
        if config.cp_striped_window_size is not None:
            window_size_left = config.cp_striped_window_size[0]
            window_size_right = config.cp_striped_window_size[1]
        else:
            window_size_left = config.window_size[0]
            window_size_right = config.window_size[1]

923
924
925
926
927
928
        return ffi.ffi_lowering(FusedAttnBwdPrimitive.name)(
            ctx,
            q,
            k,
            v,
            bias,
929
            softmax_offset,
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
            softmax_aux,
            rng_state,
            output,
            doutput,
            q_cu_seqlen,
            kv_cu_seqlen,
            q_seq_offsets,
            k_seq_offsets,
            q_segment_ids,
            kv_segment_ids,
            q_segment_pos,
            kv_segment_pos,  # ffi_lowering needs number of parameters meets primitive.lowering
            input_batch=input_batch,
            bias_batch=bias_batch,
            q_max_seqlen=q_max_seqlen,
            kv_max_seqlen=kv_max_seqlen,
            attn_heads=attn_heads,
            num_gqa_groups=num_gqa_groups,
            bias_heads=bias_heads,
949
950
            qk_head_dim=qk_head_dim,
            v_head_dim=v_head_dim,
951
952
953
954
955
956
957
958
            max_segments_per_seq=config.max_segments_per_seq,
            scaling_factor=float(config.scaling_factor),
            dropout_probability=float(config.dropout_probability),
            bias_type=int(config.attn_bias_type.value),
            mask_type=int(config.attn_mask_type.value),
            qkv_layout=int(config.qkv_layout.value),
            is_training=config.is_training,
            deterministic=not FusedAttnHelper.is_non_deterministic_allowed(),
959
960
            window_size_left=window_size_left,
            window_size_right=window_size_right,
961
            bottom_right_diagonal=config.bottom_right_diagonal,
962
            softmax_type=int(config.softmax_type.value),
963
        )
964
965

    @staticmethod
966
967
968
969
970
    def impl(
        q,
        k,
        v,
        bias,
971
        softmax_offset,
972
973
974
975
976
977
        softmax_aux,
        rng_state,
        output,
        doutput,
        q_seqlen,
        kv_seqlen,
978
979
        q_seq_offsets,
        k_seq_offsets,
980
981
982
983
        _q_segment_ids,
        _kv_segment_ids,
        _q_segment_pos,
        _kv_segment_pos,
984
        config,
985
    ):
986
987
        assert FusedAttnBwdPrimitive.inner_primitive is not None

988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
        sequence_descriptor = SequenceDescriptor(
            seqlens=(q_seqlen, kv_seqlen),
            seq_offsets=(q_seq_offsets, k_seq_offsets),
            segment_ids=(_q_segment_ids, _kv_segment_ids),
            segment_pos=(_q_segment_pos, _kv_segment_pos),
        )

        (q_seqlen, kv_seqlen), (q_seq_offsets, k_seq_offsets) = (
            sequence_descriptor.get_seqlens_and_offsets(
                config.attn_mask_type,
                config.qkv_layout,
                config.window_size,
                config.max_segments_per_seq,
            )
        )

Reese Wang's avatar
Reese Wang committed
1004
        if config.qkv_layout.is_thd():
1005

1006
            def _fix_len_take(x, condition, fill_value=-1):
1007
1008
1009
1010
1011
                x_shape = x.shape
                x = x.flatten()
                size = x.size
                indices = jnp.nonzero(condition.flatten(), size=size, fill_value=size)[0]
                # TODO(rewang): try indices_are_sorted
1012
                y = jnp.take(x, indices, fill_value=fill_value)
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
                return jnp.reshape(y, x_shape)

            def convert_to_2d(offsets, batch, max_seqlen):
                offsets_2d = jnp.where(
                    offsets >= 0,
                    offsets + (jnp.arange(batch) * max_seqlen)[..., jnp.newaxis],
                    offsets,
                )
                return offsets_2d

Reese Wang's avatar
Reese Wang committed
1023
1024
1025
1026
1027
            batch, q_max_seqlen, kv_max_seqlen, *_ = FusedAttnHelper.parse_qkv_aval(
                q, k, v, config.qkv_layout
            )
            assert len(batch) == 1
            kv_batch = q_batch = batch[0]
1028
1029

            # Gather valid q_seqlen, which is greater than 0
1030
            # cuDNN version < 9.3.0:
1031
            # [[3, 5, 7, -1, -1], [2, 4, 6, -1, -1]] -> [[3, 5, 7, 2, 4], [6, -1, -1, -1, -1]]
1032
1033
1034
1035
1036
1037
1038
1039
            # cuDNN version >= 9.3.0, which supports act_seqlen = 0
            # [[3, 5, 7, -1, -1], [2, 4, 6, -1, -1]] -> [[3, 5, 7, 2, 4], [6, 0, 0, 0, 0]]
            if get_cudnn_version() >= (9, 3, 0):
                fill_value = 0
            else:
                fill_value = -1
            q_seqlen = _fix_len_take(q_seqlen, q_seqlen > 0, fill_value=fill_value)
            kv_seqlen = _fix_len_take(kv_seqlen, kv_seqlen > 0, fill_value=fill_value)
1040
1041
1042
1043
1044

            # Flatten the offset calculation
            # max_seqlen = 8, [[0, 3, 5, -1], [0, 2, 4, -1]] -> [[0, 3, 5, -1], [8, 11, 13, -1]]
            q_seq_offsets = convert_to_2d(q_seq_offsets, q_batch, q_max_seqlen)
            k_seq_offsets = convert_to_2d(k_seq_offsets, kv_batch, kv_max_seqlen)
1045

1046
1047
            # Gather valid q_seq_offsets, which is greater and equal to 0
            # [[0, 3, 5, -1], [8, 11, 13, -1]] -> [[0, 3, 5, 8], [11, 13, -1, -1]]
1048
            # And set the unused position to max size (batch * max_seqlen)
1049
            # [[0, 3, 5, 8], [11, 13, -1, -1]] -> [[0, 3, 5, 8], [11, 13, b*s, b*s]]
1050
1051
1052
1053
1054
1055
            q_seq_offsets = _fix_len_take(
                q_seq_offsets, q_seq_offsets >= 0, fill_value=q_batch * q_max_seqlen
            )
            k_seq_offsets = _fix_len_take(
                k_seq_offsets, k_seq_offsets >= 0, fill_value=kv_batch * kv_max_seqlen
            )
1056
1057
1058

        q_cu_seqlen = generate_cu_seqlen(q_seqlen.flatten())
        kv_cu_seqlen = generate_cu_seqlen(kv_seqlen.flatten())
1059

1060
        dq, dk, dv, dbias, dsoftmax_offset, _ = FusedAttnBwdPrimitive.inner_primitive.bind(
1061
1062
1063
1064
            q,
            k,
            v,
            bias,
1065
            softmax_offset,
1066
1067
1068
1069
1070
1071
            softmax_aux,
            rng_state,
            output,
            doutput,
            q_cu_seqlen,
            kv_cu_seqlen,
1072
1073
            q_seq_offsets,
            k_seq_offsets,
1074
1075
1076
1077
            _q_segment_ids,
            _kv_segment_ids,
            _q_segment_pos,
            _kv_segment_pos,
1078
            config=config,
1079
        )
1080
        return dq, dk, dv, dbias, dsoftmax_offset
1081
1082

    @staticmethod
1083
    def batcher(batched_args, batch_dims, *, config):
1084
1085
        check_valid_batch_dims(batch_dims)
        assert FusedAttnBwdPrimitive.outer_primitive is not None
1086
        q_bdim, k_bdim, v_bdim, bias_bdim, softmax_offset_bdim, *_ = batch_dims
1087

1088
        out_bdims = q_bdim, k_bdim, v_bdim, bias_bdim, softmax_offset_bdim
1089
        return (
1090
            FusedAttnBwdPrimitive.outer_primitive.bind(*batched_args, config=config),
1091
1092
            out_bdims,
        )
1093
1094

    @staticmethod
1095
1096
    def infer_sharding_from_operands(config, mesh, arg_infos, result_infos):
        del config, result_infos
1097
1098
1099
1100
        q_spec = get_padded_spec(arg_infos[0])
        k_spec = get_padded_spec(arg_infos[1])
        v_spec = get_padded_spec(arg_infos[2])
        bias_spec = get_padded_spec(arg_infos[3])
1101
        softmax_offset_spec = get_padded_spec(arg_infos[4])
1102
1103
1104
1105
        dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
        dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec))
        dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec))
        dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec))
1106
1107
        dsoftmax_offset_sharding = NamedSharding(mesh, PartitionSpec(*softmax_offset_spec))
        return (dq_sharding, dk_sharding, dv_sharding, dbias_sharding, dsoftmax_offset_sharding)
1108
1109

    @staticmethod
1110
    def partition(config, mesh, arg_infos, result_infos):
1111
1112
1113
1114
1115
        del result_infos
        q_spec = get_padded_spec(arg_infos[0])
        k_spec = get_padded_spec(arg_infos[1])
        v_spec = get_padded_spec(arg_infos[2])
        bias_spec = get_padded_spec(arg_infos[3])
1116
        softmax_offset_spec = get_padded_spec(arg_infos[4])
1117
1118
1119
1120
        dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
        dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec))
        dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec))
        dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec))
1121
        dsoftmax_offset_sharding = NamedSharding(mesh, PartitionSpec(*softmax_offset_spec))
1122
1123
1124
1125
        arg_shardings = [arg_i.sharding for arg_i in arg_infos]
        arg_shardings[-1] = arg_shardings[-3]
        arg_shardings[-2] = arg_shardings[-4]
        arg_shardings = tuple(arg_shardings)
1126
1127
1128
1129
1130
1131
1132
        out_shardings = (
            dq_sharding,
            dk_sharding,
            dv_sharding,
            dbias_sharding,
            dsoftmax_offset_sharding,
        )
1133

1134
        def sharded_impl(
1135
1136
1137
1138
            q,
            k,
            v,
            bias,
1139
            softmax_offset,
1140
1141
1142
1143
1144
1145
1146
1147
            softmax_aux,
            rng_state,
            output,
            doutput,
            q_cu_seqlen,
            kv_cu_seqlen,
            q_seq_offsets,
            k_seq_offsets,
1148
1149
1150
1151
            _q_segment_ids,
            _kv_segment_ids,
            _q_segment_pos,
            _kv_segment_pos,
1152
        ):
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
            local_dq, local_dk, local_dv, local_dbias, local_dsoftmax_offset = (
                FusedAttnBwdPrimitive.impl(
                    q,
                    k,
                    v,
                    bias,
                    softmax_offset,
                    softmax_aux,
                    rng_state,
                    output,
                    doutput,
                    q_cu_seqlen,
                    kv_cu_seqlen,
                    q_seq_offsets,
                    k_seq_offsets,
                    _q_segment_ids,
                    _kv_segment_ids,
                    _q_segment_pos,
                    _kv_segment_pos,
                    config=config,
                )
1174
            )
1175
            global_dbias = local_dbias
Reese Wang's avatar
Reese Wang committed
1176
            if config.attn_bias_type is not AttnBiasType.NO_BIAS:
1177
                global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh)
1178
1179
1180
1181
1182
1183

            global_dsoftmax_offset = local_dsoftmax_offset
            if config.softmax_type == AttnSoftmaxType.LEARNABLE_SOFTMAX:
                global_dsoftmax_offset = all_reduce_sum_along_dp_fsdp(local_dsoftmax_offset, mesh)

            return local_dq, local_dk, local_dv, global_dbias, global_dsoftmax_offset
1184
1185
1186

        return mesh, sharded_impl, out_shardings, arg_shardings

1187
1188
1189
1190
1191
1192
1193
1194
    @staticmethod
    def shardy_sharding_rule(config, mesh, value_types, result_types):
        del config, mesh
        # Keep in sync with `infer_sharding_from_operands`.
        input_spec = tuple((f"…{x}",) for x in range(len(value_types)))
        output_spec = tuple((f"…{x}",) for x in range(len(result_types)))
        return SdyShardingRule(input_spec, output_spec)

1195
1196
1197
1198

register_primitive(FusedAttnBwdPrimitive)


Reese Wang's avatar
Reese Wang committed
1199
def reorder_causal_dual_chunk_swap(tensor, cp_size: int, seq_dim: int, to_contiguous: bool):
1200
1201
1202
1203
1204
1205
1206
1207
1208
    """Reorders a tensor for load balancing the compute of causal attention."""
    if cp_size == 1:
        return tensor

    if cp_size % 2 != 0:
        raise ValueError(f"{cp_size=} must be a multiple of 2.")

    # Need to ensure we have 2 pairs to swap for balancing between cp ranks
    if tensor.shape[seq_dim] % (cp_size * 2) != 0:
Reese Wang's avatar
Reese Wang committed
1209
        raise ValueError(f"{tensor.shape[seq_dim]=} is not a multiple of {cp_size*2=}")
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250

    # [B, S, H, D] -> [B, 2*cp_size, S/2*cp_size, D]
    # [S, B, H, D] -> [2*cp_size, S/2*cp_size, B, H, D]
    ori_tensor_shape = tensor.shape
    tensor = tensor.reshape(
        (
            *ori_tensor_shape[:seq_dim],
            2 * cp_size,
            ori_tensor_shape[seq_dim] // (2 * cp_size),
            *ori_tensor_shape[seq_dim + 1 :],
        )
    )

    parts = []
    if not to_contiguous:
        for cp_rank in range(cp_size):
            # [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] -> [B, 2, S/2*cp_size, H, D]
            # [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] -> [2, S/2*cp_size, B, H, D]
            index = jnp.array([cp_rank, (2 * cp_size - cp_rank - 1)])
            parts.append(jnp.take(tensor, index, axis=seq_dim))
    else:
        for cp_rank in range(cp_size // 2):
            # [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] -> [B, 2, S/2*cp_size, H, D]
            # [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] -> [2, S/2*cp_size, B, H, D]
            base = 4 * cp_rank
            index = jnp.array([base, base + 2])
            parts.append(jnp.take(tensor, index, axis=seq_dim))
        for cp_rank in range(cp_size // 2):
            # [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] -> [B, 2, S/2*cp_size, H, D]
            # [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] -> [2, S/2*cp_size, B, H, D]
            base = 2 * cp_size - 1 - 4 * cp_rank
            index = jnp.array([base, base - 2])
            parts.append(jnp.take(tensor, index, axis=seq_dim))

    # [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D]
    # [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D]
    combined = jnp.stack(parts, axis=seq_dim)

    return combined.reshape(ori_tensor_shape)


1251
1252
1253
def reorder_causal_striped(
    tensor, cp_size: int, seq_dim: int, is_inverse: bool, stripe_size: int = 1
):
Reese Wang's avatar
Reese Wang committed
1254
1255
    """Reorders a tensor for load balancing with striped pattern"""
    origin_shape = tensor.shape
1256
1257
1258
1259
1260
1261
    if stripe_size <= 0:
        raise ValueError(
            f"Incorrect value for CP reordering {stripe_size=}. stripe_size must be a positive"
            " integer"
        )
    if origin_shape[seq_dim] % (cp_size * stripe_size) != 0:
Reese Wang's avatar
Reese Wang committed
1262
        raise ValueError(
1263
1264
            "Expected origin_shape[seq_dim] is multiple of cp_size*stripe_size but got"
            f" {origin_shape[seq_dim]=}, {cp_size=}, {stripe_size=}, {cp_size*stripe_size=}"
Reese Wang's avatar
Reese Wang committed
1265
1266
1267
1268
1269
        )

    if not is_inverse:
        new_shape = [
            *origin_shape[:seq_dim],
1270
            *[origin_shape[seq_dim] // (cp_size * stripe_size), cp_size, stripe_size],
Reese Wang's avatar
Reese Wang committed
1271
1272
1273
1274
1275
            *origin_shape[seq_dim + 1 :],
        ]
    else:
        new_shape = [
            *origin_shape[:seq_dim],
1276
            *[cp_size, origin_shape[seq_dim] // (cp_size * stripe_size), stripe_size],
Reese Wang's avatar
Reese Wang committed
1277
1278
1279
            *origin_shape[seq_dim + 1 :],
        ]

1280
1281
1282
    striped_tensor = tensor.reshape(new_shape)
    reordered_striped_tensor = jnp.swapaxes(striped_tensor, seq_dim, seq_dim + 1)
    return reordered_striped_tensor.reshape(origin_shape)
Reese Wang's avatar
Reese Wang committed
1283
1284


1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
@dataclass(frozen=True)
class _FusedAttnCPWithAllGatherHelper:
    """Helper class to assist with running the all-gather strategy for CP attention."""

    mesh: jax.sharding.Mesh
    config: _FusedAttnConfig

    def check_supported(self):
        """Checks if the context parallel implementation is supported by the given arguments."""
        header = "Context parallel fused attention"

1296
1297
1298
1299
1300
1301
        allowed_layouts = [
            QKVLayout.BSHD_BS2HD,
            QKVLayout.BSHD_BSHD_BSHD,
            QKVLayout.THD_T2HD,
            QKVLayout.THD_THD_THD,
        ]
1302
1303
1304
        if self.config.qkv_layout not in allowed_layouts:
            raise ValueError(
                f"{header} only supports layouts:"
1305
                f" {','.join(map(str, allowed_layouts))} got: {self.config.qkv_layout}"
1306
            )
1307

1308
1309
1310
1311
1312
1313
1314
1315
        if (not self.config.qkv_layout.is_thd() and self.config.stripe_size is not None) or (
            self.config.qkv_layout.is_thd() and self.config.stripe_size is None
        ):
            raise ValueError(
                f"{header} only supports Dual Chunk load balancing with BSHD layouts and Striped"
                " load balancing with THD layouts"
            )

Reese Wang's avatar
Reese Wang committed
1316
        if self.config.attn_bias_type != AttnBiasType.NO_BIAS:
1317
            raise ValueError(f"{header} does not support bias got: {self.config.attn_bias_type}")
1318

Reese Wang's avatar
Reese Wang committed
1319
        allowed_masks = [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]
1320
1321
        if self.config.qkv_layout.is_thd():
            allowed_masks.append(AttnMaskType.PADDING_CAUSAL_MASK)
1322
1323
1324
        if self.config.attn_mask_type not in allowed_masks:
            raise ValueError(
                f"{header} only supports masking types: "
1325
                f" {','.join(map(str, allowed_masks))} got: {self.config.attn_mask_type}"
1326
            )
1327
1328
1329
1330
1331
1332
        # Do not allow CP + AG + THD + Striped with NO_MASK
        if (
            self.config.attn_mask_type is not AttnMaskType.PADDING_CAUSAL_MASK
            and self.config.qkv_layout.is_thd()
        ):
            raise ValueError(f"{header} only supports PADDING_CAUSAL_MASK for THD types")
1333

1334
        if self.config.max_segments_per_seq != 1 and (not self.config.qkv_layout.is_thd):
1335
            raise ValueError(
1336
                f"{header} only supports max_segments_per_seq == 1 for BSHD layouts, got:"
1337
1338
1339
1340
1341
                f" {self.config.max_segments_per_seq}"
            )

        if self.config.dropout_probability != 0.0:
            raise ValueError(f"{header} does not support dropout")
1342

1343
1344
1345
1346
1347
        if self.config.softmax_type != AttnSoftmaxType.VANILLA_SOFTMAX:
            raise ValueError(
                f"{header} only supports VANILLA_SOFTMAX, got: {self.config.softmax_type}"
            )

1348
1349
    def get_adjusted_mask(self):
        """Converts the mask for context parallelism."""
1350
1351
1352
1353
        if (
            self.config.attn_mask_type == AttnMaskType.CAUSAL_MASK
            and not self.config.qkv_layout.is_thd()
        ):  # BSHD AG case only
Reese Wang's avatar
Reese Wang committed
1354
            return AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK
1355
1356
1357
1358
1359
        if (
            self.config.attn_mask_type == AttnMaskType.PADDING_CAUSAL_MASK
            and self.config.qkv_layout.is_thd()
        ):  # THD AG case only
            return AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK
1360
1361
        return self.config.attn_mask_type

1362
1363
1364
1365
1366
1367
1368
    def get_adjusted_max_segments_per_seq(self, max_seqlen, cp_size):
        """Converts the max segments per seq for context parallelism AG + THD."""
        # Estimating adjusted max segments per seq
        return (
            max_seqlen // (self.config.stripe_size * cp_size)
        ) + self.config.max_segments_per_seq

1369
1370
    def get_step_config(self) -> _FusedAttnConfig:
        """Returns a _FusedAttnConfig for single CP step call to fused attention."""
1371
        adjusted_mask = self.get_adjusted_mask()
1372
1373
        return _FusedAttnConfig(
            attn_bias_type=self.config.attn_bias_type,
1374
            attn_mask_type=adjusted_mask,
1375
            softmax_type=self.config.softmax_type,
1376
1377
1378
1379
1380
1381
            qkv_layout=self.config.qkv_layout,
            scaling_factor=self.config.scaling_factor,
            dropout_probability=self.config.dropout_probability,
            is_training=self.config.is_training,
            max_segments_per_seq=self.config.max_segments_per_seq,
            window_size=self.config.window_size,
1382
            bottom_right_diagonal=adjusted_mask.is_bottom_right(),
1383
1384
            context_parallel_load_balanced=self.config.context_parallel_load_balanced,
            cp_axis=self.config.cp_axis,
1385
            cp_striped_window_size=None,
1386
1387
1388
1389
1390
            stripe_size=self.config.stripe_size,
        )

    def get_step_config_for_striped(self, max_seqlen, cp_size) -> _FusedAttnConfig:
        """Returns a _FusedAttnConfig for single CP step call (made via a striped AG primitive) to fused attention."""
1391
        adjusted_mask = self.get_adjusted_mask()
1392
1393
        return _FusedAttnConfig(
            attn_bias_type=self.config.attn_bias_type,
1394
            attn_mask_type=adjusted_mask,
1395
1396
1397
1398
1399
1400
1401
            softmax_type=self.config.softmax_type,
            qkv_layout=self.config.qkv_layout,
            scaling_factor=self.config.scaling_factor,
            dropout_probability=self.config.dropout_probability,
            is_training=self.config.is_training,
            max_segments_per_seq=self.get_adjusted_max_segments_per_seq(max_seqlen, cp_size),
            window_size=self.config.window_size,
1402
            bottom_right_diagonal=adjusted_mask.is_bottom_right(),
1403
1404
1405
1406
            context_parallel_load_balanced=self.config.context_parallel_load_balanced,
            cp_axis=self.config.cp_axis,
            cp_striped_window_size=None,
            stripe_size=self.config.stripe_size,
1407
1408
        )

1409
    def all_gather_kv(self, k, v):
1410
        """Performs an all-gather of k and v over context parallel ranks."""
1411
1412

        def ag(x):
1413
            x = lax_paral_op(
1414
1415
                x, lax.all_gather, self.config.cp_axis, mesh=self.mesh, axis=1, tiled=True
            )
1416
1417
            if self.config.context_parallel_load_balanced:
                cp_size = get_mesh_axis_size(self.config.cp_axis, self.mesh)
1418
1419
1420
1421
                if self.config.qkv_layout.is_thd():
                    x = reorder_causal_striped(x, cp_size, 1, True, self.config.stripe_size)
                else:
                    x = reorder_causal_dual_chunk_swap(x, cp_size, 1, to_contiguous=True)
1422
            return x
1423

Reese Wang's avatar
Reese Wang committed
1424
1425
1426
1427
        if self.config.qkv_layout.is_kvpacked():
            return ag(k), v
        if self.config.qkv_layout.is_separate():
            return ag(k), ag(v)
1428
1429
1430

        return k, v  # fall through

1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
    def all_gather_segment_ids_and_pos(self, kv_segment_ids, kv_segment_pos):
        """Performs an all-gather of kv segment ids and kv segment pos over context parallel ranks."""
        kv_segment_ids = lax_paral_op(
            kv_segment_ids, lax.all_gather, self.config.cp_axis, mesh=self.mesh, axis=1, tiled=True
        )
        kv_segment_pos = lax_paral_op(
            kv_segment_pos, lax.all_gather, self.config.cp_axis, mesh=self.mesh, axis=1, tiled=True
        )
        if self.config.context_parallel_load_balanced:
            cp_size = get_mesh_axis_size(self.config.cp_axis, self.mesh)
            if self.config.qkv_layout.is_thd():
                kv_segment_ids_ag = reorder_causal_striped(
                    kv_segment_ids, cp_size, 1, True, self.config.stripe_size
                )
                kv_segment_pos_ag = reorder_causal_striped(
                    kv_segment_pos, cp_size, 1, True, self.config.stripe_size
                )
                return kv_segment_ids_ag, kv_segment_pos_ag
        return kv_segment_ids, kv_segment_pos  # fall through

1451
1452
1453
1454
    def reduce_scatter_dkv(self, dk, dv):
        """Performs a reduce-scatter of dk and dv over context parallel ranks."""

        def rs(x):
1455
1456
            if self.config.context_parallel_load_balanced:
                cp_size = get_mesh_axis_size(self.config.cp_axis, self.mesh)
1457
1458
1459
1460
                if self.config.qkv_layout.is_thd():
                    x = reorder_causal_striped(x, cp_size, 1, False, self.config.stripe_size)
                else:
                    x = reorder_causal_dual_chunk_swap(x, cp_size, 1, to_contiguous=False)
1461

1462
1463
1464
1465
1466
1467
1468
1469
1470
            return lax_paral_op(
                x,
                lax.psum_scatter,
                self.config.cp_axis,
                mesh=self.mesh,
                scatter_dimension=1,
                tiled=True,
            )

Reese Wang's avatar
Reese Wang committed
1471
1472
1473
1474
        if self.config.qkv_layout.is_kvpacked():
            return rs(dk), dv
        if self.config.qkv_layout.is_separate():
            return rs(dk), rs(dv)
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510

        return dk, dv  # fall through

    def kv_seqlens_for_rank(self, cp_rank, kv_max_seqlen, kv_seqlen_per_subrank):
        """Returns sequence lengths of KV to use for each sub rank of the given cp_rank.

        Example: CP=4, MaxLen = 1024, Unbalanced
           cp_rank 0: [128, 256]
           cp_rank 1: [384, 512]
           cp_rank 2: [640, 768]
           cp_rank 3: [896, 1024]

        Example: CP=4, MaxLen = 1024, Balanced
           cp_rank 0: [128, 1024]
           cp_rank 1: [256, 896]
           cp_rank 2: [384, 768]
           cp_rank 3: [512, 640]
        """
        if self.config.context_parallel_load_balanced:
            kv_seq_this_rank = [
                (cp_rank + 1) * kv_seqlen_per_subrank,
                kv_max_seqlen - cp_rank * kv_seqlen_per_subrank,
            ]
        else:
            kv_seq_this_rank = [
                (cp_rank * 2 + 1) * kv_seqlen_per_subrank,
                (cp_rank * 2 + 2) * kv_seqlen_per_subrank,
            ]
        return kv_seq_this_rank

    def slice_kv(self, k, v, slice_seq_len):
        """Slices k and v tensors to a sequence length of slice_seq_len."""

        def sliced(x):
            return lax.dynamic_slice_in_dim(x, 0, slice_seq_len, axis=1)

Reese Wang's avatar
Reese Wang committed
1511
1512
1513
1514
        if self.config.qkv_layout.is_kvpacked():
            return sliced(k), v
        if self.config.qkv_layout.is_separate():
            return sliced(k), sliced(v)
1515
1516
1517
1518
1519
1520
1521
1522
1523

        return k, v  # fall through

    def pad_kv(self, dk, dv, pad_seq_len):
        """Pads dk and dv tensors to a sequence length of pad_seq_len."""

        def pad(x, npad):
            return jnp.pad(x, npad, "constant", constant_values=0.0)

Reese Wang's avatar
Reese Wang committed
1524
1525
1526
1527
1528
1529
        if self.config.qkv_layout.is_kvpacked():
            npad = [[0, 0], [0, pad_seq_len], [0, 0], [0, 0], [0, 0]]
            return pad(dk, npad), dv
        if self.config.qkv_layout.is_separate():
            npad = [[0, 0], [0, pad_seq_len], [0, 0], [0, 0]]
            return pad(dk, npad), pad(dv, npad)
1530
1531
1532

        return dk, dv  # fall through

1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
    # Below are the sharded post AG q seg ids and pos for a given rank:
    # q_segment_ids = [[1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2]]
    # q_segment_pos = [[0, 1, 2, 3, 16, 17, 18, 19, 11, 12, 13, 14, 27, 28, 29, 30]]
    # max_segments_per_seq = 7
    # Below are some intermediate representations:
    # non_zero_indices = [[ 0,  1,  2,  3,  8,  9, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1]]
    # segment_changes = [[ True, False, False, False,  True, False, False, False,  True, False, False, False,  True,  True,  True,  True]]
    # seqlens_pre = [[1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 0, 0, 0, 0]]
    # seqlens_all_pad_neg = [[ 4,  4,  4, -1, -1, -1, -1]]
    def q_seqlens_for_striped_for_rank(self, q_segment_ids, q_segment_pos, max_segments_per_seq):
        """Extract the q seqlens for striped primitive (post AG) from the sharded q seg ids and seg pos"""
        # Create mask for non-zero seg ids and get the non-zero indices associated with the same
        non_zero_mask = q_segment_ids != 0
        max_size = q_segment_ids.shape[-1]
        non_zero_indices = jax.vmap(
            lambda mask_row: jnp.where(mask_row, size=max_size, fill_value=-1)[0]
        )(non_zero_mask)

        # Pick non-zero seg ids and seg pos using take_along_axis to index within the seg ids and pos
        # Clip -1 to 0 for safe indexing
        clipped_indices = jnp.clip(non_zero_indices, 0, None)
        valid_segment_ids = jnp.where(
            non_zero_indices >= 0, jnp.take_along_axis(q_segment_ids, clipped_indices, axis=-1), 0
        )
        valid_segment_pos = jnp.where(
            non_zero_indices >= 0, jnp.take_along_axis(q_segment_pos, clipped_indices, axis=-1), 0
        )
        # Create a mask for actual valid entries (not padding)
        actual_valid = valid_segment_ids != 0
        # First element is True only if it's actually valid
        first_is_segment = actual_valid[..., 0:1]

        # Detect segment breaks in the valid tokens only (not full seq)
        # Padding will always be true as the segment change condition is being applied
        # on the valid segments (which have padding at the end so they'll always trigger True)
        segment_changes = jnp.concatenate(
            [
                first_is_segment,  # First valid element starts a segment
                (valid_segment_ids[..., 1:] != valid_segment_ids[..., :-1])
                | (valid_segment_pos[..., 1:] != valid_segment_pos[..., :-1] + 1),
            ],
            axis=-1,
        )
        new_segment_ids = jnp.cumsum(segment_changes, axis=-1)
        seqlens_pre = jax.vmap(
            lambda av_row, nsi_row: jnp.where(av_row, nsi_row, 0).astype(jnp.int32)
        )(actual_valid, new_segment_ids)
        seqlens_all = jax.vmap(
            lambda sp_row: jnp.bincount(sp_row, length=max_segments_per_seq + 1)[1:]
        )(seqlens_pre)
        seqlens_all_pad_neg = jnp.where(seqlens_all == 0, -1, seqlens_all)
        return seqlens_all_pad_neg

    # Below are the sharded post AG q seg ids and pos for a given rank:
    # q_segment_ids = [[1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2]]
    # q_segment_pos = [[0, 1, 2, 3, 16, 17, 18, 19, 11, 12, 13, 14, 27, 28, 29, 30]]
    # max_segments_per_seq = 7
    # Below are some intermediate representations:
    # segment_changes = [[ True, False, False, False,  True, False, False, False,  True, False, False, False,  True, False, False, False]]
    # segment_changes_masked = [[ True, False, False, False, False, False, False, False,  True, False, False, False,  True, False, False, False]]
    # seq_offsets =  [[ 0,  8, 12, -1, -1, -1, -1, -1]]
    def q_seqoffsets_for_striped_for_rank(self, q_segment_ids, q_segment_pos, max_segments_per_seq):
        """Extract the q seqoffets for striped primitive (post AG) from the sharded q seg ids and seg pos"""
        segment_changes = jnp.concatenate(
            [
                jnp.full(
                    (q_segment_pos.shape[0], 1), True, dtype=bool
                ),  # First valid element starts a segment
                (q_segment_pos[..., 1:] != q_segment_pos[..., :-1] + 1),  # Segment pos changed
            ],
            axis=-1,
        )
        # Remove any padded region segment changes
        segment_changes_masked = jnp.where(q_segment_ids != 0, segment_changes, False)
        # Get the indices for segment changes (these are the offsets)
        seq_offsets = jax.vmap(
            lambda scm_row: jnp.where(scm_row, size=max_segments_per_seq, fill_value=-1)[0]
        )(segment_changes_masked)
        return seq_offsets

    # Below are the sharded post AG q seg ids and pos for a given rank:
    # kv_segment_ids = [[1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2]]
    # kv_segment_pos = [[0, 1, 2, 3, 16, 17, 18, 19, 11, 12, 13, 14, 27, 28, 29, 30]]
    # max_segments_per_seq = 7
    # Below are some intermediate representations:
    # non_zero_mask = [[ True,  True,  True,  True, False, False, False, False,  True, True,  True,  True,  True,  True,  True,  True]]
    # non_zero_indices = [[ 0,  1,  2,  3,  8,  9, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1]]
    # segment_changes = [[False, False, False,  True, False, False, False,  True, False, False, False,  True,  True,  True,  True, False]]
    # selected_values = [[ 4, 15, 31, -1, -1, -1, -1, -1]]
    def kv_seqlens_for_striped_for_rank(self, kv_segment_ids, kv_segment_pos, max_segments_per_seq):
        """Extract the kv seqlens for striped primitive (post AG) from the sharded kv seg ids and seg pos"""
        # Create mask for non-zero seg ids and get the non-zero indices associated with the same
        non_zero_mask = kv_segment_ids != 0
        max_size = kv_segment_ids.shape[-1]
        non_zero_indices = jax.vmap(
            lambda mask_row: jnp.where(mask_row, size=max_size, fill_value=-1)[0]
        )(non_zero_mask)

        # Pick non zero seg ids and seg pos using take_along_axis
        # Clip -1 to 0 for safe indexing
        clipped_indices = jnp.clip(non_zero_indices, 0, None)
        valid_segment_ids = jnp.where(
            non_zero_indices >= 0, jnp.take_along_axis(kv_segment_ids, clipped_indices, axis=-1), 0
        )
        valid_segment_pos = jnp.where(
            non_zero_indices >= 0, jnp.take_along_axis(kv_segment_pos, clipped_indices, axis=-1), 0
        )
        actual_valid = valid_segment_ids != 0

        # Detect segment breaks (only for non-zero segments)
        segment_changes = jnp.concatenate(
            [
                (
                    (valid_segment_ids[..., 1:] != valid_segment_ids[..., :-1])
                    & actual_valid[..., 1:]
                )
                | (valid_segment_pos[..., 1:] != valid_segment_pos[..., :-1] + 1),
                actual_valid[..., -1:],
            ],
            axis=-1,
        )
        # Get the indices for segment changes
        segment_changes_valid = jax.vmap(
            lambda sc_row, av_row: jnp.where(
                sc_row & av_row, size=max_segments_per_seq, fill_value=-1
            )[0]
        )(segment_changes, actual_valid)
        safe_indices = jnp.maximum(segment_changes_valid, 0)
        # Select values using take_along_axis per row
        selected_values = jnp.where(
            segment_changes_valid >= 0,
            jnp.take_along_axis(valid_segment_pos, safe_indices, axis=-1) + 1,
            -1,
        )
        return selected_values

    # Below are the sharded post AG q seg ids and pos for a given rank:
    # kv_segment_ids = [[1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2]]
    # kv_segment_pos = [[0, 1, 2, 3, 16, 17, 18, 19, 11, 12, 13, 14, 27, 28, 29, 30]]
    # kv_segment_ids_ag = [[1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    #                       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
    #                       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
    # kv_segment_pos_ag = [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
    #                       18, 19, 20, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14,
    #                       15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
    #                       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
    # max_segments_per_seq = 7
    # Below are some intermediate representations:
    # segment_changes_first_true_masked = [[ True, False, False, False, False, False, False, False, True,
    #                                       False, False, False,  True, False, False, False]]
    # segment_changes_indices = [[ 0,  8, 12, -1, -1, -1, -1, -1, -1]]
    # segment_ids = [[ 1,  2,  2, -1, -1, -1, -1, -1, -1]]
    # segment_changes_ag_first_true_masked = [[ True, False, False, False, False, False, False, False, False,
    #                                               False, False, False, False, False, False, False, False, False,
    #                                               False, False, False,  True, False, False, False, False, False,
    #                                               False, False, False, False, False, False, False, False, False,
    #                                               False, False, False, False, False, False, False, False, False,
    #                                               False, False, False, False, False, False, False, False, False,
    #                                               False, False, False, False, False, False, False, False, False,
    #                                               False]
    # segment_changes_ag_indices = [[ 0, 21, -1, -1, -1, -1, -1, -1, -1]]
    # seq_offsets = [[ 0, 21, 21, -1, -1, -1, -1, -1, -1]]
    def kv_seqoffsets_for_striped_for_rank(
        self,
        kv_segment_pos,
        kv_segment_ids,
        kv_segment_pos_ag,
        kv_segment_ids_ag,
        max_segments_per_seq,
    ):
        """Extract the kv seqoffsets for striped primitive (post AG) from the sharded kv seg ids and seg pos,
        AG kv seg ids and seg pos."""
        # Calculate the segment pos change mask
        segment_changes_first_true = jnp.concatenate(
            [
                jnp.full(
                    (kv_segment_pos.shape[0], 1), True, dtype=bool
                ),  # Assume valid element starts a segment and mask afterwards
                (kv_segment_pos[..., 1:] != kv_segment_pos[..., :-1] + 1),  # Segment pos changed
            ],
            axis=-1,
        )
        segment_changes_first_true_masked = jnp.where(
            kv_segment_ids != 0, segment_changes_first_true, False
        )

        # Get segment change indices for rank
        segment_changes_indices = jax.vmap(
            lambda sc_row: jnp.where(sc_row, size=max_segments_per_seq, fill_value=-1)[0]
        )(segment_changes_first_true_masked)
        # Get segment ids associated with the segment_changes_indices for rank
        segment_ids = jax.vmap(
            lambda sci_row, ksi_row: jnp.where(sci_row >= 0, ksi_row[sci_row], -1)
        )(segment_changes_indices, kv_segment_ids)

        # Get segment change indices for AG
        segment_changes_ag_first_true = jnp.concatenate(
            [
                jnp.full(
                    (kv_segment_pos.shape[0], 1), True, dtype=bool
                ),  # Assume valid element starts a segment and mask afterwards
                (
                    kv_segment_pos_ag[..., 1:] != kv_segment_pos_ag[..., :-1] + 1
                ),  # Segment pos changed
            ],
            axis=-1,
        )
        segment_changes_ag_first_true_masked = jnp.where(
            kv_segment_ids_ag != 0, segment_changes_ag_first_true, False
        )
        # Get segment change indices for AG
        segment_changes_ag_indices = jax.vmap(
            lambda scag_row: jnp.where(scag_row, size=max_segments_per_seq, fill_value=-1)[0]
        )(segment_changes_ag_first_true_masked)

        # Use the segment ids picked per rank to get the offsets from the AG indices
        seq_offsets = jax.vmap(
            lambda si_row, sca_row: jnp.where(si_row > 0, sca_row[si_row - 1], -1)
        )(segment_ids, segment_changes_ag_indices)
        return seq_offsets

1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765

class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive):
    """
    Fused Attention Forward with Context Parallelism Primitive

    This context parallel implementation uses all-gather to collect KV inputs from context parallel ranks.
    """

    @staticmethod
    def partition(config, mesh, arg_infos, result_infos):
        # Call base implementation for non-context parallel mesh to avoid unecessary work.
        is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1
1766
        assert (
1767
            not is_context_parallel or config.window_size[0] == -1
1768
        ), "Sliding window attention is not supported when context parallelism is enabled"
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
        if not is_context_parallel:
            return FusedAttnFwdPrimitive.partition(config, mesh, arg_infos, result_infos)

        helper = _FusedAttnCPWithAllGatherHelper(mesh, config)
        helper.check_supported()

        out_sharding = result_infos[0].sharding
        softmax_aux_sharding = result_infos[1].sharding
        rng_state_sharding = seed_sharding = NamedSharding(
            mesh, PartitionSpec(get_all_mesh_axes(), None)
        )
1780
        arg_shardings = [arg_i.sharding for arg_i in arg_infos]
1781
        arg_shardings[5] = seed_sharding
1782
        arg_shardings = tuple(arg_shardings)
1783
1784
        out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding)

1785
1786
1787
1788
1789
        def impl(
            q,
            k,
            v,
            bias,
1790
            softmax_offset,
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
            seed,
            q_seqlen,
            kv_seqlen,
            q_seq_offsets,
            k_seq_offsets,
            _q_segment_ids,
            _kv_segment_ids,
            _q_segment_pos,
            _kv_segment_pos,
        ):
1801
1802
1803
1804
1805
1806
1807
1808
1809
            cp_size = get_mesh_axis_size(config.cp_axis, mesh)
            cp_rank = get_mesh_axis_rank(config.cp_axis, mesh)

            # cuDNN does not support right-aligned masking with dynamic sequence length padding.
            # Therefore we must explicitly instantiate each CP rank slicing and use a runtime switch
            # to select the appropriate computation. Each case generates a [..., SEQ/CP, ..] tensor
            # meeting the expectation of the SPMD model.
            # TODO(mgoldfarb-nvidia): When cuDNN supports we should be able to make use of a padding
            # mask/sequence length tensor to avoid this unrolled loop.
1810
            def _cross_attn(idx, q, k, v, bias, softmax_offset, q_seqlen, kv_seqlen, seed):
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
                kv_max_seqlen = k.shape[1]
                kv_seqlen_per_subrank = kv_max_seqlen // (cp_size * 2)
                assert kv_max_seqlen % cp_size == 0, "sequence length must evenly divide cp size"

                q_split = jnp.split(q, 2, axis=1)

                kv_seqlens_for_rank = helper.kv_seqlens_for_rank(
                    idx, kv_max_seqlen, kv_seqlen_per_subrank
                )

                results = []
                for sub_idx in range(2):
Reese Wang's avatar
Reese Wang committed
1823
                    if config.attn_mask_type == AttnMaskType.NO_MASK:
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
                        k_unmasked, v_unmasked = k, v  # full kv used for unmasked
                    else:
                        k_unmasked, v_unmasked = helper.slice_kv(k, v, kv_seqlens_for_rank[sub_idx])

                    q_seqlen_for_step = q_seqlen / (cp_size * 2)
                    num_kv_chunks = kv_max_seqlen // kv_seqlens_for_rank[sub_idx]
                    kv_seqlen_for_step = (kv_seqlen / (cp_size * 2)) * num_kv_chunks
                    output, softmax_aux, rng_state = FusedAttnFwdPrimitive.impl(
                        q_split[sub_idx],
                        k_unmasked,
                        v_unmasked,
                        bias,
1836
                        softmax_offset,
1837
                        seed,
1838
1839
1840
1841
                        q_seqlen_for_step,
                        kv_seqlen_for_step,
                        q_seq_offsets,
                        k_seq_offsets,
1842
1843
1844
1845
                        _q_segment_ids,
                        _kv_segment_ids,
                        _q_segment_pos,
                        _kv_segment_pos,
1846
                        config=helper.get_step_config(),
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
                    )
                    results.append((output, softmax_aux, rng_state))

                output = jnp.concatenate((results[0][0], results[1][0]), axis=1)
                softmax_aux = jnp.concatenate((results[0][1], results[1][1]), axis=2)
                rng_state = results[1][2]  # Use the final RNG state

                return output, softmax_aux, rng_state

            k_ag, v_ag = helper.all_gather_kv(k, v)

            functions = [
1859
1860
1861
                partial(
                    _cross_attn, idx, q, k_ag, v_ag, bias, softmax_offset, q_seqlen, kv_seqlen, seed
                )
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
                for idx in range(cp_size)
            ]

            return lax.switch(cp_rank, functions)

        return mesh, impl, out_shardings, arg_shardings


register_primitive(FusedAttnCPWithAllGatherFwdPrimitive)


class FusedAttnCPWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive):
    """
    Fused Attention Backward with Context Parallelism Primitive.

    This context parallel implementation uses all-gather to collect KV and dKV inputs from context parallel ranks.
    The gradients are subsequently reduce-scattered back to each context parallel rank.
    """

    @staticmethod
    def partition(config, mesh, arg_infos, result_infos):
        # Call base implementation for non-context parallel mesh to avoid unecessary work.
        is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1
1885
        assert (
1886
            not is_context_parallel or config.window_size[0] == -1
1887
        ), "Sliding window attention is not supported when context parallelism is enabled"
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
        if not is_context_parallel:
            return FusedAttnBwdPrimitive.partition(config, mesh, arg_infos, result_infos)

        # Ensure we can support this configuration with context parallelism.
        helper = _FusedAttnCPWithAllGatherHelper(mesh, config)
        helper.check_supported()

        del result_infos
        q_spec = get_padded_spec(arg_infos[0])
        k_spec = get_padded_spec(arg_infos[1])
        v_spec = get_padded_spec(arg_infos[2])
        bias_spec = get_padded_spec(arg_infos[3])
1900
        softmax_offset_spec = get_padded_spec(arg_infos[4])
1901
1902
1903
1904
        dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
        dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec))
        dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec))
        dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec))
1905
        dsoftmax_offset_sharding = NamedSharding(mesh, PartitionSpec(*softmax_offset_spec))
1906
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
1907
1908
1909
1910
1911
1912
1913
        out_shardings = (
            dq_sharding,
            dk_sharding,
            dv_sharding,
            dbias_sharding,
            dsoftmax_offset_sharding,
        )
1914
1915
1916
1917
1918
1919

        def impl(
            q,
            k,
            v,
            bias,
1920
            softmax_offset,
1921
1922
1923
1924
1925
1926
1927
1928
            softmax_aux,
            rng_state,
            output,
            doutput,
            q_seqlen,
            kv_seqlen,
            q_seq_offsets,
            k_seq_offsets,
1929
1930
1931
1932
            _q_segment_ids,
            _kv_segment_ids,
            _q_segment_pos,
            _kv_segment_pos,
1933
1934
1935
1936
1937
1938
        ):
            cp_size = get_mesh_axis_size(config.cp_axis, mesh)
            cp_rank = get_mesh_axis_rank(config.cp_axis, mesh)

            # See comment in FusedAttnCPFwdPrimitive.partition for why we define this function.
            def _cross_attn_bwd(
1939
1940
1941
1942
1943
                idx,
                q,
                k,
                v,
                bias,
1944
                softmax_offset,
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
                softmax_aux,
                rng_state,
                output,
                doutput,
                q_seqlen,
                kv_seqlen,
                _q_segment_ids,
                _kv_segment_ids,
                _q_segment_pos,
                _kv_segment_pos,
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
            ):
                kv_max_seqlen = k.shape[1]
                kv_seqlen_per_subrank = kv_max_seqlen // (cp_size * 2)
                assert kv_max_seqlen % cp_size == 0, "sequence length must evenly divide cp size"

                q_split = jnp.split(q, 2, axis=1)
                output_split = jnp.split(output, 2, axis=1)
                doutput_split = jnp.split(doutput, 2, axis=1)
                softmax_aux_split = jnp.split(softmax_aux, 2, axis=2)

                kv_seqlens_for_rank = helper.kv_seqlens_for_rank(
                    idx, kv_max_seqlen, kv_seqlen_per_subrank
                )

                results = []
                for sub_idx in range(2):
Reese Wang's avatar
Reese Wang committed
1971
                    if config.attn_mask_type == AttnMaskType.NO_MASK:
1972
1973
1974
1975
1976
1977
1978
1979
                        k_unmasked, v_unmasked = k, v  # full kv used for unmasked
                    else:
                        k_unmasked, v_unmasked = helper.slice_kv(k, v, kv_seqlens_for_rank[sub_idx])

                    q_seqlen_for_step = q_seqlen // (cp_size * 2)
                    num_kv_chunks = kv_max_seqlen // kv_seqlens_for_rank[sub_idx]
                    kv_seqlen_for_step = (kv_seqlen // (cp_size * 2)) * num_kv_chunks

1980
                    dq_local, dk_local, dv_local, dbias_local, _ = FusedAttnBwdPrimitive.impl(
1981
1982
1983
1984
                        q_split[sub_idx],
                        k_unmasked,
                        v_unmasked,
                        bias,
1985
                        softmax_offset,
1986
1987
1988
1989
1990
1991
1992
1993
                        softmax_aux_split[sub_idx],
                        rng_state,
                        output_split[sub_idx],
                        doutput_split[sub_idx],
                        q_seqlen_for_step,
                        kv_seqlen_for_step,
                        q_seq_offsets,
                        k_seq_offsets,
1994
1995
1996
1997
                        _q_segment_ids,
                        _kv_segment_ids,
                        _q_segment_pos,
                        _kv_segment_pos,
1998
                        config=helper.get_step_config(),
1999
2000
2001
                    )

                    # pad dk/dv to be unsliced shape so we can reduce scatter over all ranks.
Reese Wang's avatar
Reese Wang committed
2002
                    if config.attn_mask_type != AttnMaskType.NO_MASK:
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
                        pad_length = kv_max_seqlen - kv_seqlens_for_rank[sub_idx]
                        dk_local, dv_local = helper.pad_kv(dk_local, dv_local, pad_length)

                    results.append((dq_local, dk_local, dv_local, dbias_local))

                dq_local = jnp.concatenate((results[0][0], results[1][0]), axis=1)
                dk_local_pad = results[0][1] + results[1][1]
                dv_local_pad = results[0][2] + results[1][2]
                return dq_local, dk_local_pad, dv_local_pad, results[1][3]

            k_ag, v_ag = helper.all_gather_kv(k, v)

            functions = [
                partial(
                    _cross_attn_bwd,
                    idx,
                    q,
                    k_ag,
                    v_ag,
                    bias,
2023
                    softmax_offset,
2024
2025
2026
2027
2028
2029
                    softmax_aux,
                    rng_state,
                    output,
                    doutput,
                    q_seqlen,
                    kv_seqlen,
2030
2031
2032
2033
                    _q_segment_ids,
                    _kv_segment_ids,
                    _q_segment_pos,
                    _kv_segment_pos,
2034
2035
2036
2037
2038
2039
2040
                )
                for idx in range(cp_size)
            ]

            dq, dk_local, dv_local, dbias = lax.switch(cp_rank, functions)
            dk, dv = helper.reduce_scatter_dkv(dk_local, dv_local)

2041
2042
2043
            # Return dummy dsoftmax_offset for arity matching (all-gather CP doesn't use it)
            dummy_dsoftmax_offset = jnp.empty_like(softmax_offset)
            return dq, dk, dv, dbias, dummy_dsoftmax_offset
2044
2045
2046
2047
2048
2049
2050

        return mesh, impl, out_shardings, arg_shardings


register_primitive(FusedAttnCPWithAllGatherBwdPrimitive)


2051
2052
2053
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
2064
2065
2066
2067
2068
2069
2070
2071
2072
2073
2074
2075
2076
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115
2116
2117
2118
2119
2120
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
2146
2147
2148
2149
2150
2151
2152
2153
2154
2155
2156
2157
2158
2159
2160
2161
2162
2163
2164
2165
2166
2167
2168
2169
2170
2171
2172
2173
2174
2175
2176
2177
2178
2179
2180
2181
2182
2183
2184
2185
2186
2187
2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
2198
2199
2200
2201
2202
2203
2204
2205
2206
2207
2208
2209
2210
2211
2212
2213
2214
2215
2216
2217
2218
2219
2220
2221
2222
2223
2224
2225
2226
2227
2228
2229
2230
2231
2232
2233
2234
2235
2236
2237
2238
2239
2240
2241
2242
2243
2244
2245
2246
2247
2248
2249
2250
2251
2252
2253
2254
2255
2256
2257
2258
2259
2260
2261
2262
2263
2264
2265
2266
2267
2268
2269
2270
2271
2272
2273
2274
2275
2276
2277
2278
2279
2280
2281
2282
2283
2284
2285
2286
2287
2288
2289
2290
2291
2292
2293
2294
2295
2296
2297
2298
2299
2300
2301
2302
2303
2304
2305
2306
2307
2308
2309
2310
2311
2312
2313
2314
2315
2316
2317
2318
2319
2320
2321
2322
2323
2324
2325
2326
2327
2328
2329
2330
2331
2332
2333
2334
2335
2336
2337
2338
2339
2340
2341
2342
2343
2344
2345
2346
2347
2348
2349
2350
2351
2352
2353
2354
2355
2356
2357
2358
class FusedAttnCPStripedWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive):
    """
    Fused Attention Forward with Context Parallelism and Striped Load Balancing Primitive

    This context parallel implementation uses all-gather to collect KV inputs from context parallel ranks.
    """

    @staticmethod
    def partition(config, mesh, arg_infos, result_infos):
        # Call base implementation for non-context parallel mesh to avoid unecessary work.
        is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1
        if not is_context_parallel:
            return FusedAttnFwdPrimitive.partition(config, mesh, arg_infos, result_infos)

        helper = _FusedAttnCPWithAllGatherHelper(mesh, config)
        helper.check_supported()

        out_sharding = result_infos[0].sharding
        softmax_aux_sharding = result_infos[1].sharding
        rng_state_sharding = seed_sharding = NamedSharding(
            mesh, PartitionSpec(get_all_mesh_axes(), None)
        )
        arg_shardings = [arg_i.sharding for arg_i in arg_infos]
        arg_shardings[5] = seed_sharding
        arg_shardings = tuple(arg_shardings)
        out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding)

        def impl(
            q,
            k,
            v,
            bias,
            softmax_offset,
            seed,
            q_seqlen,
            kv_seqlen,
            q_seq_offsets,
            k_seq_offsets,
            _q_segment_ids,
            _kv_segment_ids,
            _q_segment_pos,
            _kv_segment_pos,
        ):  # pylint: disable=unused-argument
            cp_size = get_mesh_axis_size(config.cp_axis, mesh)
            cp_rank = get_mesh_axis_rank(config.cp_axis, mesh)

            # cuDNN does not support right-aligned masking with dynamic sequence length padding.
            # Therefore we must explicitly instantiate each CP rank slicing and use a runtime switch
            # to select the appropriate computation. Each case generates a [..., SEQ/CP, ..] tensor
            # meeting the expectation of the SPMD model.
            # TODO(mgoldfarb-nvidia): When cuDNN supports we should be able to make use of a padding
            # mask/sequence length tensor to avoid this unrolled loop.

            # Each rank receives the ag k and v along with the ag kv seg ids and kv seg offsets
            # Each rank sees the sharded view for 5 tensors -> q, _q_segment_ids, _q_segment_pos,
            # _kv_segment_ids, _kv_segment_pos -> Note these have also been reordered before passing in.
            def _cross_attn(
                q, k, v, bias, softmax_offset, kv_segment_ids_ag, kv_segment_pos_ag, seed
            ):
                # Helper generates the seqlens and offsets for q and kv and then pass them down to the FusedAttnFwdPrimitive
                # Unset the segment_ids and segment_pos by passing placeholders so that the seqlens_from_segment_ids_pos()
                # does not go down that route but instead just picks the pre-computed seqlens and offsets passed onto it

                kv_max_seqlen = k.shape[1]
                # Estimate an adjusted max_segments_per_seq per rank based on the global max_segments_per_seq
                adjusted_max_segments_per_seq = helper.get_adjusted_max_segments_per_seq(
                    max_seqlen=kv_max_seqlen, cp_size=cp_size
                )
                q_seqlens_for_rank = helper.q_seqlens_for_striped_for_rank(
                    _q_segment_ids, _q_segment_pos, adjusted_max_segments_per_seq
                )
                q_seq_offsets_for_rank = helper.q_seqoffsets_for_striped_for_rank(
                    q_segment_ids=_q_segment_ids,
                    q_segment_pos=_q_segment_pos,
                    max_segments_per_seq=adjusted_max_segments_per_seq,
                )
                kv_seqlens_for_rank = helper.kv_seqlens_for_striped_for_rank(
                    kv_segment_ids=_kv_segment_ids,
                    kv_segment_pos=_kv_segment_pos,
                    max_segments_per_seq=adjusted_max_segments_per_seq,
                )
                kv_seq_offsets_for_rank = helper.kv_seqoffsets_for_striped_for_rank(
                    kv_segment_pos=_kv_segment_pos,
                    kv_segment_ids=_kv_segment_ids,
                    kv_segment_pos_ag=kv_segment_pos_ag,
                    kv_segment_ids_ag=kv_segment_ids_ag,
                    max_segments_per_seq=adjusted_max_segments_per_seq,
                )

                output, softmax_aux, rng_state = FusedAttnFwdPrimitive.impl(
                    q,  # sharded for rank
                    k,  # ag
                    v,  # ag
                    bias,
                    softmax_offset,
                    seed,
                    q_seqlens_for_rank,
                    kv_seqlens_for_rank,
                    q_seq_offsets_for_rank,
                    kv_seq_offsets_for_rank,
                    jnp.zeros(0),
                    jnp.zeros(0),
                    jnp.zeros(0),
                    jnp.zeros(0),
                    config=helper.get_step_config_for_striped(
                        max_seqlen=kv_max_seqlen, cp_size=cp_size
                    ),
                )
                return output, softmax_aux, rng_state

            # AG the k, v, kv_segment_ids and kv_segment_pos
            k_ag, v_ag = helper.all_gather_kv(k, v)
            _kv_segment_ids_ag, _kv_segment_pos_ag = helper.all_gather_segment_ids_and_pos(
                _kv_segment_ids, _kv_segment_pos
            )
            functions = [
                partial(
                    _cross_attn,
                    q,
                    k_ag,
                    v_ag,
                    bias,
                    softmax_offset,
                    _kv_segment_ids_ag,
                    _kv_segment_pos_ag,
                    seed,
                )
                for _ in range(cp_size)
            ]
            return lax.switch(cp_rank, functions)

        return mesh, impl, out_shardings, arg_shardings


register_primitive(FusedAttnCPStripedWithAllGatherFwdPrimitive)


class FusedAttnCPStripedWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive):
    """
    Fused Attention Backward with Context Parallelism and Striped Load Balancing Primitive.

    This context parallel implementation uses all-gather to collect KV and dKV inputs from context parallel ranks.
    The gradients are subsequently reduce-scattered back to each context parallel rank.
    """

    @staticmethod
    def partition(config, mesh, arg_infos, result_infos):
        # Call base implementation for non-context parallel mesh to avoid unecessary work.
        is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1
        if not is_context_parallel:
            return FusedAttnBwdPrimitive.partition(config, mesh, arg_infos, result_infos)

        # Ensure we can support this configuration with context parallelism.
        helper = _FusedAttnCPWithAllGatherHelper(mesh, config)
        helper.check_supported()

        del result_infos
        q_spec = get_padded_spec(arg_infos[0])
        k_spec = get_padded_spec(arg_infos[1])
        v_spec = get_padded_spec(arg_infos[2])
        bias_spec = get_padded_spec(arg_infos[3])
        softmax_offset_spec = get_padded_spec(arg_infos[4])
        dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
        dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec))
        dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec))
        dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec))
        dsoftmax_offset_sharding = NamedSharding(mesh, PartitionSpec(*softmax_offset_spec))
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
        out_shardings = (
            dq_sharding,
            dk_sharding,
            dv_sharding,
            dbias_sharding,
            dsoftmax_offset_sharding,
        )

        def impl(
            q,
            k,
            v,
            bias,
            softmax_offset,
            softmax_aux,
            rng_state,
            output,
            doutput,
            q_seqlen,
            kv_seqlen,
            q_seq_offsets,
            k_seq_offsets,
            _q_segment_ids,
            _kv_segment_ids,
            _q_segment_pos,
            _kv_segment_pos,
        ):  # pylint: disable=unused-argument
            cp_size = get_mesh_axis_size(config.cp_axis, mesh)
            cp_rank = get_mesh_axis_rank(config.cp_axis, mesh)

            # See comment in FusedAttnCPFwdPrimitive.partition for why we define this function.
            def _cross_attn_bwd(
                q,
                k,
                v,
                bias,
                softmax_offset,
                softmax_aux,
                rng_state,
                output,
                doutput,
                _q_segment_ids,
                kv_segment_ids_ag,
                _q_segment_pos,
                kv_segment_pos_ag,
            ):
                # Helper generates the seqlens and offsets for q and kv and then pass them down to the FusedAttnFwdPrimitive
                # Unset the segment_ids and segment_pos by passing placeholders so that the seqlens_from_segment_ids_pos()
                # does not go down that route but instead just picks the pre-computed seqlens and offsets passed onto it

                kv_max_seqlen = k.shape[1]
                # Estimate an adjusted max_segments_per_seq per rank based on the global max_segments_per_seq
                adjusted_max_segments_per_seq = helper.get_adjusted_max_segments_per_seq(
                    max_seqlen=kv_max_seqlen, cp_size=cp_size
                )
                q_seqlens_for_rank = helper.q_seqlens_for_striped_for_rank(
                    _q_segment_ids, _q_segment_pos, adjusted_max_segments_per_seq
                )
                q_seq_offsets_for_rank = helper.q_seqoffsets_for_striped_for_rank(
                    q_segment_ids=_q_segment_ids,
                    q_segment_pos=_q_segment_pos,
                    max_segments_per_seq=adjusted_max_segments_per_seq,
                )
                kv_seqlens_for_rank = helper.kv_seqlens_for_striped_for_rank(
                    kv_segment_ids=_kv_segment_ids,
                    kv_segment_pos=_kv_segment_pos,
                    max_segments_per_seq=adjusted_max_segments_per_seq,
                )
                kv_seq_offsets_for_rank = helper.kv_seqoffsets_for_striped_for_rank(
                    kv_segment_pos=_kv_segment_pos,
                    kv_segment_ids=_kv_segment_ids,
                    kv_segment_pos_ag=kv_segment_pos_ag,
                    kv_segment_ids_ag=kv_segment_ids_ag,
                    max_segments_per_seq=adjusted_max_segments_per_seq,
                )

                dq_local, dk_local, dv_local, dbias_local, _ = FusedAttnBwdPrimitive.impl(
                    q,  # sharded for rank
                    k,  # ag
                    v,  # ag
                    bias,
                    softmax_offset,
                    softmax_aux,
                    rng_state,
                    output,
                    doutput,
                    q_seqlens_for_rank,
                    kv_seqlens_for_rank,
                    q_seq_offsets_for_rank,
                    kv_seq_offsets_for_rank,
                    jnp.zeros(0),
                    jnp.zeros(0),
                    jnp.zeros(0),
                    jnp.zeros(0),
                    config=helper.get_step_config_for_striped(
                        max_seqlen=kv_max_seqlen, cp_size=cp_size
                    ),
                )
                return dq_local, dk_local, dv_local, dbias_local

            # AG the k, v, kv_segment_ids and kv_segment_pos
            k_ag, v_ag = helper.all_gather_kv(k, v)
            _kv_segment_ids_ag, _kv_segment_pos_ag = helper.all_gather_segment_ids_and_pos(
                _kv_segment_ids, _kv_segment_pos
            )

            functions = [
                partial(
                    _cross_attn_bwd,
                    q,
                    k_ag,
                    v_ag,
                    bias,
                    softmax_offset,
                    softmax_aux,
                    rng_state,
                    output,
                    doutput,
                    _q_segment_ids,
                    _kv_segment_ids_ag,
                    _q_segment_pos,
                    _kv_segment_pos_ag,
                )
                for _ in range(cp_size)
            ]

            dq, dk_local, dv_local, dbias = lax.switch(cp_rank, functions)
            # RS the dk and dv
            dk, dv = helper.reduce_scatter_dkv(dk_local, dv_local)

            # Return dummy dsoftmax_offset for arity matching (all-gather CP doesn't use it)
            dummy_dsoftmax_offset = jnp.empty_like(softmax_offset)
            return dq, dk, dv, dbias, dummy_dsoftmax_offset

        return mesh, impl, out_shardings, arg_shardings


register_primitive(FusedAttnCPStripedWithAllGatherBwdPrimitive)


2359
2360
2361
2362
2363
2364
2365
2366
2367
2368
@dataclass(frozen=True)
class _FusedAttnCPWithP2PHelper:
    """Helper class to assist with running the P2P ring strategy for CP attention."""

    mesh: jax.sharding.Mesh
    config: _FusedAttnConfig

    @staticmethod
    def use_scanloop():
        """Returns true if the implementation will use a scan loop for iteration."""
2369
2370
        # TODO(KshitijLakhani): Reset default to 1, once the extra kv permute op issue is resolved
        use_scan = bool(int(os.getenv("NVTE_FUSED_RING_ATTENTION_USE_SCAN", "0")))
2371
        return use_scan
2372
2373
2374
2375
2376

    def check_supported(self):
        """Checks if the context parallel implementation is supported by the given arguments."""
        header = "Context parallel fused ring attention"

Reese Wang's avatar
Reese Wang committed
2377
2378
2379
2380
2381
        if self.config.qkv_layout.is_thd():
            allowed_layouts = [QKVLayout.THD_T2HD, QKVLayout.THD_THD_THD]
        else:
            allowed_layouts = [QKVLayout.BSHD_BS2HD, QKVLayout.BSHD_BSHD_BSHD]

2382
2383
2384
2385
2386
2387
        if self.config.qkv_layout not in allowed_layouts:
            raise ValueError(
                f"{header} only supports layouts:"
                f" {','.join(map(str, allowed_layouts))} got: {self.config.qkv_layout}"
            )

Reese Wang's avatar
Reese Wang committed
2388
        if self.config.attn_bias_type != AttnBiasType.NO_BIAS:
2389
2390
            raise ValueError(f"{header} does not support bias got: {self.config.attn_bias_type}")

Reese Wang's avatar
Reese Wang committed
2391
2392
2393
2394
        if self.config.qkv_layout.is_thd():
            allowed_masks = [AttnMaskType.PADDING_CAUSAL_MASK]
        else:
            allowed_masks = [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]
2395
2396
2397
2398
2399
2400
        if self.config.attn_mask_type not in allowed_masks:
            raise ValueError(
                f"{header} only supports masking types: "
                f" {','.join(map(str, allowed_masks))} got: {self.config.attn_mask_type}"
            )

Reese Wang's avatar
Reese Wang committed
2401
        if not self.config.qkv_layout.is_thd() and self.config.max_segments_per_seq != 1:
2402
2403
2404
2405
2406
2407
2408
2409
            raise ValueError(
                f"{header} only supports max_segments_per_seq == 1 got:"
                f" {self.config.max_segments_per_seq}"
            )

        if self.config.dropout_probability != 0.0:
            raise ValueError(f"{header} does not support dropout")

2410
2411
2412
2413
2414
        if self.config.softmax_type != AttnSoftmaxType.VANILLA_SOFTMAX:
            raise ValueError(
                f"{header} only supports VANILLA_SOFTMAX, got: {self.config.softmax_type}"
            )

2415
2416
2417
2418
2419
2420
        # TODO(KshitijLakhani): Flip the condition to check for disabled scan loop and warn
        # against using unrolled loops once the scan issue is resolved.
        # We want to discourage the use of scan loop as additional kv permute op observed.
        # The scan loop flavor will be supported but not the prefered implementation until
        # a resolution for the additional kv permute op, which degrades perf, is found.
        if self.use_scanloop():
2421
            warnings.warn(
2422
2423
                "Scan loop is enabled for fused ring attention. To disable set"
                " NVTE_FUSED_RING_ATTENTION_USE_SCAN=0 in your environment"
2424
2425
            )

2426
2427
2428
2429
2430
2431
2432
2433
2434
2435
        # If using scanloop, idx in scan_kv_block() will be a traced device value, but
        # _normalize_window_size_for_cp_striped() requires all parameters to be host values
        is_context_parallel = get_mesh_axis_size(self.config.cp_axis, self.mesh) > 1
        is_thd_layout = self.config.qkv_layout.is_thd()
        is_sliding_window = self.config.window_size[0] != -1
        if is_context_parallel and is_thd_layout and is_sliding_window and self.use_scanloop():
            raise ValueError(
                f"{header} with THD format and sliding window does not support using scan loop"
            )

2436
2437
2438
2439
2440
    def get_step_config(self, attn_mask_type) -> _FusedAttnConfig:
        """Returns a _FusedAttnConfig for single CP step call to fused attention."""
        return _FusedAttnConfig(
            attn_bias_type=self.config.attn_bias_type,
            attn_mask_type=attn_mask_type,
2441
            softmax_type=self.config.softmax_type,
Reese Wang's avatar
Reese Wang committed
2442
            qkv_layout=QKVLayout.BSHD_BS2HD,
2443
2444
2445
2446
2447
            scaling_factor=self.config.scaling_factor,
            dropout_probability=self.config.dropout_probability,
            is_training=self.config.is_training,
            max_segments_per_seq=self.config.max_segments_per_seq,
            window_size=self.config.window_size,
2448
            bottom_right_diagonal=attn_mask_type.is_bottom_right(),
2449
2450
            context_parallel_load_balanced=self.config.context_parallel_load_balanced,
            cp_axis=self.config.cp_axis,
2451
            cp_striped_window_size=None,
2452
            stripe_size=self.config.stripe_size,
2453
2454
2455
2456
2457
        )

    def stack_kv(self, k, v):
        """Stacks k and v tensors if not stacked."""
        _not_used = jnp.zeros(0, dtype=k.dtype)
Reese Wang's avatar
Reese Wang committed
2458
2459
2460
2461
        if self.config.qkv_layout.is_kvpacked():
            return k
        if self.config.qkv_layout.is_separate():
            return jnp.stack([k, v], axis=2)
2462
2463
2464
2465
2466
        return _not_used

    def unstack_kv(self, kv):
        """Un-stacks k and v tensors if not stacked."""
        _not_used = jnp.zeros(0, dtype=kv.dtype)
Reese Wang's avatar
Reese Wang committed
2467
2468
2469
2470
        if self.config.qkv_layout.is_kvpacked():
            return kv, _not_used
        if self.config.qkv_layout.is_separate():
            return jnp.unstack(kv, axis=2)
2471
2472
2473
2474
2475
2476
        return _not_used, _not_used  # fall through

    def permute_kv(self, kv, cp_perm):
        """Permutes kv around the ring as described by cp_perm."""
        return lax_paral_op(kv, lax.ppermute, self.config.cp_axis, mesh=self.mesh, perm=cp_perm)

2477
2478
2479
2480
2481
2482
2483
2484
2485
2486
2487
2488
2489
    @staticmethod
    def correct_output_and_softmax_aux(output, softmax_aux, partial_output, partial_softmax_aux):
        """
        Corrects the output and softmax_aux tensor after each iteration of ring attention.

        See https://github.com/zhuzilin/ring-flash-attention/pull/34#issuecomment-2076126795 for
        derivation of this equation.
        """
        new_out = output - jax.nn.sigmoid(partial_softmax_aux - softmax_aux).transpose(
            0, 2, 1, 3
        ) * (output - partial_output)
        new_aux = softmax_aux - jax.nn.log_sigmoid(softmax_aux - partial_softmax_aux)
        return new_out, new_aux
2490
2491
2492
2493
2494
2495
2496
2497
2498
2499
2500
2501
2502
2503
2504
2505
2506
2507
2508
2509
2510
2511
2512
2513
2514
2515
2516
2517
2518
2519
2520
2521
2522

    def adjust_seqlen(self, seqlen, max_seqlen, idx):
        """Adjust the sequence length per step."""
        seqlen_of_curr_step = seqlen - max_seqlen * idx
        seqlen_of_curr_step = jnp.where(seqlen_of_curr_step < 0, 0, seqlen_of_curr_step)
        seqlen_per_step = jnp.where(
            seqlen_of_curr_step < max_seqlen, seqlen_of_curr_step, max_seqlen
        )
        return seqlen_per_step


class FusedRingAttnFwdPrimitive(FusedAttnFwdPrimitive):
    """
    Fused Ring Attention Forward Primitive
    """

    @staticmethod
    def partition(config, mesh, arg_infos, result_infos):
        is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1
        assert (
            not is_context_parallel or config.window_size[0] == -1
        ), "Sliding window attention is not supported when context parallelism is enabled"
        if not is_context_parallel:
            return FusedAttnFwdPrimitive.partition(config, mesh, arg_infos, result_infos)

        helper = _FusedAttnCPWithP2PHelper(mesh, config)
        helper.check_supported()

        out_sharding = result_infos[0].sharding
        softmax_aux_sharding = result_infos[1].sharding
        rng_state_sharding = seed_sharding = NamedSharding(
            mesh, PartitionSpec(get_all_mesh_axes(), None)
        )
2523
        arg_shardings = [arg_i.sharding for arg_i in arg_infos]
2524
        arg_shardings[5] = seed_sharding
2525
2526
2527
        # Ensure segment_pos gets same sharding as ID.
        arg_shardings[-1] = arg_shardings[-3]
        arg_shardings[-2] = arg_shardings[-4]
2528
        arg_shardings = tuple(arg_shardings)
2529
2530
2531
2532
2533
2534
2535
        out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding)

        def ring_attn_fwd_impl(
            q,
            k,
            v,
            bias,
2536
            _softmax_offset,
2537
            seed,
2538
2539
2540
2541
            q_seqlen,
            kv_seqlen,
            q_seq_offsets,
            k_seq_offsets,
2542
2543
2544
2545
            _q_segment_ids,
            _kv_segment_ids,
            _q_segment_pos,
            _kv_segment_pos,
2546
2547
2548
2549
2550
2551
2552
2553
2554
2555
2556
2557
2558
2559
        ):
            _not_used = jnp.zeros(0, dtype=v.dtype)

            # Combine KV tensors if separate for better permute scheduling and performance.
            # Eventually XLA should perform this automatically.
            kv = helper.stack_kv(k, v)

            batch, q_max_seqlen, head, _ = q.shape
            kv_max_seqlen = k.shape[1]

            cp_size = get_mesh_axis_size(config.cp_axis, mesh)
            cp_rank = get_mesh_axis_rank(config.cp_axis, mesh)
            cp_perm = [(i, (i + 1) % cp_size) for i in range(cp_size)]

2560
            output = jnp.zeros(q.shape).astype(jnp.float32)
2561
2562
2563
2564
            softmax_aux = jnp.full((batch, head, q_max_seqlen, 1), -jnp.inf, dtype=jnp.float32)

            # RNG shape should be the shared shape. This is unused for ring attention as we do not
            # support dropout currently.
2565
            rng_state_shape = (seed.shape[0], *result_infos[2].shape[1:])
2566
2567
2568
            rng_state = jnp.zeros(rng_state_shape).astype(result_infos[2].dtype)

            def scan_kv_block(idx, carry):
2569
                kv, output, softmax_aux = carry
2570
2571
2572
2573
2574
2575
2576
2577
2578
2579
2580
2581

                # Send KV block to next step so we can overlap compute.
                kv_next = helper.permute_kv(kv, cp_perm)

                def mask_compute(attn_mask_type):
                    q_seqlen_per_step = helper.adjust_seqlen(q_seqlen, q_max_seqlen, idx)
                    kv_seqlen_per_step = helper.adjust_seqlen(kv_seqlen, kv_max_seqlen, idx)
                    output_per_step, softmax_aux_per_step, _ = FusedAttnFwdPrimitive.impl(
                        q,
                        kv,
                        _not_used,
                        bias,
2582
                        _softmax_offset,
2583
                        seed,
2584
2585
2586
2587
                        q_seqlen_per_step,
                        kv_seqlen_per_step,
                        q_seq_offsets,
                        k_seq_offsets,
2588
2589
2590
2591
2592
                        _q_segment_ids,
                        _kv_segment_ids,
                        _q_segment_pos,
                        _kv_segment_pos,
                        config=helper.get_step_config(attn_mask_type),
2593
2594
2595
                    )
                    return output_per_step, softmax_aux_per_step

Reese Wang's avatar
Reese Wang committed
2596
2597
                causal_mask_compute = partial(mask_compute, AttnMaskType.CAUSAL_MASK)
                no_mask_compute = partial(mask_compute, AttnMaskType.NO_MASK)
2598
2599
2600
2601
2602
2603
2604
2605
2606
2607

                def half_kv_no_mask_compute():
                    q_seqlen_per_step = helper.adjust_seqlen(q_seqlen, q_max_seqlen, idx)
                    kv_seqlen_per_step = helper.adjust_seqlen(kv_seqlen, kv_max_seqlen, idx) // 2
                    kv_part = lax.slice_in_dim(kv, 0, kv.shape[1] // 2, axis=1)
                    output_per_step, softmax_aux_per_step, _ = FusedAttnFwdPrimitive.impl(
                        q,
                        kv_part,
                        _not_used,
                        bias,
2608
                        _softmax_offset,
2609
                        seed,
2610
2611
2612
2613
                        q_seqlen_per_step,
                        kv_seqlen_per_step,
                        q_seq_offsets,
                        k_seq_offsets,
2614
2615
2616
2617
                        _q_segment_ids,
                        _kv_segment_ids,
                        _q_segment_pos,
                        _kv_segment_pos,
Reese Wang's avatar
Reese Wang committed
2618
                        config=helper.get_step_config(AttnMaskType.NO_MASK),
2619
2620
2621
2622
2623
2624
2625
2626
2627
2628
2629
2630
                    )
                    return output_per_step, softmax_aux_per_step

                def half_q_no_mask_compute():
                    q_seqlen_per_step = helper.adjust_seqlen(q_seqlen, q_max_seqlen, idx) // 2
                    kv_seqlen_per_step = helper.adjust_seqlen(kv_seqlen, kv_max_seqlen, idx)
                    q_part = lax.slice_in_dim(q, q_max_seqlen // 2, q_max_seqlen, axis=1)
                    output_per_step, softmax_aux_per_step, _ = FusedAttnFwdPrimitive.impl(
                        q_part,
                        kv,
                        _not_used,
                        bias,
2631
                        _softmax_offset,
2632
                        seed,
2633
2634
2635
2636
                        q_seqlen_per_step,
                        kv_seqlen_per_step,
                        q_seq_offsets,
                        k_seq_offsets,
2637
2638
2639
2640
                        _q_segment_ids,
                        _kv_segment_ids,
                        _q_segment_pos,
                        _kv_segment_pos,
Reese Wang's avatar
Reese Wang committed
2641
                        config=helper.get_step_config(AttnMaskType.NO_MASK),
2642
2643
2644
2645
2646
2647
2648
2649
2650
2651
2652
2653
2654
2655
2656
2657
2658
2659
                    )
                    output_per_step = jnp.concat([jnp.zeros_like(q_part), output_per_step], axis=1)
                    softmax_aux_per_step = jnp.concat(
                        [
                            jnp.full_like(softmax_aux_per_step, -jnp.inf),
                            softmax_aux_per_step,
                        ],
                        axis=2,
                    )
                    return output_per_step, softmax_aux_per_step

                def skip_compute():
                    output_per_step = jnp.zeros_like(q)
                    softmax_aux_per_step = jnp.full(
                        (batch, head, q.shape[1], 1), -jnp.inf, dtype=jnp.float32
                    )
                    return output_per_step, softmax_aux_per_step

Reese Wang's avatar
Reese Wang committed
2660
                if config.attn_mask_type == AttnMaskType.CAUSAL_MASK:
2661
2662
2663
2664
2665
2666
2667
2668
2669
2670
2671
2672
2673
2674
                    # This is for nested jax.lax.cond
                    def jax_cond_wrap():
                        if config.context_parallel_load_balanced:
                            return lax.cond(
                                (idx <= cp_rank), half_kv_no_mask_compute, half_q_no_mask_compute
                            )
                        return lax.cond((idx <= cp_rank), no_mask_compute, skip_compute)

                    output_per_step, softmax_aux_per_step = lax.cond(
                        idx == 0, causal_mask_compute, jax_cond_wrap
                    )
                else:
                    output_per_step, softmax_aux_per_step = no_mask_compute()

2675
2676
2677
2678
2679
                def skip_correction(output, softmax_aux, output_per_step, softmax_aux_per_step):
                    # No correction done here but we cast outputs to float32 and perform reduction
                    # in full precision.
                    # pylint: disable=unused-argument
                    return output_per_step.astype(jnp.float32), softmax_aux_per_step
2680

2681
2682
2683
2684
                def correction(output, softmax_aux, output_per_step, softmax_aux_per_step):
                    return helper.correct_output_and_softmax_aux(
                        output, softmax_aux, output_per_step, softmax_aux_per_step
                    )
2685

2686
2687
2688
2689
2690
2691
2692
2693
2694
2695
2696
2697
2698
2699
                # first step there is no correction we get initial output and stats
                output, softmax_aux = lax.cond(
                    (idx == 0),
                    skip_correction,
                    correction,
                    output,
                    softmax_aux,
                    output_per_step,
                    softmax_aux_per_step,
                )

                return (kv_next, output, softmax_aux)

            carry = (kv, output, softmax_aux)
2700
2701
2702
2703
2704
            if helper.use_scanloop():
                carry = lax.fori_loop(0, cp_size, scan_kv_block, carry)
            else:
                for i in range(0, cp_size):
                    carry = scan_kv_block(i, carry)
2705
            (kv, output, softmax_aux) = carry
2706
2707
2708
2709
2710
2711
2712
2713
2714
2715
2716
2717
2718
2719
2720
2721
2722
2723
2724
2725
2726
2727
2728
2729
2730
2731
2732
2733
2734

            output = output.astype(q.dtype)
            return output, softmax_aux, rng_state

        return mesh, ring_attn_fwd_impl, out_shardings, arg_shardings


register_primitive(FusedRingAttnFwdPrimitive)


class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive):
    """
    Fused Ring Attention Backward Primitive
    """

    @staticmethod
    def partition(config, mesh, arg_infos, result_infos):
        is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1
        assert (
            not is_context_parallel or config.window_size[0] == -1
        ), "Sliding window attention is not supported when context parallelism is enabled"
        if not is_context_parallel:
            return FusedAttnBwdPrimitive.partition(config, mesh, arg_infos, result_infos)

        del result_infos
        q_spec = get_padded_spec(arg_infos[0])
        k_spec = get_padded_spec(arg_infos[1])
        v_spec = get_padded_spec(arg_infos[2])
        bias_spec = get_padded_spec(arg_infos[3])
2735
        softmax_offset_spec = get_padded_spec(arg_infos[4])
2736
2737
2738
2739
        dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
        dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec))
        dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec))
        dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec))
2740
2741
        # Ring attention doesn't use dsoftmax_offset, but we need to return it for arity matching
        dsoftmax_offset_sharding = NamedSharding(mesh, PartitionSpec(*softmax_offset_spec))
2742
2743
2744
2745
        arg_shardings = [arg_i.sharding for arg_i in arg_infos]
        arg_shardings[-1] = arg_shardings[-3]
        arg_shardings[-2] = arg_shardings[-4]
        arg_shardings = tuple(arg_shardings)
2746
2747
2748
2749
2750
2751
2752
        out_shardings = (
            dq_sharding,
            dk_sharding,
            dv_sharding,
            dbias_sharding,
            dsoftmax_offset_sharding,
        )
2753
2754
2755
2756
2757
2758
2759
2760
2761

        helper = _FusedAttnCPWithP2PHelper(mesh, config)
        helper.check_supported()

        def ring_attn_bwd_impl(
            q,
            k,
            v,
            bias,
2762
            _softmax_offset,
2763
2764
2765
2766
2767
2768
2769
2770
            softmax_aux,
            rng_state,
            output,
            doutput,
            q_seqlen,
            kv_seqlen,
            q_seq_offsets,
            k_seq_offsets,
2771
2772
2773
2774
            _q_segment_ids,
            _kv_segment_ids,
            _q_segment_pos,
            _kv_segment_pos,
2775
2776
2777
2778
2779
2780
2781
2782
2783
2784
2785
2786
2787
2788
2789
2790
2791
2792
2793
2794
2795
2796
2797
2798
2799
2800
2801
2802
2803
2804
2805
        ):
            _not_used = jnp.zeros(0, dtype=output.dtype)

            # Combine KV tensors if separate for better permute scheduling and performance.
            # Eventually XLA should perform this automatically.
            kv = helper.stack_kv(k, v)

            q_max_seqlen = q.shape[1]
            kv_max_seqlen = k.shape[1]

            cp_size = get_mesh_axis_size(config.cp_axis, mesh)
            cp_rank = get_mesh_axis_rank(config.cp_axis, mesh)
            cp_perm = [(i, (i + 1) % cp_size) for i in range(cp_size)]

            dq = jnp.zeros_like(q)
            dk_dv = helper.stack_kv(jnp.zeros_like(k), jnp.zeros_like(v))
            dbias = jnp.zeros_like(bias)

            def scan_kv_block(idx, carry):

                kv, dq, dk_dv, dbias = carry

                # Start communication that feeds the next iteraton.
                # We further combine the tensors to improve overlap.

                kv_dk_dv = jnp.stack([kv, dk_dv])
                kv_dk_dv = helper.permute_kv(kv_dk_dv, cp_perm)

                def mask_compute(attn_mask_type):
                    q_seqlen_per_step = helper.adjust_seqlen(q_seqlen, q_max_seqlen, idx)
                    kv_seqlen_per_step = helper.adjust_seqlen(kv_seqlen, kv_max_seqlen, idx)
2806
                    dq_per_step, dk_dv_per_step, _, dbias_per_step, _ = FusedAttnBwdPrimitive.impl(
2807
2808
2809
2810
                        q,
                        kv,
                        _not_used,
                        bias,
2811
                        _softmax_offset,
2812
2813
2814
2815
2816
2817
2818
2819
                        softmax_aux,
                        rng_state,
                        output,
                        doutput,
                        q_seqlen_per_step,
                        kv_seqlen_per_step,
                        q_seq_offsets,
                        k_seq_offsets,
2820
2821
2822
2823
                        _q_segment_ids,
                        _kv_segment_ids,
                        _q_segment_pos,
                        _kv_segment_pos,
2824
2825
2826
2827
                        config=helper.get_step_config(attn_mask_type),
                    )
                    return dq_per_step, dk_dv_per_step, dbias_per_step

Reese Wang's avatar
Reese Wang committed
2828
2829
                causal_mask_compute = partial(mask_compute, AttnMaskType.CAUSAL_MASK)
                no_mask_compute = partial(mask_compute, AttnMaskType.NO_MASK)
2830
2831
2832
2833
2834

                def half_kv_no_mask_compute():
                    q_seqlen_per_step = helper.adjust_seqlen(q_seqlen, q_max_seqlen, idx)
                    kv_seqlen_per_step = helper.adjust_seqlen(kv_seqlen, kv_max_seqlen, idx) // 2
                    kv_part = lax.slice_in_dim(kv, 0, kv_max_seqlen // 2, axis=1)
2835
                    dq_per_step, dk_dv_per_step, _, dbias_per_step, _ = FusedAttnBwdPrimitive.impl(
2836
2837
2838
2839
                        q,
                        kv_part,
                        _not_used,
                        bias,
2840
                        _softmax_offset,
2841
2842
2843
2844
2845
2846
2847
2848
                        softmax_aux,
                        rng_state,
                        output,
                        doutput,
                        q_seqlen_per_step,
                        kv_seqlen_per_step,
                        q_seq_offsets,
                        k_seq_offsets,
2849
2850
2851
2852
                        _q_segment_ids,
                        _kv_segment_ids,
                        _q_segment_pos,
                        _kv_segment_pos,
Reese Wang's avatar
Reese Wang committed
2853
                        config=helper.get_step_config(AttnMaskType.NO_MASK),
2854
2855
2856
2857
2858
2859
2860
2861
2862
2863
2864
2865
2866
2867
2868
2869
2870
2871
2872
2873
                    )
                    dk_dv_per_step = jnp.concat(
                        [dk_dv_per_step, jnp.zeros_like(dk_dv_per_step)], axis=1
                    )
                    return dq_per_step, dk_dv_per_step, dbias_per_step

                def half_q_no_mask_compute():
                    q_seqlen_per_step = helper.adjust_seqlen(q_seqlen, q_max_seqlen, idx) // 2
                    kv_seqlen_per_step = helper.adjust_seqlen(kv_seqlen, kv_max_seqlen, idx)

                    q_part = lax.slice_in_dim(q, q_max_seqlen // 2, q_max_seqlen, axis=1)
                    doutput_part = lax.slice_in_dim(
                        doutput, q_max_seqlen // 2, q_max_seqlen, axis=1
                    )
                    output_part = lax.slice_in_dim(output, q_max_seqlen // 2, q_max_seqlen, axis=1)

                    softmax_aux_part = lax.slice_in_dim(
                        softmax_aux, q_max_seqlen // 2, q_max_seqlen, axis=2
                    )

2874
                    dq_per_step, dk_dv_per_step, _, dbias_per_step, _ = FusedAttnBwdPrimitive.impl(
2875
2876
2877
2878
                        q_part,
                        kv,
                        _not_used,
                        bias,
2879
                        _softmax_offset,
2880
2881
2882
2883
2884
2885
2886
2887
                        softmax_aux_part,
                        rng_state,
                        output_part,
                        doutput_part,
                        q_seqlen_per_step,
                        kv_seqlen_per_step,
                        q_seq_offsets,
                        k_seq_offsets,
2888
2889
2890
2891
                        _q_segment_ids,
                        _kv_segment_ids,
                        _q_segment_pos,
                        _kv_segment_pos,
Reese Wang's avatar
Reese Wang committed
2892
                        config=helper.get_step_config(AttnMaskType.NO_MASK),
2893
2894
2895
2896
2897
2898
2899
                    )
                    dq_per_step = jnp.concat([jnp.zeros_like(dq_per_step), dq_per_step], axis=1)
                    return dq_per_step, dk_dv_per_step, dbias_per_step

                def skip_compute():
                    return jnp.zeros_like(q), jnp.zeros_like(kv), jnp.zeros_like(bias)

Reese Wang's avatar
Reese Wang committed
2900
                if config.attn_mask_type == AttnMaskType.CAUSAL_MASK:
2901
2902
2903
2904
2905
2906
2907
2908
2909
2910
2911
2912
2913
2914
2915
2916
2917
                    # This is for nested jax.lax.cond
                    def jax_cond_wrap():
                        if config.context_parallel_load_balanced:
                            return lax.cond(
                                (idx <= cp_rank), half_kv_no_mask_compute, half_q_no_mask_compute
                            )
                        return lax.cond((idx <= cp_rank), no_mask_compute, skip_compute)

                    dq_per_step, dk_dv_per_step, dbias_per_step = lax.cond(
                        idx == 0, causal_mask_compute, jax_cond_wrap
                    )
                else:
                    dq_per_step, dk_dv_per_step, dbias_per_step = no_mask_compute()

                kv_next, dk_dv = jnp.unstack(kv_dk_dv)
                dq = dq + dq_per_step
                dk_dv = dk_dv + dk_dv_per_step
Reese Wang's avatar
Reese Wang committed
2918
                if config.attn_bias_type is not AttnBiasType.NO_BIAS:
2919
2920
2921
2922
2923
2924
2925
2926
2927
2928
2929
2930
2931
2932
2933
2934
                    dbias = dbias + dbias_per_step

                return (kv_next, dq, dk_dv, dbias)

            carry = (kv, dq, dk_dv, dbias)
            if helper.use_scanloop():
                carry = lax.fori_loop(0, cp_size, scan_kv_block, carry)
            else:
                for i in range(0, cp_size):
                    carry = scan_kv_block(i, carry)
            (kv, dq, dk_dv, dbias) = carry

            # Final permute to put gradients back to their final resting place.
            dk_dv = helper.permute_kv(dk_dv, cp_perm)

            global_dbias = dbias
Reese Wang's avatar
Reese Wang committed
2935
            if config.attn_bias_type is not AttnBiasType.NO_BIAS:
2936
2937
2938
                global_dbias = all_reduce_sum_along_dp_fsdp(dbias, mesh)

            dk, dv = helper.unstack_kv(dk_dv)
2939
2940
2941
            # Return dummy dsoftmax_offset for arity matching (ring attention doesn't use it)
            dummy_dsoftmax_offset = jnp.empty_like(_softmax_offset)
            return dq, dk, dv, global_dbias, dummy_dsoftmax_offset
2942
2943
2944
2945
2946
2947
2948

        return mesh, ring_attn_bwd_impl, out_shardings, arg_shardings


register_primitive(FusedRingAttnBwdPrimitive)


2949
2950
2951
2952
2953
2954
2955
2956
2957
2958
2959
2960
2961
2962
2963
2964
2965
2966
2967
2968
2969
2970
2971
2972
2973
2974
2975
2976
2977
2978
2979
2980
2981
2982
2983
2984
2985
2986
2987
2988
2989
2990
2991
2992
2993
2994
2995
2996
2997
2998
2999
3000
3001
3002
3003
3004
3005
3006
3007
3008
3009
def adjust_cp_striped_window_size(q_pos0, kv_pos0, cp_size, window_size):
    """
    Adjust window size with cp_size for striped sharding, where both q_pos and
    kv_pos are arithmetic sequences like [x, x+cp_size, x+2*cp_size, ...].
    Example 1:
        q_pos = kv_pos = [0, 8, 16, 24, 32], cp_size = 8, window_size = (15, 0).
        q_pos = 32 can look at kv_pos at [24, 32]. The effective mask is:
              0  8 16 24 32
           ----------------
         0 |  1  0  0  0  0
         8 |  1  1  0  0  0
        16 |  0  1  1  0  0
        24 |  0  0  1  1  0
        32 |  0  0  0  1  1
        SequenceDescriptor outputs: {q,kv}_seqlen = [5, ...], {q,kv}_seq_offsets = [0, ...].
        Adjusted window size = (1, 0).
    Example 2:
        q_pos = [0, 8, 16, 24, 32], kv_pos = [1, 9, 17, 25, 33], cp_size = 8,
        window_size = (15, 0). The effective mask is:
              1  9 17 25 33
           ----------------
         0 |  0  0  0  0  0
         8 |  1  0  0  0  0
        16 |  1  1  0  0  0
        24 |  0  1  1  0  0
        32 |  0  0  1  1  0
        SequenceDescriptor outputs:
        q_seqlen = [4, ...], q_seq_offsets = [1, ...],
        kv_seqlen = [4, ...], kv_seq_offsets = [0, ...].
        If diagonal are all 1, left window size = 2. Now since diagonal are all 0,
        we need to use left window size = 2 - 1 = 1 to make cuDNN work.
    Example 3:
        q_pos = [7, 15, 23, 31, 39], kv_pos = [0, 8, 16, 24, 32], cp_size = 8,
        window_size = (22, 0). The effective mask is:
              0  8 16 24 32
           ----------------
         7 |  1  0  0  0  0
        15 |  1  1  0  0  0
        23 |  0  1  1  0  0
        31 |  0  0  1  1  0
        39 |  0  0  0  1  1
        SequenceDescriptor outputs: {q,kv}_seqlen = [5, ...], {q,kv}_seq_offsets = [0, ...].
        Adjust window size = (1, 0).
    """

    left_limit = q_pos0 - window_size[0]
    right_limit = q_pos0 + window_size[1]

    # Count how many left/right steps of size cp_size we can take from kv_pos0 -/+ cp_size
    left_steps = (kv_pos0 - cp_size - left_limit) // cp_size + 1
    right_steps = (right_limit - kv_pos0 - cp_size) // cp_size + 1
    left_steps = max(left_steps, 0)
    right_steps = max(right_steps, 0)

    # If kv_pos0 > q_pos0, we must reduce left window size by 1
    shift = 1 if kv_pos0 > q_pos0 else 0
    left_steps = left_steps - shift

    return left_steps, right_steps


Reese Wang's avatar
Reese Wang committed
3010
3011
3012
3013
3014
3015
3016
3017
3018
3019
3020
3021
3022
3023
3024
3025
3026
3027
3028
3029
class FusedRingAttnStripedFwdPrimitive(FusedAttnFwdPrimitive):
    """
    Fused Striped Ring Attention Forward Primitive
    """

    @staticmethod
    def partition(config, mesh, arg_infos, result_infos):
        is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1
        if not is_context_parallel:
            return FusedAttnFwdPrimitive.partition(config, mesh, arg_infos, result_infos)

        helper = _FusedAttnCPWithP2PHelper(mesh, config)
        helper.check_supported()

        out_sharding = result_infos[0].sharding
        softmax_aux_sharding = result_infos[1].sharding
        rng_state_sharding = seed_sharding = NamedSharding(
            mesh, PartitionSpec(get_all_mesh_axes(), None)
        )
        arg_shardings = [arg_i.sharding for arg_i in arg_infos]
3030
        arg_shardings[5] = seed_sharding
3031
3032
3033
        # Ensure segment_pos gets same sharding as ID.
        arg_shardings[-1] = arg_shardings[-3]
        arg_shardings[-2] = arg_shardings[-4]
Reese Wang's avatar
Reese Wang committed
3034
3035
3036
3037
3038
3039
3040
3041
        arg_shardings = tuple(arg_shardings)
        out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding)

        def fwd_impl(
            q,
            k,
            v,
            bias,
3042
            _softmax_offset,
Reese Wang's avatar
Reese Wang committed
3043
3044
3045
3046
3047
3048
3049
3050
3051
3052
3053
3054
3055
3056
3057
3058
3059
3060
3061
3062
3063
3064
3065
3066
            seed,
            q_seqlen,
            kv_seqlen,
            q_seq_offsets,
            k_seq_offsets,
            q_segment_ids,
            kv_segment_ids,
            q_segment_pos,
            kv_segment_pos,
        ):
            if q_segment_ids.size == 0 or kv_segment_ids.size == 0:
                raise ValueError("THD + ring attn only supports passing seqment_ids/pos")

            _not_used = jnp.zeros(0, dtype=v.dtype)

            # Combine KV tensors if separate for better permute scheduling and performance.
            # Eventually XLA should perform this automatically.
            kv = helper.stack_kv(k, v)
            if not config.qkv_layout.is_qkvpacked():
                subblock_config = replace(config, qkv_layout=config.qkv_layout.to_kvpacked())
            else:
                subblock_config = config

            cp_size = get_mesh_axis_size(config.cp_axis, mesh)
3067
            cp_rank = get_mesh_axis_rank_host(config.cp_axis, mesh)
Reese Wang's avatar
Reese Wang committed
3068
3069
3070
3071
3072
3073
3074
3075
            cp_perm = [(i, (i + 1) % cp_size) for i in range(cp_size)]

            batch, q_max_seqlen, head, _ = q.shape
            output = jnp.zeros(q.shape).astype(jnp.float32)
            softmax_aux = jnp.zeros((batch, q_max_seqlen, head, 1), dtype=jnp.float32)

            # RNG shape should be the shared shape. This is unused for ring attention as we do not
            # support dropout currently.
3076
            rng_state_shape = (seed.shape[0], *result_infos[2].shape[1:])
Reese Wang's avatar
Reese Wang committed
3077
3078
3079
3080
3081
3082
3083
3084
3085
3086
3087
            rng_state = jnp.zeros(rng_state_shape).astype(result_infos[2].dtype)

            def scan_kv_block(idx, carry):
                kv, kv_segment_ids, kv_segment_pos, output, softmax_aux = carry

                # TODO(rewang): To check whether we need special handle for the last idx
                # Send KV block to next step so we can overlap compute.
                kv_next = helper.permute_kv(kv, cp_perm)
                kv_segment_ids_next = helper.permute_kv(kv_segment_ids, cp_perm)
                kv_segment_pos_next = helper.permute_kv(kv_segment_pos, cp_perm)

3088
3089
3090
3091
3092
3093
                def compute(config):
                    return FusedAttnFwdPrimitive.impl(
                        q,
                        kv,
                        _not_used,
                        bias,
3094
                        _softmax_offset,
3095
3096
3097
3098
3099
3100
3101
3102
3103
                        seed,
                        q_seqlen,
                        kv_seqlen,
                        q_seq_offsets,
                        k_seq_offsets,
                        q_segment_ids,
                        kv_segment_ids,
                        q_segment_pos,
                        kv_segment_pos,
3104
                        config=config,
3105
3106
3107
3108
3109
3110
3111
3112
3113
3114
3115
3116
3117
3118
                    )

                if config.window_size != (-1, -1):
                    kv_src_rank = (cp_size + cp_rank - idx) % cp_size
                    # Note: all inputs of adjust_cp_striped_window_size should be host values
                    cp_striped_window_size = adjust_cp_striped_window_size(
                        cp_rank, kv_src_rank, cp_size, config.window_size
                    )
                    current_config = replace(
                        subblock_config, cp_striped_window_size=cp_striped_window_size
                    )
                else:
                    current_config = subblock_config
                output_per_step, softmax_aux_per_step, _ = compute(current_config)
Reese Wang's avatar
Reese Wang committed
3119
3120
3121
3122
3123
3124
3125
3126
3127
3128
3129
3130
3131
3132
3133
3134
3135
3136
3137
3138
3139
3140
3141
3142
3143
3144
3145
3146
3147
3148
3149
3150
3151
3152
3153
3154
3155
3156
3157
3158
3159
3160
3161
3162
3163
3164
3165
3166
3167
3168
3169
3170
3171
3172
3173

                softmax_aux_per_step = softmax_aux_per_step.reshape((batch, q_max_seqlen, head, 1))

                def skip_correction(_output, _softmax_aux, output_per_step, softmax_aux_per_step):
                    # No correction done here but we cast outputs to float32 and perform reduction
                    # in full precision.
                    return output_per_step.astype(jnp.float32), softmax_aux_per_step

                def correction(output, softmax_aux, output_per_step, softmax_aux_per_step):
                    new_out = output - jax.nn.sigmoid(softmax_aux_per_step - softmax_aux) * (
                        output - output_per_step
                    )
                    new_aux = softmax_aux - jax.nn.log_sigmoid(softmax_aux - softmax_aux_per_step)
                    return new_out, new_aux

                # first step there is no correction we get initial output and stats
                output, softmax_aux = lax.cond(
                    idx == 0,
                    skip_correction,
                    correction,
                    output,
                    softmax_aux,
                    output_per_step,
                    softmax_aux_per_step,
                )

                return (kv_next, kv_segment_ids_next, kv_segment_pos_next, output, softmax_aux)

            carry = (kv, kv_segment_ids, kv_segment_pos, output, softmax_aux)
            if helper.use_scanloop():
                carry = lax.fori_loop(0, cp_size, scan_kv_block, carry)
            else:
                for i in range(0, cp_size):
                    carry = scan_kv_block(i, carry)
            (_, _, _, output, softmax_aux) = carry

            return output.astype(q.dtype), softmax_aux, rng_state

        return mesh, fwd_impl, out_shardings, arg_shardings


register_primitive(FusedRingAttnStripedFwdPrimitive)


class FusedRingAttnStripedBwdPrimitive(FusedAttnBwdPrimitive):
    """
    Fused Striped Ring Attention Backward Primitive
    """

    @staticmethod
    def partition(config, mesh, arg_infos, result_infos):
        is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1
        if not is_context_parallel:
            return FusedAttnBwdPrimitive.partition(config, mesh, arg_infos, result_infos)

3174
3175
3176
3177
3178
        arg_shardings = [arg_i.sharding for arg_i in arg_infos]
        # Ensure segment_pos gets same sharding as ID.
        arg_shardings[-1] = arg_shardings[-3]
        arg_shardings[-2] = arg_shardings[-4]
        arg_shardings = tuple(arg_shardings)
3179
3180
        # dq, dk, dv, dbias, dsoftmax_offset sharding = q, k, v, bias, softmax_offset sharding
        out_shardings = tuple(arg.sharding for arg in arg_infos[:5])
Reese Wang's avatar
Reese Wang committed
3181
3182
3183
3184
3185
3186
3187
3188
3189

        helper = _FusedAttnCPWithP2PHelper(mesh, config)
        helper.check_supported()

        def bwd_impl(
            q,
            k,
            v,
            bias,
3190
            _softmax_offset,
Reese Wang's avatar
Reese Wang committed
3191
3192
3193
3194
3195
3196
3197
3198
3199
3200
3201
3202
3203
3204
3205
3206
3207
3208
3209
3210
3211
3212
3213
3214
3215
3216
3217
3218
            softmax_aux,
            rng_state,
            output,
            doutput,
            q_seqlen,
            kv_seqlen,
            q_seq_offsets,
            k_seq_offsets,
            q_segment_ids,
            kv_segment_ids,
            q_segment_pos,
            kv_segment_pos,
        ):

            if q_segment_ids.size == 0 or kv_segment_ids.size == 0:
                raise ValueError("THD + ring attn only supports passing seqment_ids/pos")

            _not_used = jnp.zeros(0, dtype=output.dtype)

            # Combine KV tensors if separate for better permute scheduling and performance.
            # Eventually XLA should perform this automatically.
            kv = helper.stack_kv(k, v)
            if not config.qkv_layout.is_qkvpacked():
                subblock_config = replace(config, qkv_layout=config.qkv_layout.to_kvpacked())
            else:
                subblock_config = config

            cp_size = get_mesh_axis_size(config.cp_axis, mesh)
3219
3220
            # We need cp_rank to be a host value for adjust_cp_striped_window_size()
            cp_rank = get_mesh_axis_rank_host(config.cp_axis, mesh)
Reese Wang's avatar
Reese Wang committed
3221
3222
3223
3224
3225
3226
            cp_perm = [(i, (i + 1) % cp_size) for i in range(cp_size)]

            dq = jnp.zeros_like(q)
            dkv = jnp.zeros_like(kv)
            dbias = jnp.zeros_like(bias)

3227
            def scan_kv_block(idx, carry):
Reese Wang's avatar
Reese Wang committed
3228
3229
3230
3231
3232
3233
3234
3235
3236
                kv, kv_segment_ids, kv_segment_pos, dq, dkv, dbias = carry

                # Start communication that feeds the next iteration.
                # We further combine the tensors to improve overlap.
                kv_dkv = jnp.stack([kv, dkv])
                kv_dkv = helper.permute_kv(kv_dkv, cp_perm)
                kv_segment_ids_next = helper.permute_kv(kv_segment_ids, cp_perm)
                kv_segment_pos_next = helper.permute_kv(kv_segment_pos, cp_perm)

3237
                def compute(config):
3238
                    dq_per_step, dkv_per_step, _, dbias_per_step, _ = FusedAttnBwdPrimitive.impl(
Reese Wang's avatar
Reese Wang committed
3239
3240
3241
3242
                        q,
                        kv,
                        _not_used,
                        bias,
3243
                        _softmax_offset,
Reese Wang's avatar
Reese Wang committed
3244
3245
3246
3247
3248
3249
3250
3251
3252
3253
3254
3255
                        softmax_aux,
                        rng_state,
                        output,
                        doutput,
                        q_seqlen,
                        kv_seqlen,
                        q_seq_offsets,
                        k_seq_offsets,
                        q_segment_ids,
                        kv_segment_ids,
                        q_segment_pos,
                        kv_segment_pos,
3256
                        config=config,
Reese Wang's avatar
Reese Wang committed
3257
3258
3259
                    )
                    return dq_per_step, dkv_per_step, dbias_per_step

3260
3261
3262
3263
3264
3265
3266
3267
3268
3269
3270
3271
                if config.window_size != (-1, -1):
                    kv_src_rank = (cp_size + cp_rank - idx) % cp_size
                    # Note: all inputs of adjust_cp_striped_window_size should be host values
                    cp_striped_window_size = adjust_cp_striped_window_size(
                        cp_rank, kv_src_rank, cp_size, config.window_size
                    )
                    current_config = replace(
                        subblock_config, cp_striped_window_size=cp_striped_window_size
                    )
                else:
                    current_config = subblock_config
                dq_per_step, dkv_per_step, dbias_per_step = compute(current_config)
Reese Wang's avatar
Reese Wang committed
3272
3273
3274
3275
3276
3277
3278
3279
3280
3281
3282
3283
3284
3285
3286
3287
3288
3289
3290
3291
3292
3293
3294
3295
3296

                kv_next, dkv = jnp.unstack(kv_dkv)
                dq += dq_per_step
                dkv += dkv_per_step
                if config.attn_bias_type is not AttnBiasType.NO_BIAS:
                    dbias = dbias + dbias_per_step

                return (kv_next, kv_segment_ids_next, kv_segment_pos_next, dq, dkv, dbias)

            carry = (kv, kv_segment_ids, kv_segment_pos, dq, dkv, dbias)
            if helper.use_scanloop():
                carry = lax.fori_loop(0, cp_size, scan_kv_block, carry)
            else:
                for idx in range(cp_size):
                    carry = scan_kv_block(idx, carry)
            (_, _, _, dq, dkv, dbias) = carry

            # Final permute to put gradients back to their final resting place.
            dkv = helper.permute_kv(dkv, cp_perm)

            global_dbias = dbias
            if config.attn_bias_type is not AttnBiasType.NO_BIAS:
                global_dbias = all_reduce_sum_along_dp_fsdp(dbias, mesh)

            dk, dv = helper.unstack_kv(dkv)
3297
3298
3299
            # Return dummy dsoftmax_offset for arity matching (ring attention doesn't use it)
            dummy_dsoftmax_offset = jnp.empty_like(_softmax_offset)
            return dq, dk, dv, global_dbias, dummy_dsoftmax_offset
Reese Wang's avatar
Reese Wang committed
3300
3301
3302
3303
3304
3305
3306

        return mesh, bwd_impl, out_shardings, arg_shardings


register_primitive(FusedRingAttnStripedBwdPrimitive)


3307
def _maybe_context_parallel_axis(cp_axis: str):
3308
    if not cp_axis and is_mesh_available():
3309
3310
3311
3312
3313
3314
3315
3316
        gmr = global_mesh_resource()
        if gmr is not None:
            cp_axis = gmr.cp_resource
        else:
            cp_axis = ""
    return cp_axis


3317
3318
3319
def fused_attn_fwd(
    qkv: Tuple[jnp.ndarray, ...],
    bias: Optional[jnp.ndarray],
3320
    softmax_offset: Optional[jnp.ndarray],
3321
    sequence_descriptor: SequenceDescriptor,
3322
    seed: Optional[jnp.ndarray],
Reese Wang's avatar
Reese Wang committed
3323
3324
    attn_bias_type: AttnBiasType,
    attn_mask_type: AttnMaskType,
3325
    softmax_type: AttnSoftmaxType,
Reese Wang's avatar
Reese Wang committed
3326
    qkv_layout: QKVLayout,
3327
3328
3329
    scaling_factor: float,
    dropout_probability: float,
    is_training: bool,
3330
    max_segments_per_seq: int,
3331
    window_size: Optional[Tuple[int, int]] = None,
3332
    context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT,
3333
3334
    context_parallel_causal_load_balanced: bool = False,
    context_parallel_axis: str = "",
3335
    stripe_size: int | None = None,
3336
) -> jnp.ndarray:
3337
    """
3338
3339
3340
3341
3342
3343
3344
3345
3346
3347
3348
3349
3350
    Perform the forward pass of with cuDNN fused attention implementations.

    This function implements the following formula:
        BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2
    Args:
        qkv (Tuple[jnp.ndarray, ...]): A tuple containing query, key, and value tensors.
        It supports three formats:
            - `(qkv_packed,)`: For interleaved QKV packed format, typically used when query, key,
              and value have the same shape (e.g., self-attention).
            - `(query, kv_packed)`: For separate query and KV packed format, typically used when
              query has a different shape (e.g., cross-attention).
            - `(query, key, value)`: For separate query, key, and value tensors.
        bias (Optional[jnp.ndarray]): An optional bias tensor to be added to the attention scores.
3351
        softmax_offset (Optional[jnp.ndarray]): An optional softmax offset tensor.
3352
3353
3354
3355
3356
3357
3358
        q_seqlen (jnp.ndarray): Sequence lengths for the query, with shape [batch,].
        kv_seqlen (jnp.ndarray): Sequence lengths for the key and value, with shape [batch,].
        q_seq_offsets (jnp.ndarray):
            The offsets in the sequence dim for the query, with shape [batch + 1,].
        kv_seq_offsets (jnp.ndarray):
            The offsets in the sequence dim for the query, with shape [batch + 1,].
        seed (Optional[jnp.ndarray]): Optional random seed for dropout.
Reese Wang's avatar
Reese Wang committed
3359
3360
        attn_bias_type (AttnBiasType): Type of attention bias.
        attn_mask_type (AttnMaskType): Type of attention mask.
3361
        softmax_type (AttnSoftmaxType): Type of softmax.
Reese Wang's avatar
Reese Wang committed
3362
        qkv_layout (QKVLayout): Layout of the QKV tensors.
3363
3364
3365
        scaling_factor (float): Scaling factor for the attention scores.
        dropout_probability (float): Dropout probability to apply during attention.
        is_training (bool): Flag indicating whether the model is in training mode.
3366
3367
3368
3369
3370
        max_segments_per_seq (int):
            Indicating the maximum number of segments inside a sequence. This parameter is to
            constrain the limit usage and need to be static during the e2e training. The XLA compile
            time and memory consumption is proportional to `max_segments_per_seq`.
        window_size (Optional[Tuple[int, int]]): Sliding window size.
3371
3372
3373
        context_parallel_causal_load_balanced (bool):
            Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
        context_parallel_axis (str): The name of the context parallel axis.
3374
        stripe_size (int | None): Indicates the striping height to be used for ReorderStrategy.Striped Load Balancing
3375
3376
    Returns:
        (jnp.ndarray): The output tensor from the fused attention.
3377
    """
3378
3379
3380
    seed = _FusedAttnRNGStateChecker().check_seed(seed, dropout_probability, is_training)
    # For optional tensors, which custom calls doesn't support None
    _not_used = jnp.zeros(0, dtype=qkv[0].dtype)
3381

Reese Wang's avatar
Reese Wang committed
3382
3383
3384
3385
3386
3387
3388
3389
3390
3391
3392
3393
3394
3395
3396
3397
3398
    if qkv_layout.is_qkvpacked():
        assert len(qkv) == 1, f"qkv=(packed_qkv,) is expected with {qkv_layout=} but got {qkv=}"
        qkv_for_primitive = [*qkv, _not_used, _not_used]
    elif qkv_layout.is_kvpacked():
        assert (
            len(qkv) == 2
        ), f"qkv=(query, packed_kv) is expected with {qkv_layout=} but got {qkv=}"
        qkv_for_primitive = [*qkv, _not_used]
    elif qkv_layout.is_separate():
        assert (
            len(qkv) == 3
        ), f"qkv=(query, key, value) is expected with {qkv_layout=} but got {qkv=}"
        qkv_for_primitive = qkv
    else:
        raise ValueError(f"Unknown {qkv_layout=}")

    if attn_bias_type == AttnBiasType.NO_BIAS:
3399
        assert bias is None
3400
        bias = jnp.zeros(0, dtype=qkv[0].dtype)
3401

3402
3403
3404
3405
3406
3407
3408
3409
3410
3411
3412
3413
3414
3415
3416
3417
3418
3419
3420
3421
3422
3423
3424
3425
3426
    if softmax_offset is None:
        assert (
            softmax_type != AttnSoftmaxType.LEARNABLE_SOFTMAX
        ), f"Softmax type {softmax_type} is not supported when softmax_offset is None"
        if softmax_type == AttnSoftmaxType.OFF_BY_ONE_SOFTMAX:
            num_heads = qkv[0].shape[-2]
            # Create tensor [1, h, 1, 1] filled with zeros (logit value = 0)
            # This adds exp(0 - x_max) = exp(-x_max) to the denominator,
            # which contributes exactly 1 after normalization, giving: exp(x_i) / (sum(exp(x_j)) + 1)
            softmax_offset = jnp.zeros((1, num_heads, 1, 1), dtype=jnp.float32)
            # Shard by heads dimension
            softmax_offset = with_sharding_constraint_by_logical_axes(
                softmax_offset, (None, HEAD_AXES, None, None)
            )
        else:
            assert softmax_type == AttnSoftmaxType.VANILLA_SOFTMAX
            softmax_offset = jnp.zeros(0, dtype=jnp.float32)
    else:
        assert softmax_offset.dtype == jnp.float32
        # Shard by heads dimension if not VANILLA_SOFTMAX
        if softmax_type != AttnSoftmaxType.VANILLA_SOFTMAX:
            softmax_offset = with_sharding_constraint_by_logical_axes(
                softmax_offset, (None, HEAD_AXES, None, None)
            )

3427
    fused_config = _FusedAttnConfig(
3428
3429
        attn_bias_type=attn_bias_type,
        attn_mask_type=attn_mask_type,
3430
        qkv_layout=qkv_layout,
3431
        softmax_type=softmax_type,
3432
3433
3434
        scaling_factor=scaling_factor,
        dropout_probability=dropout_probability,
        is_training=is_training,
3435
        max_segments_per_seq=max_segments_per_seq,
3436
        window_size=(-1, -1) if window_size is None else window_size,
3437
        bottom_right_diagonal=attn_mask_type.is_bottom_right(),
3438
3439
        context_parallel_load_balanced=context_parallel_causal_load_balanced,
        cp_axis=_maybe_context_parallel_axis(context_parallel_axis),
3440
        cp_striped_window_size=None,
3441
        stripe_size=stripe_size,
3442
3443
    )

3444
    primitive = None
3445
3446
    match context_parallel_strategy:
        case CPStrategy.DEFAULT | CPStrategy.ALL_GATHER:
3447
3448
3449
3450
            if qkv_layout.is_thd():
                primitive = FusedAttnCPStripedWithAllGatherFwdPrimitive.outer_primitive
            else:
                primitive = FusedAttnCPWithAllGatherFwdPrimitive.outer_primitive
3451
        case CPStrategy.RING:
Reese Wang's avatar
Reese Wang committed
3452
3453
3454
3455
3456
            # We must use stripe attention for THD-RING
            if qkv_layout.is_thd():
                primitive = FusedRingAttnStripedFwdPrimitive.outer_primitive
            else:
                primitive = FusedRingAttnFwdPrimitive.outer_primitive
3457

3458
    seq_desc_flatten, _ = jax.tree.flatten(sequence_descriptor)
3459
    output, softmax_aux, rng_state = primitive.bind(
3460
3461
        *qkv_for_primitive,
        bias,
3462
        softmax_offset,
3463
        seed,
3464
        *seq_desc_flatten,
3465
        config=fused_config,
3466
    )
3467
3468
    rng_state = with_sharding_constraint(rng_state, PartitionSpec(get_all_mesh_axes(), None))
    return (output, softmax_aux, rng_state)
3469
3470


3471
3472
3473
def fused_attn_bwd(
    qkv: Tuple[jnp.ndarray, ...],
    bias: Optional[jnp.ndarray],
3474
    softmax_offset: Optional[jnp.ndarray],
3475
3476
3477
3478
    softmax_aux: jnp.ndarray,
    rng_state: jnp.ndarray,
    output: jnp.ndarray,
    doutput: jnp.ndarray,
3479
    sequence_descriptor: SequenceDescriptor,
Reese Wang's avatar
Reese Wang committed
3480
3481
3482
    attn_bias_type: AttnBiasType,
    attn_mask_type: AttnMaskType,
    qkv_layout: QKVLayout,
3483
    softmax_type: AttnSoftmaxType,
3484
3485
3486
    scaling_factor: float,
    dropout_probability: float,
    is_training: bool,
3487
    max_segments_per_seq: int,
3488
    window_size: Optional[Tuple[int, int]] = None,
3489
    context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT,
3490
3491
    context_parallel_causal_load_balanced: bool = False,
    context_parallel_axis: str = "",
3492
    stripe_size: int | None = None,
3493
):
3494
    """
3495
3496
3497
3498
3499
3500
3501
3502
3503
3504
3505
    Perform the backward pass of the cuDNN fused attention implementations.

    Args:
        qkv (Tuple[jnp.ndarray, ...]): A tuple containing the original query, key, and value tensors
        used in the forward pass. It supports three formats:
            - `(qkv_packed,)`: For interleaved QKV packed format, typically used when query, key,
              and value have the same shape (e.g., self-attention).
            - `(query, kv_packed)`: For separate query and KV packed format, typically used when
              query has a different shape (e.g., cross-attention).
            - `(query, key, value)`: For separate query, key, and value tensors.
        bias (Optional[jnp.ndarray]): An optional bias tensor to be added to the attention scores.
3506
        softmax_offset (Optional[jnp.ndarray]): An optional softmax offset tensor.
3507
3508
3509
3510
3511
3512
3513
3514
3515
3516
        softmax_aux (jnp.ndarray): Auxiliary tensors from the softmax step used in the forward pass.
        rng_state (jnp.ndarray): Auxiliary tensors to save the random state in the forward pass.
        output (jnp.ndarray): The output tensor from the forward pass.
        doutput (jnp.ndarray): The gradient with respect to the output.
        q_seqlen (jnp.ndarray): Sequence lengths for the query, with shape [batch,].
        kv_seqlen (jnp.ndarray): Sequence lengths for the key and value, with shape [batch,].
        q_seq_offsets (jnp.ndarray):
            The offsets in the sequence dim for the query, with shape [batch + 1,].
        kv_seq_offsets (jnp.ndarray):
            The offsets in the sequence dim for the query, with shape [batch + 1,].
Reese Wang's avatar
Reese Wang committed
3517
3518
        attn_bias_type (AttnBiasType): Type of attention bias.
        attn_mask_type (AttnMaskType): Type of attention mask.
3519
        softmax_type (AttnSoftmaxType): Type of softmax.
Reese Wang's avatar
Reese Wang committed
3520
        qkv_layout (QKVLayout): Layout of the QKV tensors.
3521
3522
3523
        scaling_factor (float): Scaling factor for the attention scores.
        dropout_probability (float): Dropout probability to apply during attention.
        is_training (bool): Flag indicating whether the model is in training mode.
3524
3525
3526
3527
3528
        max_segments_per_seq (int):
            Indicating the maximum number of segments inside a sequence. This parameter is to
            constrain the limit usage and need to be static during the e2e training. The XLA compile
            time and memory consumption is proportional to `max_segments_per_seq`.
        window_size (Optional[Tuple[int, int]]): Sliding window size .
3529
3530
3531
        context_parallel_causal_load_balanced (bool):
            Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
        context_parallel_axis (str): The name of the context parallel axis.
3532
        stripe_size (int | None): Indicates the striping height to be used for ReorderStrategy.Striped Load Balancing
3533
3534
3535
3536
3537
    Returns:
        Tuple[jnp.ndarray, ...], jnp.ndarray:
        - The first tuple contains the gradients with respect to the input `qkv` tensors in the
          same format as the input `qkv`.
        - The second value is the gradient with respect to `bias`, or `None` if `bias` is `None`.
3538
    """
3539
3540
3541
    # For optional tensors, which custom calls doesn't support None
    _not_used = jnp.zeros(0, dtype=qkv[0].dtype)

Reese Wang's avatar
Reese Wang committed
3542
3543
3544
3545
3546
3547
3548
3549
3550
3551
3552
3553
3554
3555
3556
3557
3558
    if qkv_layout.is_qkvpacked():
        assert len(qkv) == 1, f"qkv=(packed_qkv,) is expected with {qkv_layout=} but got {qkv=}"
        qkv_for_primitive = [*qkv, _not_used, _not_used]
    elif qkv_layout.is_kvpacked():
        assert (
            len(qkv) == 2
        ), f"qkv=(query, packed_kv) is expected with {qkv_layout=} but got {qkv=}"
        qkv_for_primitive = [*qkv, _not_used]
    elif qkv_layout.is_separate():
        assert (
            len(qkv) == 3
        ), f"qkv=(query, key, value) is expected with {qkv_layout=} but got {qkv=}"
        qkv_for_primitive = qkv
    else:
        raise ValueError(f"Unknown {qkv_layout=}")

    if attn_bias_type == AttnBiasType.NO_BIAS:
3559
        assert bias is None
3560
        bias = jnp.zeros(0, dtype=qkv[0].dtype)
3561

3562
3563
3564
3565
3566
3567
3568
3569
3570
3571
3572
3573
3574
3575
3576
3577
3578
3579
3580
3581
3582
3583
    if softmax_offset is None:
        assert softmax_type != AttnSoftmaxType.LEARNABLE_SOFTMAX, f"Unknown {softmax_type=}"
        if softmax_type == AttnSoftmaxType.OFF_BY_ONE_SOFTMAX:
            num_heads = qkv[0].shape[-2]
            # Create tensor [1, h, 1, 1] filled with zeros
            softmax_offset = jnp.zeros((1, num_heads, 1, 1), dtype=jnp.float32)
            # Shard by heads dimension
            softmax_offset = with_sharding_constraint_by_logical_axes(
                softmax_offset, (None, HEAD_AXES, None, None)
            )
        elif softmax_type == AttnSoftmaxType.VANILLA_SOFTMAX:
            softmax_offset = jnp.zeros(0, dtype=jnp.float32)
        else:
            raise NotImplementedError(f"Unknown {softmax_type=}")
    else:
        softmax_offset = softmax_offset.astype(jnp.float32)
        # Shard by heads dimension if not VANILLA_SOFTMAX
        if softmax_type != AttnSoftmaxType.VANILLA_SOFTMAX:
            softmax_offset = with_sharding_constraint_by_logical_axes(
                softmax_offset, (None, HEAD_AXES, None, None)
            )

3584
    compute_capabilities = get_all_device_compute_capability()
3585
3586
3587
3588
3589
3590
3591
3592
3593
3594
3595
3596
3597
3598
    if any(x >= 100 for x in compute_capabilities) and is_training:
        assert (
            FusedAttnHelper.is_non_deterministic_allowed()
            and get_cudnn_version() >= (9, 7, 0)
            and (attn_bias_type == AttnBiasType.NO_BIAS or dropout_probability == 0.0)
        ) or (
            not FusedAttnHelper.is_non_deterministic_allowed()
            and get_cudnn_version() >= (9, 18, 1)
            and attn_bias_type == AttnBiasType.NO_BIAS
            and dropout_probability == 0.0
        ), (
            "For sm100+, non-deterministic bprop (cuDNN 9.7+) does not support bias with dropout,"
            " and deterministic bprop (cuDNN 9.18.1+) does not support bias or dropout"
        )
3599

3600
3601
3602
3603
    fused_config = _FusedAttnConfig(
        attn_bias_type=attn_bias_type,
        attn_mask_type=attn_mask_type,
        qkv_layout=qkv_layout,
3604
        softmax_type=softmax_type,
3605
3606
3607
3608
        scaling_factor=scaling_factor,
        dropout_probability=dropout_probability,
        is_training=is_training,
        max_segments_per_seq=max_segments_per_seq,
3609
        window_size=(-1, -1) if window_size is None else window_size,
3610
        bottom_right_diagonal=attn_mask_type.is_bottom_right(),
3611
3612
        context_parallel_load_balanced=context_parallel_causal_load_balanced,
        cp_axis=_maybe_context_parallel_axis(context_parallel_axis),
3613
        cp_striped_window_size=None,
3614
        stripe_size=stripe_size,
3615
3616
    )

3617
    primitive = None
3618
3619
    match context_parallel_strategy:
        case CPStrategy.DEFAULT | CPStrategy.ALL_GATHER:
3620
3621
3622
3623
            if qkv_layout.is_thd():
                primitive = FusedAttnCPStripedWithAllGatherBwdPrimitive.outer_primitive
            else:
                primitive = FusedAttnCPWithAllGatherBwdPrimitive.outer_primitive
3624
        case CPStrategy.RING:
Reese Wang's avatar
Reese Wang committed
3625
3626
3627
3628
            if qkv_layout.is_thd():
                primitive = FusedRingAttnStripedBwdPrimitive.outer_primitive
            else:
                primitive = FusedRingAttnBwdPrimitive.outer_primitive
3629
3630

    seq_desc_flatten, _ = jax.tree.flatten(sequence_descriptor)
3631
    *qkv_grads, bias_grad, softmax_offset_grad = primitive.bind(
3632
        *qkv_for_primitive,
3633
        bias,
3634
        softmax_offset,
3635
3636
3637
3638
        softmax_aux,
        rng_state,
        output,
        doutput,
3639
        *seq_desc_flatten,
3640
        config=fused_config,
3641
    )
3642
    return tuple(qkv_grads[: len(qkv)]), bias_grad, softmax_offset_grad