attention.py 116 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.
"""JAX/TE custom ops for attention"""
import operator
6
import os
7
import warnings
8
9
10
from dataclasses import dataclass, replace
from functools import partial, reduce
from typing import Optional, Tuple
11

12
import jax
13
import jax.numpy as jnp
14
from jax import dtypes, lax, ffi
15
from jax.sharding import PartitionSpec, NamedSharding
16
from jax.experimental.custom_partitioning import SdyShardingRule
17
18
19

import transformer_engine_jax
from transformer_engine_jax import NVTE_Fused_Attn_Backend
Reese Wang's avatar
Reese Wang committed
20
21
22
from transformer_engine.jax.attention import (
    AttnBiasType,
    AttnMaskType,
23
    AttnSoftmaxType,
Reese Wang's avatar
Reese Wang committed
24
25
26
27
28
    QKVLayout,
    QKVFormat,
    CPStrategy,
    SequenceDescriptor,
)
29
from ..sharding import with_sharding_constraint_by_logical_axes, HEAD_AXES
30

31
32
33
34
35
from .base import BasePrimitive, register_primitive
from .misc import (
    check_valid_batch_dims,
    jax_dtype_to_te_dtype,
    te_dtype_to_jax_dtype,
36
    get_padded_spec,
37
    get_cudnn_version,
38
    get_all_device_compute_capability,
39
40
)
from ..sharding import (
41
42
    global_mesh_resource,
    lax_paral_op,
43
    all_reduce_sum_along_dp_fsdp,
44
45
    get_mesh_axis_size,
    get_mesh_axis_rank,
46
    get_mesh_axis_rank_host,
47
48
    get_all_mesh_axes,
    num_of_devices,
49
    with_sharding_constraint,
50
51
52
)


53
54
55
56
57
__all__ = [
    "FusedAttnHelper",
    "fused_attn_fwd",
    "fused_attn_bwd",
]
58
59


60
61
62
63
64
65
@partial(
    jax.tree_util.register_dataclass,
    data_fields=[],
    meta_fields=[
        "attn_bias_type",
        "attn_mask_type",
66
        "softmax_type",
67
68
69
70
71
        "qkv_layout",
        "scaling_factor",
        "dropout_probability",
        "is_training",
        "max_segments_per_seq",
72
        "window_size",
73
74
        "context_parallel_load_balanced",
        "cp_axis",
75
        "cp_striped_window_size",
76
77
78
79
80
81
82
83
    ],
)
@dataclass(frozen=True)
class _FusedAttnConfig:
    """
    Passes static configuration properties of fused attention.
    """

Reese Wang's avatar
Reese Wang committed
84
85
    attn_bias_type: AttnBiasType
    attn_mask_type: AttnMaskType
86
    softmax_type: AttnSoftmaxType
Reese Wang's avatar
Reese Wang committed
87
    qkv_layout: QKVLayout
88
89
90
91
    scaling_factor: float
    dropout_probability: float
    is_training: bool
    max_segments_per_seq: int
92
    window_size: Tuple[int, int]
93
94
    context_parallel_load_balanced: bool
    cp_axis: str
95
    cp_striped_window_size: Tuple[int, int]  # Only for CP + Ring + THD + SWA
96
97


98
99
100
101
102
103
@dataclass(frozen=True)
class FusedAttnHelper:
    """
    Helper for the fused attention backend
    """

104
    is_training: bool
105
106
    q_dtype: jnp.dtype
    kv_dtype: jnp.dtype
Reese Wang's avatar
Reese Wang committed
107
108
109
    qkv_layout: QKVLayout
    attn_bias_type: AttnBiasType
    attn_mask_type: AttnMaskType
110
    softmax_type: AttnSoftmaxType
111
112
113
114
115
    dropout_probability: float
    q_num_heads: int
    kv_num_heads: int
    q_max_seqlen: int
    kv_max_seqlen: int
116
117
    head_dim_qk: int
    head_dim_v: int
118
    window_size: Tuple[int, int]
119
120
121
122
123
124
125
126

    def is_fused_attn_kernel_available(self):
        """Check if there is available fused attention kernel"""
        return self.get_fused_attn_backend() != NVTE_Fused_Attn_Backend.NVTE_No_Backend

    def get_fused_attn_backend(self):
        """Get the fused attention kernel backend"""
        return transformer_engine_jax.get_fused_attn_backend(
127
            self.is_training,
128
129
            jax_dtype_to_te_dtype(self.q_dtype),
            jax_dtype_to_te_dtype(self.kv_dtype),
Reese Wang's avatar
Reese Wang committed
130
131
132
            self.qkv_layout.value,
            self.attn_bias_type.value,
            self.attn_mask_type.value,
133
            self.softmax_type.value,
134
135
136
137
138
            self.dropout_probability,
            self.q_num_heads,
            self.kv_num_heads,
            self.q_max_seqlen,
            self.kv_max_seqlen,
139
140
            self.head_dim_qk,
            self.head_dim_v,
141
142
            self.window_size[0],
            self.window_size[1],
143
        )
144

145
146
147
148
149
    @staticmethod
    def is_non_deterministic_allowed():
        """Check if non-deterministic kernels are allowed"""
        return bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1")))

150
151
152
    @staticmethod
    def parse_qkv_aval(q_aval, k_aval, v_aval, qkv_layout):
        """Parse qkv aval"""
Reese Wang's avatar
Reese Wang committed
153
154
155
156
157
158
159
        if qkv_layout.get_qkv_format() == QKVFormat.SBHD:
            raise NotImplementedError
        if qkv_layout.is_qkvpacked():
            *q_batch_shape, q_max_seqlen, nqkv, attn_heads, q_head_dim = q_aval.shape
            kv_batch_shape = q_batch_shape
            kv_max_seqlen = q_max_seqlen
            num_gqa_groups = attn_heads
160
            v_head_dim = q_head_dim
Reese Wang's avatar
Reese Wang committed
161
162
163
            assert nqkv == 3
        elif qkv_layout.is_kvpacked():
            *q_batch_shape, q_max_seqlen, attn_heads, q_head_dim = q_aval.shape
164
165
166
            *kv_batch_shape, kv_max_seqlen, nkv, num_gqa_groups, v_head_dim = k_aval.shape
            assert q_batch_shape == kv_batch_shape
            assert q_head_dim == v_head_dim
Reese Wang's avatar
Reese Wang committed
167
168
169
            assert nkv == 2
        elif qkv_layout.is_separate():
            *q_batch_shape, q_max_seqlen, attn_heads, q_head_dim = q_aval.shape
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
            *k_batch_shape, k_max_seqlen, k_num_gqa_groups, k_head_dim = k_aval.shape
            *v_batch_shape, v_max_seqlen, v_num_gqa_groups, v_head_dim = v_aval.shape
            assert (
                q_head_dim == k_head_dim
            ), f"Mismatched q_head_dim: {q_head_dim} and k_head_dim: {k_head_dim}"
            assert (
                k_max_seqlen == v_max_seqlen
            ), f"Mismatched k_max_seqlen: {k_max_seqlen} and v_max_seqlen: {v_max_seqlen}"
            kv_max_seqlen = k_max_seqlen
            assert q_batch_shape == k_batch_shape == v_batch_shape, (
                f"Mismatched qkv batch size for q_batch_shape: {q_batch_shape}, k_batch_shape:"
                f" {k_batch_shape} and v_batch_shape: {v_batch_shape}"
            )
            assert k_num_gqa_groups == v_num_gqa_groups, (
                f"Mismatched k_num_gqa_groups: {k_num_gqa_groups} and v_num_gqa_groups:"
                f" {v_num_gqa_groups}"
            )
            num_gqa_groups = k_num_gqa_groups
Reese Wang's avatar
Reese Wang committed
188
189
        else:
            raise ValueError(f"Unexpected {qkv_layout=}")
190
191
192
193
194
195
196
197
198
199
200
201
202
        assert q_aval.dtype == k_aval.dtype == v_aval.dtype, (
            f"Mismatched data types for q_aval: {q_aval.dtype}, k_aval: {k_aval.dtype}, v_aval:"
            f" {v_aval.dtype}"
        )
        return (
            q_batch_shape,
            q_max_seqlen,
            kv_max_seqlen,
            attn_heads,
            num_gqa_groups,
            q_head_dim,
            v_head_dim,
        )
203
204
205
206
207
208
209
210
211
212
213


@dataclass(frozen=True)
class _FusedAttnRNGStateChecker:
    """
    Checker for guarding the fused attention rng state.
    The fused attention backend requires a 64 bits seed and a 64 bits offset.
    However, JAX doesn't enable 64 bits by default,
    so we have to emulate seed as two 32 bits array.
    The offset calculation is maintained in the backend.
    """
214

215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
    rng_state_dtype: jnp.dtype = jnp.uint32
    # (seed,) with internal dtype int64
    seed_size: int = 2
    # (seed, offset) with internal dtype int64
    rng_state_size: int = 2 * 2

    def check_seed(self, seed, dropout_probability, is_training):
        """
        Check the seed and convert the data type of seed if possible.
        """
        # Jax can't bind None, create a dummy tensor for None
        if seed is None:
            dropout_enabled = dropout_probability > 0 and is_training
            assert not dropout_enabled, "seed is not allowed to be None when dropout is enabled."
            seed = jnp.zeros(2, dtype=self.rng_state_dtype)
            seed = jnp.repeat(seed, num_of_devices())

        if seed.dtype != self.rng_state_dtype:
            warnings.warn(
                f"Requested {seed.dtype=} is not available, and will be "
                f"casted to dtype {self.rng_state_dtype}. "
236
237
                "Please use threefry/rbg/unsafe_rbg PRNG implementations to remove this warning."
            )
238
239
240
241
242
243
244
245
246
247
248
249
250
            seed = seed.astype(self.rng_state_dtype)

        assert seed.dtype == self.rng_state_dtype
        # Backend takes an int64_t seed, so only the first two u32 elements are taken
        assert seed.size >= self.seed_size

        return seed


def generate_cu_seqlen(actual_seqlen):
    """
    Generating cumsum seqlen for a batch
    """
251
252
    actual_seqlen = jnp.where(actual_seqlen < 0, 0, actual_seqlen)
    cu_seqlen = jnp.cumulative_sum(actual_seqlen, include_initial=True)
253
254
255
256
257
258
259
    return cu_seqlen


class FusedAttnFwdPrimitive(BasePrimitive):
    """
    Fused Attention Forward Primitive
    """
260

261
    name = "te_fused_attn_forward_ffi"
262
    multiple_results = True
263
    impl_static_args = (14,)
264
265
266
267
    inner_primitive = None
    outer_primitive = None

    @staticmethod
268
269
270
271
272
    def abstract(
        q_aval,
        k_aval,
        v_aval,
        bias_aval,
273
        softmax_offset_aval,
274
        seed_aval,
275
276
        q_seqlen_or_cu_seqlen_aval,
        kv_seqlen_or_cu_seqlen_aval,
277
278
        _q_seq_offsets,
        _k_seq_offsets,
279
280
281
282
        _q_segment_ids,
        _kv_segment_ids,
        _q_segment_pos,
        _kv_segment_pos,
283
        *,
284
        config: _FusedAttnConfig,
285
    ):
286
287
288
289
290
291
292
        """
        Fused attention fwd abstract
        """
        q_dtype = dtypes.canonicalize_dtype(q_aval.dtype)
        k_dtype = dtypes.canonicalize_dtype(k_aval.dtype)
        v_dtype = dtypes.canonicalize_dtype(v_aval.dtype)
        bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype)
293
294
295
296
297
298
299
        assert (
            q_dtype == k_dtype == v_dtype == bias_dtype
        ), f"q_dtype={q_dtype}, k_dtype={k_dtype}, v_dtype={v_dtype}, bias_dtype={bias_dtype}"
        assert q_seqlen_or_cu_seqlen_aval.dtype == kv_seqlen_or_cu_seqlen_aval.dtype, (
            f"q_seqlen_or_cu_seqlen_aval={q_seqlen_or_cu_seqlen_aval},"
            f" kv_seqlen_or_cu_seqlen_aval={kv_seqlen_or_cu_seqlen_aval}"
        )
300

301
302
303
304
305
306
307
308
309
        (
            batch_shape,
            q_max_seqlen,
            kv_max_seqlen,
            attn_heads,
            num_gqa_groups,
            q_head_dim,
            v_head_dim,
        ) = FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout)
310

311
        output_shape = (*batch_shape, q_max_seqlen, attn_heads, v_head_dim)
312
313
314
        out_aval = q_aval.update(shape=output_shape, dtype=q_dtype)

        # backend determines the softmax buffer shape/dtype
315
        backend = FusedAttnHelper(
316
            config.is_training,
317
318
            q_dtype,
            k_dtype,
319
320
321
            config.qkv_layout,
            config.attn_bias_type,
            config.attn_mask_type,
322
            config.softmax_type,
323
            config.dropout_probability,
324
325
326
327
            attn_heads,
            num_gqa_groups,
            q_max_seqlen,
            kv_max_seqlen,
328
329
            q_head_dim,
            v_head_dim,
330
            config.window_size,
331
        ).get_fused_attn_backend()
332
333
334
335
336

        if backend == NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen:
            softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, kv_max_seqlen)
            softmax_dtype = q_dtype
        elif backend == NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen:
337
338
            # cuDNN 9.6 reduces the required softmax shape
            if get_cudnn_version() >= (9, 6, 0):
339
340
341
342
                if config.qkv_layout.is_thd():
                    softmax_shape = (*batch_shape, q_max_seqlen, attn_heads, 1)
                else:
                    softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, 1)
343
344
345
346
347
348
349
            else:
                softmax_shape = (
                    *batch_shape,
                    attn_heads,
                    q_max_seqlen,
                    config.max_segments_per_seq,
                )
350
351
            softmax_dtype = dtypes.canonicalize_dtype(jnp.float32)
        else:
352
            raise ValueError(f"Unsupported {backend=}")
353
354
355
356
357
358
359
360
361
362
        softmax_aux_aval = q_aval.update(shape=softmax_shape, dtype=softmax_dtype)

        # JAX does not enable 64-bit int by default so we get XLA to allocate x8 memory with
        # 32-bit unsigned int to get the buffer size we need in the C++ kernel
        checker = _FusedAttnRNGStateChecker()
        seed_dtype = dtypes.canonicalize_dtype(seed_aval.dtype)
        assert seed_dtype == checker.rng_state_dtype
        rng_state_shape = (seed_aval.shape[0], checker.rng_state_size)
        rng_state_aval = seed_aval.update(shape=rng_state_shape, dtype=checker.rng_state_dtype)

Reese Wang's avatar
Reese Wang committed
363
        if config.attn_bias_type == AttnBiasType.NO_BIAS:
364
365
366
367
368
369
370
371
372
            bias_batch = bias_heads = 0
        else:
            *bias_batch_shape, bias_heads, _, _ = bias_aval.shape
            bias_batch = reduce(operator.mul, bias_batch_shape)

        # do a dummy kernel call here to get workspace buffer shapes/dtypes that XLA needs to
        # prepare for the active fused-attn backend
        input_batch = reduce(operator.mul, batch_shape)
        wkspace_info = transformer_engine_jax.get_fused_attn_fwd_workspace_sizes(
373
374
375
376
377
378
379
            input_batch,
            bias_batch,
            q_max_seqlen,
            kv_max_seqlen,
            attn_heads,
            num_gqa_groups,
            bias_heads,
380
381
            q_head_dim,
            v_head_dim,
382
383
            config.scaling_factor,
            config.dropout_probability,
Reese Wang's avatar
Reese Wang committed
384
385
            config.attn_bias_type.value,
            config.attn_mask_type.value,
386
            config.softmax_type.value,
Reese Wang's avatar
Reese Wang committed
387
            config.qkv_layout.value,
388
            jax_dtype_to_te_dtype(q_aval.dtype),
389
390
            config.is_training,
            config.max_segments_per_seq,
391
392
            config.window_size[0],
            config.window_size[1],
393
394
395
396
        )
        wkspace_aval = q_aval.update(
            shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1])
        )
397

398
399
400
401
402
403
        assert softmax_offset_aval.dtype == jnp.float32
        if config.softmax_type != AttnSoftmaxType.VANILLA_SOFTMAX:
            assert softmax_offset_aval.shape == (1, attn_heads, 1, 1)
        else:
            assert softmax_offset_aval.shape == (0,)

404
405
406
407
408
409
410
        return out_aval, softmax_aux_aval, rng_state_aval, wkspace_aval

    @staticmethod
    def outer_abstract(*args, **kwargs):
        """
        Fused attention fwd outer primitive abstract
        """
411
412
413
        out_aval, softmax_aux_aval, rng_state_aval, _ = FusedAttnFwdPrimitive.abstract(
            *args, **kwargs
        )
414
415
416
        return out_aval, softmax_aux_aval, rng_state_aval

    @staticmethod
417
418
419
420
421
422
    def lowering(
        ctx,
        q,
        k,
        v,
        bias,
423
        softmax_offset,
424
        seed,
425
426
        q_cu_seqlen,
        kv_cu_seqlen,
427
428
        q_seq_offsets,
        k_seq_offsets,
429
430
431
432
        _q_segment_ids,
        _kv_segment_ids,
        _q_segment_pos,
        _kv_segment_pos,
433
        *,
434
        config: _FusedAttnConfig,
435
    ):
436
437
438
439
440
        """
        Fused attention fwd lowering rules
        """
        q_aval, k_aval, v_aval, bias_aval, *_ = ctx.avals_in

441
442
443
444
445
446
447
448
449
        (
            batch_shape,
            q_max_seqlen,
            kv_max_seqlen,
            attn_heads,
            num_gqa_groups,
            q_head_dim,
            v_head_dim,
        ) = FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout)
450
451
452

        input_batch = reduce(operator.mul, batch_shape)

Reese Wang's avatar
Reese Wang committed
453
        if config.attn_bias_type == AttnBiasType.NO_BIAS:
454
455
456
457
458
            bias_batch = bias_heads = 0
        else:
            *bias_batch_shape, bias_heads, _, _ = bias_aval.shape
            bias_batch = reduce(operator.mul, bias_batch_shape)

459
460
461
462
463
464
465
        if config.cp_striped_window_size is not None:
            window_size_left = config.cp_striped_window_size[0]
            window_size_right = config.cp_striped_window_size[1]
        else:
            window_size_left = config.window_size[0]
            window_size_right = config.window_size[1]

466
467
468
469
470
471
        return ffi.ffi_lowering(FusedAttnFwdPrimitive.name)(
            ctx,
            q,
            k,
            v,
            bias,
472
            softmax_offset,
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
            seed,
            q_cu_seqlen,
            kv_cu_seqlen,
            q_seq_offsets,
            k_seq_offsets,
            _q_segment_ids,
            _kv_segment_ids,
            _q_segment_pos,
            _kv_segment_pos,  # ffi_lowering needs number of parameters meets primitive.lowering
            input_batch=input_batch,
            bias_batch=bias_batch,
            q_max_seqlen=q_max_seqlen,
            kv_max_seqlen=kv_max_seqlen,
            attn_heads=attn_heads,
            num_gqa_groups=num_gqa_groups,
            bias_heads=bias_heads,
489
490
            qk_head_dim=q_head_dim,
            v_head_dim=v_head_dim,
491
492
493
494
495
496
497
498
            max_segments_per_seq=config.max_segments_per_seq,
            scaling_factor=float(config.scaling_factor),
            dropout_probability=float(config.dropout_probability),
            bias_type=int(config.attn_bias_type.value),
            mask_type=int(config.attn_mask_type.value),
            qkv_layout=int(config.qkv_layout.value),
            is_training=config.is_training,
            deterministic=not FusedAttnHelper.is_non_deterministic_allowed(),
499
500
            window_size_left=window_size_left,
            window_size_right=window_size_right,
501
            softmax_type=int(config.softmax_type.value),
502
        )
503
504

    @staticmethod
505
506
507
508
509
    def impl(
        q,
        k,
        v,
        bias,
510
        softmax_offset,
511
        seed,
512
513
        q_seqlen,
        kv_seqlen,
514
515
        q_seq_offsets,
        k_seq_offsets,
516
517
518
519
        _q_segment_ids,
        _kv_segment_ids,
        _q_segment_pos,
        _kv_segment_pos,
520
        config: _FusedAttnConfig,
521
    ):
522
523
        assert FusedAttnFwdPrimitive.inner_primitive is not None

524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
        sequence_descriptor = SequenceDescriptor(
            seqlens=(q_seqlen, kv_seqlen),
            seq_offsets=(q_seq_offsets, k_seq_offsets),
            segment_ids=(_q_segment_ids, _kv_segment_ids),
            segment_pos=(_q_segment_pos, _kv_segment_pos),
        )

        (q_seqlen, kv_seqlen), (q_seq_offsets, k_seq_offsets) = (
            sequence_descriptor.get_seqlens_and_offsets(
                config.attn_mask_type,
                config.qkv_layout,
                config.window_size,
                config.max_segments_per_seq,
            )
        )

Reese Wang's avatar
Reese Wang committed
540
        if config.qkv_layout.is_thd():
541

542
            def _fix_len_take(x, condition, fill_value=-1):
543
544
545
546
                x_shape = x.shape
                x = x.flatten()
                size = x.size
                indices = jnp.nonzero(condition.flatten(), size=size, fill_value=size)[0]
547
                y = jnp.take(x, indices, fill_value=fill_value)
548
549
550
551
552
553
554
555
556
557
                return jnp.reshape(y, x_shape)

            def convert_to_2d(offsets, batch, max_seqlen):
                offsets_2d = jnp.where(
                    offsets >= 0,
                    offsets + (jnp.arange(batch) * max_seqlen)[..., jnp.newaxis],
                    offsets,
                )
                return offsets_2d

Reese Wang's avatar
Reese Wang committed
558
559
560
561
562
            batch, q_max_seqlen, kv_max_seqlen, *_ = FusedAttnHelper.parse_qkv_aval(
                q, k, v, config.qkv_layout
            )
            assert len(batch) == 1, f"Expected len(batch) == 1, but got {len(batch)=}"
            kv_batch = q_batch = batch[0]
563
564

            # Gather valid q_seqlen, which is greater than 0
565
            # cuDNN version < 9.3.0:
566
            # [[3, 5, 7, -1, -1], [2, 4, 6, -1, -1]] -> [[3, 5, 7, 2, 4], [6, -1, -1, -1, -1]]
567
568
569
570
571
572
            # cuDNN version >= 9.3.0, which supports act_seqlen = 0
            # [[3, 5, 7, -1, -1], [2, 4, 6, -1, -1]] -> [[3, 5, 7, 2, 4], [6, 0, 0, 0, 0]]
            if get_cudnn_version() >= (9, 3, 0):
                fill_value = 0
            else:
                fill_value = -1
573

574
575
            q_seqlen = _fix_len_take(q_seqlen, q_seqlen > 0, fill_value=fill_value)
            kv_seqlen = _fix_len_take(kv_seqlen, kv_seqlen > 0, fill_value=fill_value)
576
577
578
579
580

            # Flatten the offset calculation
            # max_seqlen = 8, [[0, 3, 5, -1], [0, 2, 4, -1]] -> [[0, 3, 5, -1], [8, 11, 13, -1]]
            q_seq_offsets = convert_to_2d(q_seq_offsets, q_batch, q_max_seqlen)
            k_seq_offsets = convert_to_2d(k_seq_offsets, kv_batch, kv_max_seqlen)
581

582
583
            # Gather valid q_seq_offsets, which is greater and equal to 0
            # [[0, 3, 5, -1], [8, 11, 13, -1]] -> [[0, 3, 5, 8], [11, 13, -1, -1]]
584
            # And set the unused position to max size (batch * max_seqlen)
585
            # [[0, 3, 5, 8], [11, 13, -1, -1]] -> [[0, 3, 5, 8], [11, 13, b*s, b*s]]
586
587
588
589
590
591
            q_seq_offsets = _fix_len_take(
                q_seq_offsets, q_seq_offsets >= 0, fill_value=q_batch * q_max_seqlen
            )
            k_seq_offsets = _fix_len_take(
                k_seq_offsets, k_seq_offsets >= 0, fill_value=kv_batch * kv_max_seqlen
            )
592
593
594

        q_cu_seqlen = generate_cu_seqlen(q_seqlen.flatten())
        kv_cu_seqlen = generate_cu_seqlen(kv_seqlen.flatten())
595
596
597
598
599
600

        output, softmax_aux, rng_state, _ = FusedAttnFwdPrimitive.inner_primitive.bind(
            q,
            k,
            v,
            bias,
601
            softmax_offset,
602
            seed,
603
604
            q_cu_seqlen,
            kv_cu_seqlen,
605
606
            q_seq_offsets,
            k_seq_offsets,
607
608
609
610
            _q_segment_ids,
            _kv_segment_ids,
            _q_segment_pos,
            _kv_segment_pos,
611
            config=config,
612
        )
613
614
615
        return output, softmax_aux, rng_state

    @staticmethod
616
    def batcher(batched_args, batch_dims, *, config):
617
618
        check_valid_batch_dims(batch_dims)
        assert FusedAttnFwdPrimitive.outer_primitive is not None
619
        q_bdim, _, _, _, _, seed_bdim, *_ = batch_dims
620
621

        out_bdims = q_bdim, q_bdim, seed_bdim
622
        return (
623
            FusedAttnFwdPrimitive.outer_primitive.bind(*batched_args, config=config),
624
625
            out_bdims,
        )
626
627

    @staticmethod
628
629
    def infer_sharding_from_operands(config, mesh, arg_infos, result_infos):
        del result_infos
630
        q_spec = get_padded_spec(arg_infos[0])
631
632
633
634
635

        # when supported softmax_aux shape is (b, s, h, 1) for thd on cudnn 9.6+
        # otherwise softmax_aux shape is (b, h, s, 1) or (b, h, s, max_segments)
        is_packed_softmax = get_cudnn_version() >= (9, 6, 0) and config.qkv_layout.is_thd()

Reese Wang's avatar
Reese Wang committed
636
637
638
        if config.qkv_layout.is_qkvpacked():
            # q_spec = (...batch, q_seqlen, 3, head, hidden)
            out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec[:-3], *q_spec[-2:]))
639
640
641
642
643
644
645
646
            if not is_packed_softmax:
                softmax_aux_sharding = NamedSharding(
                    mesh, PartitionSpec(*q_spec[:-4], q_spec[-2], q_spec[-4], None)
                )
            else:
                softmax_aux_sharding = NamedSharding(
                    mesh, PartitionSpec(*q_spec[:-4], q_spec[-4], q_spec[-2], None)
                )
Reese Wang's avatar
Reese Wang committed
647
648
649
650
        elif config.qkv_layout.is_kvpacked():
            # q_spec = (...batch, q_seqlen, head, hidden)
            # k_spec = (...batch, kv_seqlen, 2, num_gqa_groups, hidden)
            out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
651
652
653
654
655
656
657
658
            if not is_packed_softmax:
                softmax_aux_sharding = NamedSharding(
                    mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], None)
                )
            else:
                softmax_aux_sharding = NamedSharding(
                    mesh, PartitionSpec(*q_spec[:-3], q_spec[-3], q_spec[-2], None)
                )
Reese Wang's avatar
Reese Wang committed
659
660
661
662
        elif config.qkv_layout.is_separate():
            # q_spec = (...batch, q_seqlen, head, hidden)
            # k_spec = (...batch, kv_seqlen, num_gqa_groups, hidden)
            out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
663
664
665
666
667
668
669
670
            if not is_packed_softmax:
                softmax_aux_sharding = NamedSharding(
                    mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], None)
                )
            else:
                softmax_aux_sharding = NamedSharding(
                    mesh, PartitionSpec(*q_spec[:-3], q_spec[-3], q_spec[-2], None)
                )
Reese Wang's avatar
Reese Wang committed
671
672
        else:
            raise ValueError(f"Unsupported {config.qkv_layout=}")
673

674
675
676
677
        rng_state_sharding = NamedSharding(mesh, PartitionSpec(get_all_mesh_axes(), None))
        return (out_sharding, softmax_aux_sharding, rng_state_sharding)

    @staticmethod
678
    def partition(config, mesh, arg_infos, result_infos):
679
680
        out_sharding = result_infos[0].sharding
        softmax_aux_sharding = result_infos[1].sharding
681
682
683
        rng_state_sharding = seed_sharding = NamedSharding(
            mesh, PartitionSpec(get_all_mesh_axes(), None)
        )
684
        arg_shardings = [arg_i.sharding for arg_i in arg_infos]
685
        arg_shardings[5] = seed_sharding
686
687
        arg_shardings[-1] = arg_shardings[-3]
        arg_shardings[-2] = arg_shardings[-4]
688
        arg_shardings = tuple(arg_shardings)
689
        out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding)
690
        impl = partial(FusedAttnFwdPrimitive.impl, config=config)
691
692
        return mesh, impl, out_shardings, arg_shardings

693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
    @staticmethod
    def shardy_sharding_rule(config, mesh, value_types, result_types):
        del mesh, result_types

        # Keep in sync with `infer_sharding_from_operands`.
        # We only need the first input. Fill up the rest with placeholders.
        input_spec = [(f"…{x}",) for x in range(len(value_types))]
        # The RNG state sharding cannot be expressed as a Shardy rule. We use with_sharding_constraint
        # instead. This has to happen outside of the primitive, see `fused_attn_fwd`.
        rng_sharding = (f"…{len(value_types)}",)

        if config.qkv_layout.is_qkvpacked():
            input_spec[0] = ("…0", "seqlen", "three", "head", "hidden")
        elif config.qkv_layout.is_kvpacked() or config.qkv_layout.is_separate():
            input_spec[0] = ("…0", "seqlen", "head", "hidden")
        else:
            raise ValueError(f"Unsupported {config.qkv_layout=}")

        is_packed_softmax = get_cudnn_version() >= (9, 6, 0) and config.qkv_layout.is_thd()
        out_sharding = ("…0", "seqlen", "head", "hidden")
        if is_packed_softmax:
            softmax_aux_sharding = ("…0", "seqlen", "head", "i")
        else:
            softmax_aux_sharding = ("…0", "head", "seqlen", "i")

        return SdyShardingRule(
            tuple(input_spec), (out_sharding, softmax_aux_sharding, rng_sharding)
        )

722
723
724
725
726
727
728
729

register_primitive(FusedAttnFwdPrimitive)


class FusedAttnBwdPrimitive(BasePrimitive):
    """
    Fused Attention Backward Primitive
    """
730

731
    name = "te_fused_attn_backward_ffi"
732
    multiple_results = True
733
    impl_static_args = (17,)
734
735
736
737
    inner_primitive = None
    outer_primitive = None

    @staticmethod
738
739
740
741
742
    def abstract(
        q_aval,
        k_aval,
        v_aval,
        bias_aval,
743
        softmax_offset_aval,
744
745
746
747
        softmax_aux_aval,
        rng_state_aval,
        output_aval,
        doutput_aval,
748
749
750
751
        q_seqlen_or_cu_seqlen_aval,
        kv_seqlen_or_cu_seqlen_aval,
        _q_seq_offsets,
        _k_seq_offsets,
752
753
754
755
        _q_segment_ids,
        _kv_segment_ids,
        _q_segment_pos,
        _kv_segment_pos,
756
        *,
757
        config,
758
    ):
759
760
761
762
763
764
765
766
767
768
769
        """
        Fused attention bwd abstract
        """
        del softmax_aux_aval, rng_state_aval, output_aval

        q_dtype = dtypes.canonicalize_dtype(q_aval.dtype)
        k_dtype = dtypes.canonicalize_dtype(k_aval.dtype)
        v_dtype = dtypes.canonicalize_dtype(v_aval.dtype)
        bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype)
        doutput_dtype = dtypes.canonicalize_dtype(doutput_aval.dtype)
        assert q_dtype == k_dtype == v_dtype == bias_dtype == doutput_dtype
770
        assert q_seqlen_or_cu_seqlen_aval.dtype == kv_seqlen_or_cu_seqlen_aval.dtype
771

772
773
774
775
776
777
778
779
780
        (
            batch_shape,
            q_max_seqlen,
            kv_max_seqlen,
            attn_heads,
            num_gqa_groups,
            qk_head_dim,
            v_head_dim,
        ) = FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout)
781

Reese Wang's avatar
Reese Wang committed
782
        if config.attn_bias_type == AttnBiasType.NO_BIAS:
783
784
785
786
787
            bias_batch = bias_heads = 0
        else:
            *bias_batch_shape, bias_heads, _, _ = bias_aval.shape
            bias_batch = reduce(operator.mul, bias_batch_shape)

788
789
        deterministic = not FusedAttnHelper.is_non_deterministic_allowed()

790
        input_batch = reduce(operator.mul, batch_shape)
791
792
793
794
795
796
797
798
        wkspace_shape, wkspace_dtype = transformer_engine_jax.get_fused_attn_bwd_workspace_sizes(
            input_batch,
            bias_batch,
            q_max_seqlen,
            kv_max_seqlen,
            attn_heads,
            num_gqa_groups,
            bias_heads,
799
800
            qk_head_dim,
            v_head_dim,
801
802
            config.scaling_factor,
            config.dropout_probability,
Reese Wang's avatar
Reese Wang committed
803
804
            config.attn_bias_type.value,
            config.attn_mask_type.value,
805
            config.softmax_type.value,
Reese Wang's avatar
Reese Wang committed
806
            config.qkv_layout.value,
807
            jax_dtype_to_te_dtype(q_aval.dtype),
808
            config.is_training,
809
            deterministic,
810
            config.max_segments_per_seq,
811
812
            config.window_size[0],
            config.window_size[1],
813
        )
814
815
816
817
818

        dq_aval = q_aval.update(shape=q_aval.shape, dtype=q_dtype)
        dk_aval = k_aval.update(shape=k_aval.shape, dtype=k_dtype)
        dv_aval = v_aval.update(shape=v_aval.shape, dtype=v_dtype)
        dbias_aval = bias_aval.update(shape=bias_aval.shape, dtype=bias_dtype)
819
820
821
        wkspace_aval = q_aval.update(
            shape=wkspace_shape, dtype=te_dtype_to_jax_dtype(wkspace_dtype)
        )
822

823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
        # Validate incoming softmax_offset shape and dtype
        assert (
            softmax_offset_aval.dtype == jnp.float32
        ), f"Incorrect softmax_offset dtype: {softmax_offset_aval.dtype}, expected: {jnp.float32}"
        if config.softmax_type != AttnSoftmaxType.VANILLA_SOFTMAX:
            assert softmax_offset_aval.shape == (1, attn_heads, 1, 1), (
                f"Incorrect softmax_offset shape for {config.softmax_type}:"
                f" {softmax_offset_aval.shape}, expected: (1, {attn_heads}, 1, 1)"
            )
        else:
            assert softmax_offset_aval.shape == (0,), (
                f"Incorrect softmax_offset shape for {config.softmax_type}:"
                f" {softmax_offset_aval.shape}, expected: (0,)"
            )

        if config.softmax_type == AttnSoftmaxType.VANILLA_SOFTMAX:
            dsoftmax_offset_aval = q_aval.update(
                shape=softmax_offset_aval.shape, dtype=softmax_offset_aval.dtype
            )
        else:
            dsoftmax_offset_aval = q_aval.update(shape=(1, attn_heads, 1, 1), dtype=jnp.float32)

        return dq_aval, dk_aval, dv_aval, dbias_aval, dsoftmax_offset_aval, wkspace_aval
846
847
848
849
850
851

    @staticmethod
    def outer_abstract(*args, **kwargs):
        """
        Fused attention fwd outer primitive abstract
        """
852
853
854
855
        dq_aval, dk_aval, dv_aval, dbias_aval, dsoftmax_offset_aval, _ = (
            FusedAttnBwdPrimitive.abstract(*args, **kwargs)
        )
        return dq_aval, dk_aval, dv_aval, dbias_aval, dsoftmax_offset_aval
856
857

    @staticmethod
858
859
860
861
862
863
    def lowering(
        ctx,
        q,
        k,
        v,
        bias,
864
        softmax_offset,
865
866
867
868
869
870
        softmax_aux,
        rng_state,
        output,
        doutput,
        q_cu_seqlen,
        kv_cu_seqlen,
871
872
        q_seq_offsets,
        k_seq_offsets,
873
874
875
876
        q_segment_ids,
        kv_segment_ids,
        q_segment_pos,
        kv_segment_pos,
877
        *,
878
        config,
879
    ):
880
881
882
883
884
        """
        Fused attention bwd lowering rules
        """
        q_aval, k_aval, v_aval, bias_aval, *_ = ctx.avals_in

885
886
887
888
889
890
891
892
893
        (
            batch_shape,
            q_max_seqlen,
            kv_max_seqlen,
            attn_heads,
            num_gqa_groups,
            qk_head_dim,
            v_head_dim,
        ) = FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout)
894
895
896

        input_batch = reduce(operator.mul, batch_shape)

Reese Wang's avatar
Reese Wang committed
897
        if config.attn_bias_type == AttnBiasType.NO_BIAS:
898
899
900
901
902
            bias_batch = bias_heads = 0
        else:
            *bias_batch_shape, bias_heads, _, _ = bias_aval.shape
            bias_batch = reduce(operator.mul, bias_batch_shape)

903
904
905
906
907
908
909
        if config.cp_striped_window_size is not None:
            window_size_left = config.cp_striped_window_size[0]
            window_size_right = config.cp_striped_window_size[1]
        else:
            window_size_left = config.window_size[0]
            window_size_right = config.window_size[1]

910
911
912
913
914
915
        return ffi.ffi_lowering(FusedAttnBwdPrimitive.name)(
            ctx,
            q,
            k,
            v,
            bias,
916
            softmax_offset,
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
            softmax_aux,
            rng_state,
            output,
            doutput,
            q_cu_seqlen,
            kv_cu_seqlen,
            q_seq_offsets,
            k_seq_offsets,
            q_segment_ids,
            kv_segment_ids,
            q_segment_pos,
            kv_segment_pos,  # ffi_lowering needs number of parameters meets primitive.lowering
            input_batch=input_batch,
            bias_batch=bias_batch,
            q_max_seqlen=q_max_seqlen,
            kv_max_seqlen=kv_max_seqlen,
            attn_heads=attn_heads,
            num_gqa_groups=num_gqa_groups,
            bias_heads=bias_heads,
936
937
            qk_head_dim=qk_head_dim,
            v_head_dim=v_head_dim,
938
939
940
941
942
943
944
945
            max_segments_per_seq=config.max_segments_per_seq,
            scaling_factor=float(config.scaling_factor),
            dropout_probability=float(config.dropout_probability),
            bias_type=int(config.attn_bias_type.value),
            mask_type=int(config.attn_mask_type.value),
            qkv_layout=int(config.qkv_layout.value),
            is_training=config.is_training,
            deterministic=not FusedAttnHelper.is_non_deterministic_allowed(),
946
947
            window_size_left=window_size_left,
            window_size_right=window_size_right,
948
            softmax_type=int(config.softmax_type.value),
949
        )
950
951

    @staticmethod
952
953
954
955
956
    def impl(
        q,
        k,
        v,
        bias,
957
        softmax_offset,
958
959
960
961
962
963
        softmax_aux,
        rng_state,
        output,
        doutput,
        q_seqlen,
        kv_seqlen,
964
965
        q_seq_offsets,
        k_seq_offsets,
966
967
968
969
        _q_segment_ids,
        _kv_segment_ids,
        _q_segment_pos,
        _kv_segment_pos,
970
        config,
971
    ):
972
973
        assert FusedAttnBwdPrimitive.inner_primitive is not None

974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
        sequence_descriptor = SequenceDescriptor(
            seqlens=(q_seqlen, kv_seqlen),
            seq_offsets=(q_seq_offsets, k_seq_offsets),
            segment_ids=(_q_segment_ids, _kv_segment_ids),
            segment_pos=(_q_segment_pos, _kv_segment_pos),
        )

        (q_seqlen, kv_seqlen), (q_seq_offsets, k_seq_offsets) = (
            sequence_descriptor.get_seqlens_and_offsets(
                config.attn_mask_type,
                config.qkv_layout,
                config.window_size,
                config.max_segments_per_seq,
            )
        )

Reese Wang's avatar
Reese Wang committed
990
        if config.qkv_layout.is_thd():
991

992
            def _fix_len_take(x, condition, fill_value=-1):
993
994
995
996
997
                x_shape = x.shape
                x = x.flatten()
                size = x.size
                indices = jnp.nonzero(condition.flatten(), size=size, fill_value=size)[0]
                # TODO(rewang): try indices_are_sorted
998
                y = jnp.take(x, indices, fill_value=fill_value)
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
                return jnp.reshape(y, x_shape)

            def convert_to_2d(offsets, batch, max_seqlen):
                offsets_2d = jnp.where(
                    offsets >= 0,
                    offsets + (jnp.arange(batch) * max_seqlen)[..., jnp.newaxis],
                    offsets,
                )
                return offsets_2d

Reese Wang's avatar
Reese Wang committed
1009
1010
1011
1012
1013
            batch, q_max_seqlen, kv_max_seqlen, *_ = FusedAttnHelper.parse_qkv_aval(
                q, k, v, config.qkv_layout
            )
            assert len(batch) == 1
            kv_batch = q_batch = batch[0]
1014
1015

            # Gather valid q_seqlen, which is greater than 0
1016
            # cuDNN version < 9.3.0:
1017
            # [[3, 5, 7, -1, -1], [2, 4, 6, -1, -1]] -> [[3, 5, 7, 2, 4], [6, -1, -1, -1, -1]]
1018
1019
1020
1021
1022
1023
1024
1025
            # cuDNN version >= 9.3.0, which supports act_seqlen = 0
            # [[3, 5, 7, -1, -1], [2, 4, 6, -1, -1]] -> [[3, 5, 7, 2, 4], [6, 0, 0, 0, 0]]
            if get_cudnn_version() >= (9, 3, 0):
                fill_value = 0
            else:
                fill_value = -1
            q_seqlen = _fix_len_take(q_seqlen, q_seqlen > 0, fill_value=fill_value)
            kv_seqlen = _fix_len_take(kv_seqlen, kv_seqlen > 0, fill_value=fill_value)
1026
1027
1028
1029
1030

            # Flatten the offset calculation
            # max_seqlen = 8, [[0, 3, 5, -1], [0, 2, 4, -1]] -> [[0, 3, 5, -1], [8, 11, 13, -1]]
            q_seq_offsets = convert_to_2d(q_seq_offsets, q_batch, q_max_seqlen)
            k_seq_offsets = convert_to_2d(k_seq_offsets, kv_batch, kv_max_seqlen)
1031

1032
1033
            # Gather valid q_seq_offsets, which is greater and equal to 0
            # [[0, 3, 5, -1], [8, 11, 13, -1]] -> [[0, 3, 5, 8], [11, 13, -1, -1]]
1034
            # And set the unused position to max size (batch * max_seqlen)
1035
            # [[0, 3, 5, 8], [11, 13, -1, -1]] -> [[0, 3, 5, 8], [11, 13, b*s, b*s]]
1036
1037
1038
1039
1040
1041
            q_seq_offsets = _fix_len_take(
                q_seq_offsets, q_seq_offsets >= 0, fill_value=q_batch * q_max_seqlen
            )
            k_seq_offsets = _fix_len_take(
                k_seq_offsets, k_seq_offsets >= 0, fill_value=kv_batch * kv_max_seqlen
            )
1042
1043
1044

        q_cu_seqlen = generate_cu_seqlen(q_seqlen.flatten())
        kv_cu_seqlen = generate_cu_seqlen(kv_seqlen.flatten())
1045

1046
        dq, dk, dv, dbias, dsoftmax_offset, _ = FusedAttnBwdPrimitive.inner_primitive.bind(
1047
1048
1049
1050
            q,
            k,
            v,
            bias,
1051
            softmax_offset,
1052
1053
1054
1055
1056
1057
            softmax_aux,
            rng_state,
            output,
            doutput,
            q_cu_seqlen,
            kv_cu_seqlen,
1058
1059
            q_seq_offsets,
            k_seq_offsets,
1060
1061
1062
1063
            _q_segment_ids,
            _kv_segment_ids,
            _q_segment_pos,
            _kv_segment_pos,
1064
            config=config,
1065
        )
1066
        return dq, dk, dv, dbias, dsoftmax_offset
1067
1068

    @staticmethod
1069
    def batcher(batched_args, batch_dims, *, config):
1070
1071
        check_valid_batch_dims(batch_dims)
        assert FusedAttnBwdPrimitive.outer_primitive is not None
1072
        q_bdim, k_bdim, v_bdim, bias_bdim, softmax_offset_bdim, *_ = batch_dims
1073

1074
        out_bdims = q_bdim, k_bdim, v_bdim, bias_bdim, softmax_offset_bdim
1075
        return (
1076
            FusedAttnBwdPrimitive.outer_primitive.bind(*batched_args, config=config),
1077
1078
            out_bdims,
        )
1079
1080

    @staticmethod
1081
1082
    def infer_sharding_from_operands(config, mesh, arg_infos, result_infos):
        del config, result_infos
1083
1084
1085
1086
        q_spec = get_padded_spec(arg_infos[0])
        k_spec = get_padded_spec(arg_infos[1])
        v_spec = get_padded_spec(arg_infos[2])
        bias_spec = get_padded_spec(arg_infos[3])
1087
        softmax_offset_spec = get_padded_spec(arg_infos[4])
1088
1089
1090
1091
        dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
        dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec))
        dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec))
        dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec))
1092
1093
        dsoftmax_offset_sharding = NamedSharding(mesh, PartitionSpec(*softmax_offset_spec))
        return (dq_sharding, dk_sharding, dv_sharding, dbias_sharding, dsoftmax_offset_sharding)
1094
1095

    @staticmethod
1096
    def partition(config, mesh, arg_infos, result_infos):
1097
1098
1099
1100
1101
        del result_infos
        q_spec = get_padded_spec(arg_infos[0])
        k_spec = get_padded_spec(arg_infos[1])
        v_spec = get_padded_spec(arg_infos[2])
        bias_spec = get_padded_spec(arg_infos[3])
1102
        softmax_offset_spec = get_padded_spec(arg_infos[4])
1103
1104
1105
1106
        dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
        dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec))
        dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec))
        dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec))
1107
        dsoftmax_offset_sharding = NamedSharding(mesh, PartitionSpec(*softmax_offset_spec))
1108
1109
1110
1111
        arg_shardings = [arg_i.sharding for arg_i in arg_infos]
        arg_shardings[-1] = arg_shardings[-3]
        arg_shardings[-2] = arg_shardings[-4]
        arg_shardings = tuple(arg_shardings)
1112
1113
1114
1115
1116
1117
1118
        out_shardings = (
            dq_sharding,
            dk_sharding,
            dv_sharding,
            dbias_sharding,
            dsoftmax_offset_sharding,
        )
1119

1120
        def sharded_impl(
1121
1122
1123
1124
            q,
            k,
            v,
            bias,
1125
            softmax_offset,
1126
1127
1128
1129
1130
1131
1132
1133
            softmax_aux,
            rng_state,
            output,
            doutput,
            q_cu_seqlen,
            kv_cu_seqlen,
            q_seq_offsets,
            k_seq_offsets,
1134
1135
1136
1137
            _q_segment_ids,
            _kv_segment_ids,
            _q_segment_pos,
            _kv_segment_pos,
1138
        ):
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
            local_dq, local_dk, local_dv, local_dbias, local_dsoftmax_offset = (
                FusedAttnBwdPrimitive.impl(
                    q,
                    k,
                    v,
                    bias,
                    softmax_offset,
                    softmax_aux,
                    rng_state,
                    output,
                    doutput,
                    q_cu_seqlen,
                    kv_cu_seqlen,
                    q_seq_offsets,
                    k_seq_offsets,
                    _q_segment_ids,
                    _kv_segment_ids,
                    _q_segment_pos,
                    _kv_segment_pos,
                    config=config,
                )
1160
            )
1161
            global_dbias = local_dbias
Reese Wang's avatar
Reese Wang committed
1162
            if config.attn_bias_type is not AttnBiasType.NO_BIAS:
1163
                global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh)
1164
1165
1166
1167
1168
1169

            global_dsoftmax_offset = local_dsoftmax_offset
            if config.softmax_type == AttnSoftmaxType.LEARNABLE_SOFTMAX:
                global_dsoftmax_offset = all_reduce_sum_along_dp_fsdp(local_dsoftmax_offset, mesh)

            return local_dq, local_dk, local_dv, global_dbias, global_dsoftmax_offset
1170
1171
1172

        return mesh, sharded_impl, out_shardings, arg_shardings

1173
1174
1175
1176
1177
1178
1179
1180
    @staticmethod
    def shardy_sharding_rule(config, mesh, value_types, result_types):
        del config, mesh
        # Keep in sync with `infer_sharding_from_operands`.
        input_spec = tuple((f"…{x}",) for x in range(len(value_types)))
        output_spec = tuple((f"…{x}",) for x in range(len(result_types)))
        return SdyShardingRule(input_spec, output_spec)

1181
1182
1183
1184

register_primitive(FusedAttnBwdPrimitive)


Reese Wang's avatar
Reese Wang committed
1185
def reorder_causal_dual_chunk_swap(tensor, cp_size: int, seq_dim: int, to_contiguous: bool):
1186
1187
1188
1189
1190
1191
1192
1193
1194
    """Reorders a tensor for load balancing the compute of causal attention."""
    if cp_size == 1:
        return tensor

    if cp_size % 2 != 0:
        raise ValueError(f"{cp_size=} must be a multiple of 2.")

    # Need to ensure we have 2 pairs to swap for balancing between cp ranks
    if tensor.shape[seq_dim] % (cp_size * 2) != 0:
Reese Wang's avatar
Reese Wang committed
1195
        raise ValueError(f"{tensor.shape[seq_dim]=} is not a multiple of {cp_size*2=}")
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236

    # [B, S, H, D] -> [B, 2*cp_size, S/2*cp_size, D]
    # [S, B, H, D] -> [2*cp_size, S/2*cp_size, B, H, D]
    ori_tensor_shape = tensor.shape
    tensor = tensor.reshape(
        (
            *ori_tensor_shape[:seq_dim],
            2 * cp_size,
            ori_tensor_shape[seq_dim] // (2 * cp_size),
            *ori_tensor_shape[seq_dim + 1 :],
        )
    )

    parts = []
    if not to_contiguous:
        for cp_rank in range(cp_size):
            # [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] -> [B, 2, S/2*cp_size, H, D]
            # [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] -> [2, S/2*cp_size, B, H, D]
            index = jnp.array([cp_rank, (2 * cp_size - cp_rank - 1)])
            parts.append(jnp.take(tensor, index, axis=seq_dim))
    else:
        for cp_rank in range(cp_size // 2):
            # [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] -> [B, 2, S/2*cp_size, H, D]
            # [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] -> [2, S/2*cp_size, B, H, D]
            base = 4 * cp_rank
            index = jnp.array([base, base + 2])
            parts.append(jnp.take(tensor, index, axis=seq_dim))
        for cp_rank in range(cp_size // 2):
            # [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] -> [B, 2, S/2*cp_size, H, D]
            # [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] -> [2, S/2*cp_size, B, H, D]
            base = 2 * cp_size - 1 - 4 * cp_rank
            index = jnp.array([base, base - 2])
            parts.append(jnp.take(tensor, index, axis=seq_dim))

    # [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D]
    # [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D]
    combined = jnp.stack(parts, axis=seq_dim)

    return combined.reshape(ori_tensor_shape)


Reese Wang's avatar
Reese Wang committed
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
def reorder_causal_striped(tensor, cp_size: int, seq_dim: int, is_inverse: bool):
    """Reorders a tensor for load balancing with striped pattern"""
    origin_shape = tensor.shape
    if origin_shape[seq_dim] % cp_size != 0:
        raise ValueError(
            "Expected origin_shape[seq_dim] is multiple of cp_size but got"
            f" {origin_shape[seq_dim]=} and {cp_size=}"
        )

    if not is_inverse:
        new_shape = [
            *origin_shape[:seq_dim],
            *[origin_shape[seq_dim] // cp_size, cp_size],
            *origin_shape[seq_dim + 1 :],
        ]
    else:
        new_shape = [
            *origin_shape[:seq_dim],
            *[cp_size, origin_shape[seq_dim] // cp_size],
            *origin_shape[seq_dim + 1 :],
        ]

    chunked_tensor = tensor.reshape(new_shape)
    reordered_chunked_tensor = jnp.swapaxes(chunked_tensor, seq_dim, seq_dim + 1)
    return reordered_chunked_tensor.reshape(origin_shape)


1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
@dataclass(frozen=True)
class _FusedAttnCPWithAllGatherHelper:
    """Helper class to assist with running the all-gather strategy for CP attention."""

    mesh: jax.sharding.Mesh
    config: _FusedAttnConfig

    def check_supported(self):
        """Checks if the context parallel implementation is supported by the given arguments."""
        header = "Context parallel fused attention"

Reese Wang's avatar
Reese Wang committed
1275
        allowed_layouts = [QKVLayout.BSHD_BS2HD, QKVLayout.BSHD_BSHD_BSHD]
1276
1277
1278
        if self.config.qkv_layout not in allowed_layouts:
            raise ValueError(
                f"{header} only supports layouts:"
1279
                f" {','.join(map(str, allowed_layouts))} got: {self.config.qkv_layout}"
1280
            )
1281

Reese Wang's avatar
Reese Wang committed
1282
        if self.config.attn_bias_type != AttnBiasType.NO_BIAS:
1283
            raise ValueError(f"{header} does not support bias got: {self.config.attn_bias_type}")
1284

Reese Wang's avatar
Reese Wang committed
1285
        allowed_masks = [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]
1286
1287
1288
        if self.config.attn_mask_type not in allowed_masks:
            raise ValueError(
                f"{header} only supports masking types: "
1289
                f" {','.join(map(str, allowed_masks))} got: {self.config.attn_mask_type}"
1290
            )
1291

1292
1293
1294
1295
1296
1297
1298
1299
        if self.config.max_segments_per_seq != 1:
            raise ValueError(
                f"{header} only supports max_segments_per_seq == 1 got:"
                f" {self.config.max_segments_per_seq}"
            )

        if self.config.dropout_probability != 0.0:
            raise ValueError(f"{header} does not support dropout")
1300

1301
1302
1303
1304
1305
        if self.config.softmax_type != AttnSoftmaxType.VANILLA_SOFTMAX:
            raise ValueError(
                f"{header} only supports VANILLA_SOFTMAX, got: {self.config.softmax_type}"
            )

1306
1307
    def get_adjusted_mask(self):
        """Converts the mask for context parallelism."""
Reese Wang's avatar
Reese Wang committed
1308
1309
        if self.config.attn_mask_type == AttnMaskType.CAUSAL_MASK:
            return AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK
1310
1311
        return self.config.attn_mask_type

1312
1313
1314
1315
1316
    def get_step_config(self) -> _FusedAttnConfig:
        """Returns a _FusedAttnConfig for single CP step call to fused attention."""
        return _FusedAttnConfig(
            attn_bias_type=self.config.attn_bias_type,
            attn_mask_type=self.get_adjusted_mask(),
1317
            softmax_type=self.config.softmax_type,
1318
1319
1320
1321
1322
1323
1324
1325
            qkv_layout=self.config.qkv_layout,
            scaling_factor=self.config.scaling_factor,
            dropout_probability=self.config.dropout_probability,
            is_training=self.config.is_training,
            max_segments_per_seq=self.config.max_segments_per_seq,
            window_size=self.config.window_size,
            context_parallel_load_balanced=self.config.context_parallel_load_balanced,
            cp_axis=self.config.cp_axis,
1326
            cp_striped_window_size=None,
1327
1328
        )

1329
1330
1331
1332
    def all_gather_kv(self, k, v):
        """Performs a all-gather of k and v over context parallel ranks."""

        def ag(x):
1333
            x = lax_paral_op(
1334
1335
                x, lax.all_gather, self.config.cp_axis, mesh=self.mesh, axis=1, tiled=True
            )
1336
1337
            if self.config.context_parallel_load_balanced:
                cp_size = get_mesh_axis_size(self.config.cp_axis, self.mesh)
Reese Wang's avatar
Reese Wang committed
1338
                x = reorder_causal_dual_chunk_swap(x, cp_size, 1, to_contiguous=True)
1339
            return x
1340

Reese Wang's avatar
Reese Wang committed
1341
1342
1343
1344
        if self.config.qkv_layout.is_kvpacked():
            return ag(k), v
        if self.config.qkv_layout.is_separate():
            return ag(k), ag(v)
1345
1346
1347
1348
1349
1350
1351

        return k, v  # fall through

    def reduce_scatter_dkv(self, dk, dv):
        """Performs a reduce-scatter of dk and dv over context parallel ranks."""

        def rs(x):
1352
1353
            if self.config.context_parallel_load_balanced:
                cp_size = get_mesh_axis_size(self.config.cp_axis, self.mesh)
Reese Wang's avatar
Reese Wang committed
1354
                x = reorder_causal_dual_chunk_swap(x, cp_size, 1, to_contiguous=False)
1355

1356
1357
1358
1359
1360
1361
1362
1363
1364
            return lax_paral_op(
                x,
                lax.psum_scatter,
                self.config.cp_axis,
                mesh=self.mesh,
                scatter_dimension=1,
                tiled=True,
            )

Reese Wang's avatar
Reese Wang committed
1365
1366
1367
1368
        if self.config.qkv_layout.is_kvpacked():
            return rs(dk), dv
        if self.config.qkv_layout.is_separate():
            return rs(dk), rs(dv)
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404

        return dk, dv  # fall through

    def kv_seqlens_for_rank(self, cp_rank, kv_max_seqlen, kv_seqlen_per_subrank):
        """Returns sequence lengths of KV to use for each sub rank of the given cp_rank.

        Example: CP=4, MaxLen = 1024, Unbalanced
           cp_rank 0: [128, 256]
           cp_rank 1: [384, 512]
           cp_rank 2: [640, 768]
           cp_rank 3: [896, 1024]

        Example: CP=4, MaxLen = 1024, Balanced
           cp_rank 0: [128, 1024]
           cp_rank 1: [256, 896]
           cp_rank 2: [384, 768]
           cp_rank 3: [512, 640]
        """
        if self.config.context_parallel_load_balanced:
            kv_seq_this_rank = [
                (cp_rank + 1) * kv_seqlen_per_subrank,
                kv_max_seqlen - cp_rank * kv_seqlen_per_subrank,
            ]
        else:
            kv_seq_this_rank = [
                (cp_rank * 2 + 1) * kv_seqlen_per_subrank,
                (cp_rank * 2 + 2) * kv_seqlen_per_subrank,
            ]
        return kv_seq_this_rank

    def slice_kv(self, k, v, slice_seq_len):
        """Slices k and v tensors to a sequence length of slice_seq_len."""

        def sliced(x):
            return lax.dynamic_slice_in_dim(x, 0, slice_seq_len, axis=1)

Reese Wang's avatar
Reese Wang committed
1405
1406
1407
1408
        if self.config.qkv_layout.is_kvpacked():
            return sliced(k), v
        if self.config.qkv_layout.is_separate():
            return sliced(k), sliced(v)
1409
1410
1411
1412
1413
1414
1415
1416
1417

        return k, v  # fall through

    def pad_kv(self, dk, dv, pad_seq_len):
        """Pads dk and dv tensors to a sequence length of pad_seq_len."""

        def pad(x, npad):
            return jnp.pad(x, npad, "constant", constant_values=0.0)

Reese Wang's avatar
Reese Wang committed
1418
1419
1420
1421
1422
1423
        if self.config.qkv_layout.is_kvpacked():
            npad = [[0, 0], [0, pad_seq_len], [0, 0], [0, 0], [0, 0]]
            return pad(dk, npad), dv
        if self.config.qkv_layout.is_separate():
            npad = [[0, 0], [0, pad_seq_len], [0, 0], [0, 0]]
            return pad(dk, npad), pad(dv, npad)
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438

        return dk, dv  # fall through


class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive):
    """
    Fused Attention Forward with Context Parallelism Primitive

    This context parallel implementation uses all-gather to collect KV inputs from context parallel ranks.
    """

    @staticmethod
    def partition(config, mesh, arg_infos, result_infos):
        # Call base implementation for non-context parallel mesh to avoid unecessary work.
        is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1
1439
        assert (
1440
            not is_context_parallel or config.window_size[0] == -1
1441
        ), "Sliding window attention is not supported when context parallelism is enabled"
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
        if not is_context_parallel:
            return FusedAttnFwdPrimitive.partition(config, mesh, arg_infos, result_infos)

        helper = _FusedAttnCPWithAllGatherHelper(mesh, config)
        helper.check_supported()

        out_sharding = result_infos[0].sharding
        softmax_aux_sharding = result_infos[1].sharding
        rng_state_sharding = seed_sharding = NamedSharding(
            mesh, PartitionSpec(get_all_mesh_axes(), None)
        )
1453
        arg_shardings = [arg_i.sharding for arg_i in arg_infos]
1454
        arg_shardings[5] = seed_sharding
1455
        arg_shardings = tuple(arg_shardings)
1456
1457
        out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding)

1458
1459
1460
1461
1462
        def impl(
            q,
            k,
            v,
            bias,
1463
            softmax_offset,
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
            seed,
            q_seqlen,
            kv_seqlen,
            q_seq_offsets,
            k_seq_offsets,
            _q_segment_ids,
            _kv_segment_ids,
            _q_segment_pos,
            _kv_segment_pos,
        ):
1474
1475
1476
1477
1478
1479
1480
1481
1482
            cp_size = get_mesh_axis_size(config.cp_axis, mesh)
            cp_rank = get_mesh_axis_rank(config.cp_axis, mesh)

            # cuDNN does not support right-aligned masking with dynamic sequence length padding.
            # Therefore we must explicitly instantiate each CP rank slicing and use a runtime switch
            # to select the appropriate computation. Each case generates a [..., SEQ/CP, ..] tensor
            # meeting the expectation of the SPMD model.
            # TODO(mgoldfarb-nvidia): When cuDNN supports we should be able to make use of a padding
            # mask/sequence length tensor to avoid this unrolled loop.
1483
            def _cross_attn(idx, q, k, v, bias, softmax_offset, q_seqlen, kv_seqlen, seed):
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
                kv_max_seqlen = k.shape[1]
                kv_seqlen_per_subrank = kv_max_seqlen // (cp_size * 2)
                assert kv_max_seqlen % cp_size == 0, "sequence length must evenly divide cp size"

                q_split = jnp.split(q, 2, axis=1)

                kv_seqlens_for_rank = helper.kv_seqlens_for_rank(
                    idx, kv_max_seqlen, kv_seqlen_per_subrank
                )

                results = []
                for sub_idx in range(2):
Reese Wang's avatar
Reese Wang committed
1496
                    if config.attn_mask_type == AttnMaskType.NO_MASK:
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
                        k_unmasked, v_unmasked = k, v  # full kv used for unmasked
                    else:
                        k_unmasked, v_unmasked = helper.slice_kv(k, v, kv_seqlens_for_rank[sub_idx])

                    q_seqlen_for_step = q_seqlen / (cp_size * 2)
                    num_kv_chunks = kv_max_seqlen // kv_seqlens_for_rank[sub_idx]
                    kv_seqlen_for_step = (kv_seqlen / (cp_size * 2)) * num_kv_chunks

                    output, softmax_aux, rng_state = FusedAttnFwdPrimitive.impl(
                        q_split[sub_idx],
                        k_unmasked,
                        v_unmasked,
                        bias,
1510
                        softmax_offset,
1511
                        seed,
1512
1513
1514
1515
                        q_seqlen_for_step,
                        kv_seqlen_for_step,
                        q_seq_offsets,
                        k_seq_offsets,
1516
1517
1518
1519
                        _q_segment_ids,
                        _kv_segment_ids,
                        _q_segment_pos,
                        _kv_segment_pos,
1520
                        config=helper.get_step_config(),
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
                    )
                    results.append((output, softmax_aux, rng_state))

                output = jnp.concatenate((results[0][0], results[1][0]), axis=1)
                softmax_aux = jnp.concatenate((results[0][1], results[1][1]), axis=2)
                rng_state = results[1][2]  # Use the final RNG state

                return output, softmax_aux, rng_state

            k_ag, v_ag = helper.all_gather_kv(k, v)

            functions = [
1533
1534
1535
                partial(
                    _cross_attn, idx, q, k_ag, v_ag, bias, softmax_offset, q_seqlen, kv_seqlen, seed
                )
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
                for idx in range(cp_size)
            ]

            return lax.switch(cp_rank, functions)

        return mesh, impl, out_shardings, arg_shardings


register_primitive(FusedAttnCPWithAllGatherFwdPrimitive)


class FusedAttnCPWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive):
    """
    Fused Attention Backward with Context Parallelism Primitive.

    This context parallel implementation uses all-gather to collect KV and dKV inputs from context parallel ranks.
    The gradients are subsequently reduce-scattered back to each context parallel rank.
    """

    @staticmethod
    def partition(config, mesh, arg_infos, result_infos):
        # Call base implementation for non-context parallel mesh to avoid unecessary work.
        is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1
1559
        assert (
1560
            not is_context_parallel or config.window_size[0] == -1
1561
        ), "Sliding window attention is not supported when context parallelism is enabled"
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
        if not is_context_parallel:
            return FusedAttnBwdPrimitive.partition(config, mesh, arg_infos, result_infos)

        # Ensure we can support this configuration with context parallelism.
        helper = _FusedAttnCPWithAllGatherHelper(mesh, config)
        helper.check_supported()

        del result_infos
        q_spec = get_padded_spec(arg_infos[0])
        k_spec = get_padded_spec(arg_infos[1])
        v_spec = get_padded_spec(arg_infos[2])
        bias_spec = get_padded_spec(arg_infos[3])
1574
        softmax_offset_spec = get_padded_spec(arg_infos[4])
1575
1576
1577
1578
        dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
        dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec))
        dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec))
        dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec))
1579
        dsoftmax_offset_sharding = NamedSharding(mesh, PartitionSpec(*softmax_offset_spec))
1580
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
1581
1582
1583
1584
1585
1586
1587
        out_shardings = (
            dq_sharding,
            dk_sharding,
            dv_sharding,
            dbias_sharding,
            dsoftmax_offset_sharding,
        )
1588
1589
1590
1591
1592
1593

        def impl(
            q,
            k,
            v,
            bias,
1594
            softmax_offset,
1595
1596
1597
1598
1599
1600
1601
1602
            softmax_aux,
            rng_state,
            output,
            doutput,
            q_seqlen,
            kv_seqlen,
            q_seq_offsets,
            k_seq_offsets,
1603
1604
1605
1606
            _q_segment_ids,
            _kv_segment_ids,
            _q_segment_pos,
            _kv_segment_pos,
1607
1608
1609
1610
1611
1612
        ):
            cp_size = get_mesh_axis_size(config.cp_axis, mesh)
            cp_rank = get_mesh_axis_rank(config.cp_axis, mesh)

            # See comment in FusedAttnCPFwdPrimitive.partition for why we define this function.
            def _cross_attn_bwd(
1613
1614
1615
1616
1617
                idx,
                q,
                k,
                v,
                bias,
1618
                softmax_offset,
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
                softmax_aux,
                rng_state,
                output,
                doutput,
                q_seqlen,
                kv_seqlen,
                _q_segment_ids,
                _kv_segment_ids,
                _q_segment_pos,
                _kv_segment_pos,
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
            ):
                kv_max_seqlen = k.shape[1]
                kv_seqlen_per_subrank = kv_max_seqlen // (cp_size * 2)
                assert kv_max_seqlen % cp_size == 0, "sequence length must evenly divide cp size"

                q_split = jnp.split(q, 2, axis=1)
                output_split = jnp.split(output, 2, axis=1)
                doutput_split = jnp.split(doutput, 2, axis=1)
                softmax_aux_split = jnp.split(softmax_aux, 2, axis=2)

                kv_seqlens_for_rank = helper.kv_seqlens_for_rank(
                    idx, kv_max_seqlen, kv_seqlen_per_subrank
                )

                results = []
                for sub_idx in range(2):
Reese Wang's avatar
Reese Wang committed
1645
                    if config.attn_mask_type == AttnMaskType.NO_MASK:
1646
1647
1648
1649
1650
1651
1652
1653
                        k_unmasked, v_unmasked = k, v  # full kv used for unmasked
                    else:
                        k_unmasked, v_unmasked = helper.slice_kv(k, v, kv_seqlens_for_rank[sub_idx])

                    q_seqlen_for_step = q_seqlen // (cp_size * 2)
                    num_kv_chunks = kv_max_seqlen // kv_seqlens_for_rank[sub_idx]
                    kv_seqlen_for_step = (kv_seqlen // (cp_size * 2)) * num_kv_chunks

1654
                    dq_local, dk_local, dv_local, dbias_local, _ = FusedAttnBwdPrimitive.impl(
1655
1656
1657
1658
                        q_split[sub_idx],
                        k_unmasked,
                        v_unmasked,
                        bias,
1659
                        softmax_offset,
1660
1661
1662
1663
1664
1665
1666
1667
                        softmax_aux_split[sub_idx],
                        rng_state,
                        output_split[sub_idx],
                        doutput_split[sub_idx],
                        q_seqlen_for_step,
                        kv_seqlen_for_step,
                        q_seq_offsets,
                        k_seq_offsets,
1668
1669
1670
1671
                        _q_segment_ids,
                        _kv_segment_ids,
                        _q_segment_pos,
                        _kv_segment_pos,
1672
                        config=helper.get_step_config(),
1673
1674
1675
                    )

                    # pad dk/dv to be unsliced shape so we can reduce scatter over all ranks.
Reese Wang's avatar
Reese Wang committed
1676
                    if config.attn_mask_type != AttnMaskType.NO_MASK:
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
                        pad_length = kv_max_seqlen - kv_seqlens_for_rank[sub_idx]
                        dk_local, dv_local = helper.pad_kv(dk_local, dv_local, pad_length)

                    results.append((dq_local, dk_local, dv_local, dbias_local))

                dq_local = jnp.concatenate((results[0][0], results[1][0]), axis=1)
                dk_local_pad = results[0][1] + results[1][1]
                dv_local_pad = results[0][2] + results[1][2]
                return dq_local, dk_local_pad, dv_local_pad, results[1][3]

            k_ag, v_ag = helper.all_gather_kv(k, v)

            functions = [
                partial(
                    _cross_attn_bwd,
                    idx,
                    q,
                    k_ag,
                    v_ag,
                    bias,
1697
                    softmax_offset,
1698
1699
1700
1701
1702
1703
                    softmax_aux,
                    rng_state,
                    output,
                    doutput,
                    q_seqlen,
                    kv_seqlen,
1704
1705
1706
1707
                    _q_segment_ids,
                    _kv_segment_ids,
                    _q_segment_pos,
                    _kv_segment_pos,
1708
1709
1710
1711
1712
1713
1714
                )
                for idx in range(cp_size)
            ]

            dq, dk_local, dv_local, dbias = lax.switch(cp_rank, functions)
            dk, dv = helper.reduce_scatter_dkv(dk_local, dv_local)

1715
1716
1717
            # Return dummy dsoftmax_offset for arity matching (all-gather CP doesn't use it)
            dummy_dsoftmax_offset = jnp.empty_like(softmax_offset)
            return dq, dk, dv, dbias, dummy_dsoftmax_offset
1718
1719
1720
1721
1722
1723
1724

        return mesh, impl, out_shardings, arg_shardings


register_primitive(FusedAttnCPWithAllGatherBwdPrimitive)


1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
@dataclass(frozen=True)
class _FusedAttnCPWithP2PHelper:
    """Helper class to assist with running the P2P ring strategy for CP attention."""

    mesh: jax.sharding.Mesh
    config: _FusedAttnConfig

    @staticmethod
    def use_scanloop():
        """Returns true if the implementation will use a scan loop for iteration."""
        use_scan = bool(int(os.getenv("NVTE_FUSED_RING_ATTENTION_USE_SCAN", "1")))
1736
        return use_scan
1737
1738
1739
1740
1741

    def check_supported(self):
        """Checks if the context parallel implementation is supported by the given arguments."""
        header = "Context parallel fused ring attention"

Reese Wang's avatar
Reese Wang committed
1742
1743
1744
1745
1746
        if self.config.qkv_layout.is_thd():
            allowed_layouts = [QKVLayout.THD_T2HD, QKVLayout.THD_THD_THD]
        else:
            allowed_layouts = [QKVLayout.BSHD_BS2HD, QKVLayout.BSHD_BSHD_BSHD]

1747
1748
1749
1750
1751
1752
        if self.config.qkv_layout not in allowed_layouts:
            raise ValueError(
                f"{header} only supports layouts:"
                f" {','.join(map(str, allowed_layouts))} got: {self.config.qkv_layout}"
            )

Reese Wang's avatar
Reese Wang committed
1753
        if self.config.attn_bias_type != AttnBiasType.NO_BIAS:
1754
1755
            raise ValueError(f"{header} does not support bias got: {self.config.attn_bias_type}")

Reese Wang's avatar
Reese Wang committed
1756
1757
1758
1759
        if self.config.qkv_layout.is_thd():
            allowed_masks = [AttnMaskType.PADDING_CAUSAL_MASK]
        else:
            allowed_masks = [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]
1760
1761
1762
1763
1764
1765
        if self.config.attn_mask_type not in allowed_masks:
            raise ValueError(
                f"{header} only supports masking types: "
                f" {','.join(map(str, allowed_masks))} got: {self.config.attn_mask_type}"
            )

Reese Wang's avatar
Reese Wang committed
1766
        if not self.config.qkv_layout.is_thd() and self.config.max_segments_per_seq != 1:
1767
1768
1769
1770
1771
1772
1773
1774
            raise ValueError(
                f"{header} only supports max_segments_per_seq == 1 got:"
                f" {self.config.max_segments_per_seq}"
            )

        if self.config.dropout_probability != 0.0:
            raise ValueError(f"{header} does not support dropout")

1775
1776
1777
1778
1779
        if self.config.softmax_type != AttnSoftmaxType.VANILLA_SOFTMAX:
            raise ValueError(
                f"{header} only supports VANILLA_SOFTMAX, got: {self.config.softmax_type}"
            )

1780
1781
1782
1783
1784
1785
        # We want to encourage use of scan loop to minimize unrolling and ensure more
        # predictable scheduling from XLA. The unrolled flavor will be supported but
        # not the prefered implementation.
        if not self.use_scanloop():
            warnings.warn(
                "Scan loop is disabled for fused ring attention. To enable set"
1786
                " NVTE_FUSED_RING_ATTENTION_USE_SCAN=1 in your environment"
1787
1788
            )

1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
        # If using scanloop, idx in scan_kv_block() will be a traced device value, but
        # _normalize_window_size_for_cp_striped() requires all parameters to be host values
        is_context_parallel = get_mesh_axis_size(self.config.cp_axis, self.mesh) > 1
        is_thd_layout = self.config.qkv_layout.is_thd()
        is_sliding_window = self.config.window_size[0] != -1
        if is_context_parallel and is_thd_layout and is_sliding_window and self.use_scanloop():
            raise ValueError(
                f"{header} with THD format and sliding window does not support using scan loop"
            )

1799
1800
1801
1802
1803
    def get_step_config(self, attn_mask_type) -> _FusedAttnConfig:
        """Returns a _FusedAttnConfig for single CP step call to fused attention."""
        return _FusedAttnConfig(
            attn_bias_type=self.config.attn_bias_type,
            attn_mask_type=attn_mask_type,
1804
            softmax_type=self.config.softmax_type,
Reese Wang's avatar
Reese Wang committed
1805
            qkv_layout=QKVLayout.BSHD_BS2HD,
1806
1807
1808
1809
1810
1811
1812
            scaling_factor=self.config.scaling_factor,
            dropout_probability=self.config.dropout_probability,
            is_training=self.config.is_training,
            max_segments_per_seq=self.config.max_segments_per_seq,
            window_size=self.config.window_size,
            context_parallel_load_balanced=self.config.context_parallel_load_balanced,
            cp_axis=self.config.cp_axis,
1813
            cp_striped_window_size=None,
1814
1815
1816
1817
1818
        )

    def stack_kv(self, k, v):
        """Stacks k and v tensors if not stacked."""
        _not_used = jnp.zeros(0, dtype=k.dtype)
Reese Wang's avatar
Reese Wang committed
1819
1820
1821
1822
        if self.config.qkv_layout.is_kvpacked():
            return k
        if self.config.qkv_layout.is_separate():
            return jnp.stack([k, v], axis=2)
1823
1824
1825
1826
1827
        return _not_used

    def unstack_kv(self, kv):
        """Un-stacks k and v tensors if not stacked."""
        _not_used = jnp.zeros(0, dtype=kv.dtype)
Reese Wang's avatar
Reese Wang committed
1828
1829
1830
1831
        if self.config.qkv_layout.is_kvpacked():
            return kv, _not_used
        if self.config.qkv_layout.is_separate():
            return jnp.unstack(kv, axis=2)
1832
1833
1834
1835
1836
1837
        return _not_used, _not_used  # fall through

    def permute_kv(self, kv, cp_perm):
        """Permutes kv around the ring as described by cp_perm."""
        return lax_paral_op(kv, lax.ppermute, self.config.cp_axis, mesh=self.mesh, perm=cp_perm)

1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
    @staticmethod
    def correct_output_and_softmax_aux(output, softmax_aux, partial_output, partial_softmax_aux):
        """
        Corrects the output and softmax_aux tensor after each iteration of ring attention.

        See https://github.com/zhuzilin/ring-flash-attention/pull/34#issuecomment-2076126795 for
        derivation of this equation.
        """
        new_out = output - jax.nn.sigmoid(partial_softmax_aux - softmax_aux).transpose(
            0, 2, 1, 3
        ) * (output - partial_output)
        new_aux = softmax_aux - jax.nn.log_sigmoid(softmax_aux - partial_softmax_aux)
        return new_out, new_aux
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883

    def adjust_seqlen(self, seqlen, max_seqlen, idx):
        """Adjust the sequence length per step."""
        seqlen_of_curr_step = seqlen - max_seqlen * idx
        seqlen_of_curr_step = jnp.where(seqlen_of_curr_step < 0, 0, seqlen_of_curr_step)
        seqlen_per_step = jnp.where(
            seqlen_of_curr_step < max_seqlen, seqlen_of_curr_step, max_seqlen
        )
        return seqlen_per_step


class FusedRingAttnFwdPrimitive(FusedAttnFwdPrimitive):
    """
    Fused Ring Attention Forward Primitive
    """

    @staticmethod
    def partition(config, mesh, arg_infos, result_infos):
        is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1
        assert (
            not is_context_parallel or config.window_size[0] == -1
        ), "Sliding window attention is not supported when context parallelism is enabled"
        if not is_context_parallel:
            return FusedAttnFwdPrimitive.partition(config, mesh, arg_infos, result_infos)

        helper = _FusedAttnCPWithP2PHelper(mesh, config)
        helper.check_supported()

        out_sharding = result_infos[0].sharding
        softmax_aux_sharding = result_infos[1].sharding
        rng_state_sharding = seed_sharding = NamedSharding(
            mesh, PartitionSpec(get_all_mesh_axes(), None)
        )
1884
        arg_shardings = [arg_i.sharding for arg_i in arg_infos]
1885
        arg_shardings[5] = seed_sharding
1886
1887
1888
        # Ensure segment_pos gets same sharding as ID.
        arg_shardings[-1] = arg_shardings[-3]
        arg_shardings[-2] = arg_shardings[-4]
1889
        arg_shardings = tuple(arg_shardings)
1890
1891
1892
1893
1894
1895
1896
        out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding)

        def ring_attn_fwd_impl(
            q,
            k,
            v,
            bias,
1897
            _softmax_offset,
1898
            seed,
1899
1900
1901
1902
            q_seqlen,
            kv_seqlen,
            q_seq_offsets,
            k_seq_offsets,
1903
1904
1905
1906
            _q_segment_ids,
            _kv_segment_ids,
            _q_segment_pos,
            _kv_segment_pos,
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
        ):
            _not_used = jnp.zeros(0, dtype=v.dtype)

            # Combine KV tensors if separate for better permute scheduling and performance.
            # Eventually XLA should perform this automatically.
            kv = helper.stack_kv(k, v)

            batch, q_max_seqlen, head, _ = q.shape
            kv_max_seqlen = k.shape[1]

            cp_size = get_mesh_axis_size(config.cp_axis, mesh)
            cp_rank = get_mesh_axis_rank(config.cp_axis, mesh)
            cp_perm = [(i, (i + 1) % cp_size) for i in range(cp_size)]

1921
            output = jnp.zeros(q.shape).astype(jnp.float32)
1922
1923
1924
1925
            softmax_aux = jnp.full((batch, head, q_max_seqlen, 1), -jnp.inf, dtype=jnp.float32)

            # RNG shape should be the shared shape. This is unused for ring attention as we do not
            # support dropout currently.
1926
            rng_state_shape = (seed.shape[0], *result_infos[2].shape[1:])
1927
1928
1929
            rng_state = jnp.zeros(rng_state_shape).astype(result_infos[2].dtype)

            def scan_kv_block(idx, carry):
1930
                kv, output, softmax_aux = carry
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942

                # Send KV block to next step so we can overlap compute.
                kv_next = helper.permute_kv(kv, cp_perm)

                def mask_compute(attn_mask_type):
                    q_seqlen_per_step = helper.adjust_seqlen(q_seqlen, q_max_seqlen, idx)
                    kv_seqlen_per_step = helper.adjust_seqlen(kv_seqlen, kv_max_seqlen, idx)
                    output_per_step, softmax_aux_per_step, _ = FusedAttnFwdPrimitive.impl(
                        q,
                        kv,
                        _not_used,
                        bias,
1943
                        _softmax_offset,
1944
                        seed,
1945
1946
1947
1948
                        q_seqlen_per_step,
                        kv_seqlen_per_step,
                        q_seq_offsets,
                        k_seq_offsets,
1949
1950
1951
1952
1953
                        _q_segment_ids,
                        _kv_segment_ids,
                        _q_segment_pos,
                        _kv_segment_pos,
                        config=helper.get_step_config(attn_mask_type),
1954
1955
1956
                    )
                    return output_per_step, softmax_aux_per_step

Reese Wang's avatar
Reese Wang committed
1957
1958
                causal_mask_compute = partial(mask_compute, AttnMaskType.CAUSAL_MASK)
                no_mask_compute = partial(mask_compute, AttnMaskType.NO_MASK)
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968

                def half_kv_no_mask_compute():
                    q_seqlen_per_step = helper.adjust_seqlen(q_seqlen, q_max_seqlen, idx)
                    kv_seqlen_per_step = helper.adjust_seqlen(kv_seqlen, kv_max_seqlen, idx) // 2
                    kv_part = lax.slice_in_dim(kv, 0, kv.shape[1] // 2, axis=1)
                    output_per_step, softmax_aux_per_step, _ = FusedAttnFwdPrimitive.impl(
                        q,
                        kv_part,
                        _not_used,
                        bias,
1969
                        _softmax_offset,
1970
                        seed,
1971
1972
1973
1974
                        q_seqlen_per_step,
                        kv_seqlen_per_step,
                        q_seq_offsets,
                        k_seq_offsets,
1975
1976
1977
1978
                        _q_segment_ids,
                        _kv_segment_ids,
                        _q_segment_pos,
                        _kv_segment_pos,
Reese Wang's avatar
Reese Wang committed
1979
                        config=helper.get_step_config(AttnMaskType.NO_MASK),
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
                    )
                    return output_per_step, softmax_aux_per_step

                def half_q_no_mask_compute():
                    q_seqlen_per_step = helper.adjust_seqlen(q_seqlen, q_max_seqlen, idx) // 2
                    kv_seqlen_per_step = helper.adjust_seqlen(kv_seqlen, kv_max_seqlen, idx)
                    q_part = lax.slice_in_dim(q, q_max_seqlen // 2, q_max_seqlen, axis=1)
                    output_per_step, softmax_aux_per_step, _ = FusedAttnFwdPrimitive.impl(
                        q_part,
                        kv,
                        _not_used,
                        bias,
1992
                        _softmax_offset,
1993
                        seed,
1994
1995
1996
1997
                        q_seqlen_per_step,
                        kv_seqlen_per_step,
                        q_seq_offsets,
                        k_seq_offsets,
1998
1999
2000
2001
                        _q_segment_ids,
                        _kv_segment_ids,
                        _q_segment_pos,
                        _kv_segment_pos,
Reese Wang's avatar
Reese Wang committed
2002
                        config=helper.get_step_config(AttnMaskType.NO_MASK),
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
                    )
                    output_per_step = jnp.concat([jnp.zeros_like(q_part), output_per_step], axis=1)
                    softmax_aux_per_step = jnp.concat(
                        [
                            jnp.full_like(softmax_aux_per_step, -jnp.inf),
                            softmax_aux_per_step,
                        ],
                        axis=2,
                    )
                    return output_per_step, softmax_aux_per_step

                def skip_compute():
                    output_per_step = jnp.zeros_like(q)
                    softmax_aux_per_step = jnp.full(
                        (batch, head, q.shape[1], 1), -jnp.inf, dtype=jnp.float32
                    )
                    return output_per_step, softmax_aux_per_step

Reese Wang's avatar
Reese Wang committed
2021
                if config.attn_mask_type == AttnMaskType.CAUSAL_MASK:
2022
2023
2024
2025
2026
2027
2028
2029
2030
2031
2032
2033
2034
2035
                    # This is for nested jax.lax.cond
                    def jax_cond_wrap():
                        if config.context_parallel_load_balanced:
                            return lax.cond(
                                (idx <= cp_rank), half_kv_no_mask_compute, half_q_no_mask_compute
                            )
                        return lax.cond((idx <= cp_rank), no_mask_compute, skip_compute)

                    output_per_step, softmax_aux_per_step = lax.cond(
                        idx == 0, causal_mask_compute, jax_cond_wrap
                    )
                else:
                    output_per_step, softmax_aux_per_step = no_mask_compute()

2036
2037
2038
2039
2040
                def skip_correction(output, softmax_aux, output_per_step, softmax_aux_per_step):
                    # No correction done here but we cast outputs to float32 and perform reduction
                    # in full precision.
                    # pylint: disable=unused-argument
                    return output_per_step.astype(jnp.float32), softmax_aux_per_step
2041

2042
2043
2044
2045
                def correction(output, softmax_aux, output_per_step, softmax_aux_per_step):
                    return helper.correct_output_and_softmax_aux(
                        output, softmax_aux, output_per_step, softmax_aux_per_step
                    )
2046

2047
2048
2049
2050
2051
2052
2053
2054
2055
2056
2057
2058
2059
2060
                # first step there is no correction we get initial output and stats
                output, softmax_aux = lax.cond(
                    (idx == 0),
                    skip_correction,
                    correction,
                    output,
                    softmax_aux,
                    output_per_step,
                    softmax_aux_per_step,
                )

                return (kv_next, output, softmax_aux)

            carry = (kv, output, softmax_aux)
2061
2062
2063
2064
2065
            if helper.use_scanloop():
                carry = lax.fori_loop(0, cp_size, scan_kv_block, carry)
            else:
                for i in range(0, cp_size):
                    carry = scan_kv_block(i, carry)
2066
            (kv, output, softmax_aux) = carry
2067
2068
2069
2070
2071
2072
2073
2074
2075
2076
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
2095

            output = output.astype(q.dtype)
            return output, softmax_aux, rng_state

        return mesh, ring_attn_fwd_impl, out_shardings, arg_shardings


register_primitive(FusedRingAttnFwdPrimitive)


class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive):
    """
    Fused Ring Attention Backward Primitive
    """

    @staticmethod
    def partition(config, mesh, arg_infos, result_infos):
        is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1
        assert (
            not is_context_parallel or config.window_size[0] == -1
        ), "Sliding window attention is not supported when context parallelism is enabled"
        if not is_context_parallel:
            return FusedAttnBwdPrimitive.partition(config, mesh, arg_infos, result_infos)

        del result_infos
        q_spec = get_padded_spec(arg_infos[0])
        k_spec = get_padded_spec(arg_infos[1])
        v_spec = get_padded_spec(arg_infos[2])
        bias_spec = get_padded_spec(arg_infos[3])
2096
        softmax_offset_spec = get_padded_spec(arg_infos[4])
2097
2098
2099
2100
        dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
        dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec))
        dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec))
        dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec))
2101
2102
        # Ring attention doesn't use dsoftmax_offset, but we need to return it for arity matching
        dsoftmax_offset_sharding = NamedSharding(mesh, PartitionSpec(*softmax_offset_spec))
2103
2104
2105
2106
        arg_shardings = [arg_i.sharding for arg_i in arg_infos]
        arg_shardings[-1] = arg_shardings[-3]
        arg_shardings[-2] = arg_shardings[-4]
        arg_shardings = tuple(arg_shardings)
2107
2108
2109
2110
2111
2112
2113
        out_shardings = (
            dq_sharding,
            dk_sharding,
            dv_sharding,
            dbias_sharding,
            dsoftmax_offset_sharding,
        )
2114
2115
2116
2117
2118
2119
2120
2121
2122

        helper = _FusedAttnCPWithP2PHelper(mesh, config)
        helper.check_supported()

        def ring_attn_bwd_impl(
            q,
            k,
            v,
            bias,
2123
            _softmax_offset,
2124
2125
2126
2127
2128
2129
2130
2131
            softmax_aux,
            rng_state,
            output,
            doutput,
            q_seqlen,
            kv_seqlen,
            q_seq_offsets,
            k_seq_offsets,
2132
2133
2134
2135
            _q_segment_ids,
            _kv_segment_ids,
            _q_segment_pos,
            _kv_segment_pos,
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
2146
2147
2148
2149
2150
2151
2152
2153
2154
2155
2156
2157
2158
2159
2160
2161
2162
2163
2164
2165
2166
        ):
            _not_used = jnp.zeros(0, dtype=output.dtype)

            # Combine KV tensors if separate for better permute scheduling and performance.
            # Eventually XLA should perform this automatically.
            kv = helper.stack_kv(k, v)

            q_max_seqlen = q.shape[1]
            kv_max_seqlen = k.shape[1]

            cp_size = get_mesh_axis_size(config.cp_axis, mesh)
            cp_rank = get_mesh_axis_rank(config.cp_axis, mesh)
            cp_perm = [(i, (i + 1) % cp_size) for i in range(cp_size)]

            dq = jnp.zeros_like(q)
            dk_dv = helper.stack_kv(jnp.zeros_like(k), jnp.zeros_like(v))
            dbias = jnp.zeros_like(bias)

            def scan_kv_block(idx, carry):

                kv, dq, dk_dv, dbias = carry

                # Start communication that feeds the next iteraton.
                # We further combine the tensors to improve overlap.

                kv_dk_dv = jnp.stack([kv, dk_dv])
                kv_dk_dv = helper.permute_kv(kv_dk_dv, cp_perm)

                def mask_compute(attn_mask_type):
                    q_seqlen_per_step = helper.adjust_seqlen(q_seqlen, q_max_seqlen, idx)
                    kv_seqlen_per_step = helper.adjust_seqlen(kv_seqlen, kv_max_seqlen, idx)
2167
                    dq_per_step, dk_dv_per_step, _, dbias_per_step, _ = FusedAttnBwdPrimitive.impl(
2168
2169
2170
2171
                        q,
                        kv,
                        _not_used,
                        bias,
2172
                        _softmax_offset,
2173
2174
2175
2176
2177
2178
2179
2180
                        softmax_aux,
                        rng_state,
                        output,
                        doutput,
                        q_seqlen_per_step,
                        kv_seqlen_per_step,
                        q_seq_offsets,
                        k_seq_offsets,
2181
2182
2183
2184
                        _q_segment_ids,
                        _kv_segment_ids,
                        _q_segment_pos,
                        _kv_segment_pos,
2185
2186
2187
2188
                        config=helper.get_step_config(attn_mask_type),
                    )
                    return dq_per_step, dk_dv_per_step, dbias_per_step

Reese Wang's avatar
Reese Wang committed
2189
2190
                causal_mask_compute = partial(mask_compute, AttnMaskType.CAUSAL_MASK)
                no_mask_compute = partial(mask_compute, AttnMaskType.NO_MASK)
2191
2192
2193
2194
2195

                def half_kv_no_mask_compute():
                    q_seqlen_per_step = helper.adjust_seqlen(q_seqlen, q_max_seqlen, idx)
                    kv_seqlen_per_step = helper.adjust_seqlen(kv_seqlen, kv_max_seqlen, idx) // 2
                    kv_part = lax.slice_in_dim(kv, 0, kv_max_seqlen // 2, axis=1)
2196
                    dq_per_step, dk_dv_per_step, _, dbias_per_step, _ = FusedAttnBwdPrimitive.impl(
2197
2198
2199
2200
                        q,
                        kv_part,
                        _not_used,
                        bias,
2201
                        _softmax_offset,
2202
2203
2204
2205
2206
2207
2208
2209
                        softmax_aux,
                        rng_state,
                        output,
                        doutput,
                        q_seqlen_per_step,
                        kv_seqlen_per_step,
                        q_seq_offsets,
                        k_seq_offsets,
2210
2211
2212
2213
                        _q_segment_ids,
                        _kv_segment_ids,
                        _q_segment_pos,
                        _kv_segment_pos,
Reese Wang's avatar
Reese Wang committed
2214
                        config=helper.get_step_config(AttnMaskType.NO_MASK),
2215
2216
2217
2218
2219
2220
2221
2222
2223
2224
2225
2226
2227
2228
2229
2230
2231
2232
2233
2234
                    )
                    dk_dv_per_step = jnp.concat(
                        [dk_dv_per_step, jnp.zeros_like(dk_dv_per_step)], axis=1
                    )
                    return dq_per_step, dk_dv_per_step, dbias_per_step

                def half_q_no_mask_compute():
                    q_seqlen_per_step = helper.adjust_seqlen(q_seqlen, q_max_seqlen, idx) // 2
                    kv_seqlen_per_step = helper.adjust_seqlen(kv_seqlen, kv_max_seqlen, idx)

                    q_part = lax.slice_in_dim(q, q_max_seqlen // 2, q_max_seqlen, axis=1)
                    doutput_part = lax.slice_in_dim(
                        doutput, q_max_seqlen // 2, q_max_seqlen, axis=1
                    )
                    output_part = lax.slice_in_dim(output, q_max_seqlen // 2, q_max_seqlen, axis=1)

                    softmax_aux_part = lax.slice_in_dim(
                        softmax_aux, q_max_seqlen // 2, q_max_seqlen, axis=2
                    )

2235
                    dq_per_step, dk_dv_per_step, _, dbias_per_step, _ = FusedAttnBwdPrimitive.impl(
2236
2237
2238
2239
                        q_part,
                        kv,
                        _not_used,
                        bias,
2240
                        _softmax_offset,
2241
2242
2243
2244
2245
2246
2247
2248
                        softmax_aux_part,
                        rng_state,
                        output_part,
                        doutput_part,
                        q_seqlen_per_step,
                        kv_seqlen_per_step,
                        q_seq_offsets,
                        k_seq_offsets,
2249
2250
2251
2252
                        _q_segment_ids,
                        _kv_segment_ids,
                        _q_segment_pos,
                        _kv_segment_pos,
Reese Wang's avatar
Reese Wang committed
2253
                        config=helper.get_step_config(AttnMaskType.NO_MASK),
2254
2255
2256
2257
2258
2259
2260
                    )
                    dq_per_step = jnp.concat([jnp.zeros_like(dq_per_step), dq_per_step], axis=1)
                    return dq_per_step, dk_dv_per_step, dbias_per_step

                def skip_compute():
                    return jnp.zeros_like(q), jnp.zeros_like(kv), jnp.zeros_like(bias)

Reese Wang's avatar
Reese Wang committed
2261
                if config.attn_mask_type == AttnMaskType.CAUSAL_MASK:
2262
2263
2264
2265
2266
2267
2268
2269
2270
2271
2272
2273
2274
2275
2276
2277
2278
                    # This is for nested jax.lax.cond
                    def jax_cond_wrap():
                        if config.context_parallel_load_balanced:
                            return lax.cond(
                                (idx <= cp_rank), half_kv_no_mask_compute, half_q_no_mask_compute
                            )
                        return lax.cond((idx <= cp_rank), no_mask_compute, skip_compute)

                    dq_per_step, dk_dv_per_step, dbias_per_step = lax.cond(
                        idx == 0, causal_mask_compute, jax_cond_wrap
                    )
                else:
                    dq_per_step, dk_dv_per_step, dbias_per_step = no_mask_compute()

                kv_next, dk_dv = jnp.unstack(kv_dk_dv)
                dq = dq + dq_per_step
                dk_dv = dk_dv + dk_dv_per_step
Reese Wang's avatar
Reese Wang committed
2279
                if config.attn_bias_type is not AttnBiasType.NO_BIAS:
2280
2281
2282
2283
2284
2285
2286
2287
2288
2289
2290
2291
2292
2293
2294
2295
                    dbias = dbias + dbias_per_step

                return (kv_next, dq, dk_dv, dbias)

            carry = (kv, dq, dk_dv, dbias)
            if helper.use_scanloop():
                carry = lax.fori_loop(0, cp_size, scan_kv_block, carry)
            else:
                for i in range(0, cp_size):
                    carry = scan_kv_block(i, carry)
            (kv, dq, dk_dv, dbias) = carry

            # Final permute to put gradients back to their final resting place.
            dk_dv = helper.permute_kv(dk_dv, cp_perm)

            global_dbias = dbias
Reese Wang's avatar
Reese Wang committed
2296
            if config.attn_bias_type is not AttnBiasType.NO_BIAS:
2297
2298
2299
                global_dbias = all_reduce_sum_along_dp_fsdp(dbias, mesh)

            dk, dv = helper.unstack_kv(dk_dv)
2300
2301
2302
            # Return dummy dsoftmax_offset for arity matching (ring attention doesn't use it)
            dummy_dsoftmax_offset = jnp.empty_like(_softmax_offset)
            return dq, dk, dv, global_dbias, dummy_dsoftmax_offset
2303
2304
2305
2306
2307
2308
2309

        return mesh, ring_attn_bwd_impl, out_shardings, arg_shardings


register_primitive(FusedRingAttnBwdPrimitive)


2310
2311
2312
2313
2314
2315
2316
2317
2318
2319
2320
2321
2322
2323
2324
2325
2326
2327
2328
2329
2330
2331
2332
2333
2334
2335
2336
2337
2338
2339
2340
2341
2342
2343
2344
2345
2346
2347
2348
2349
2350
2351
2352
2353
2354
2355
2356
2357
2358
2359
2360
2361
2362
2363
2364
2365
2366
2367
2368
2369
2370
def adjust_cp_striped_window_size(q_pos0, kv_pos0, cp_size, window_size):
    """
    Adjust window size with cp_size for striped sharding, where both q_pos and
    kv_pos are arithmetic sequences like [x, x+cp_size, x+2*cp_size, ...].
    Example 1:
        q_pos = kv_pos = [0, 8, 16, 24, 32], cp_size = 8, window_size = (15, 0).
        q_pos = 32 can look at kv_pos at [24, 32]. The effective mask is:
              0  8 16 24 32
           ----------------
         0 |  1  0  0  0  0
         8 |  1  1  0  0  0
        16 |  0  1  1  0  0
        24 |  0  0  1  1  0
        32 |  0  0  0  1  1
        SequenceDescriptor outputs: {q,kv}_seqlen = [5, ...], {q,kv}_seq_offsets = [0, ...].
        Adjusted window size = (1, 0).
    Example 2:
        q_pos = [0, 8, 16, 24, 32], kv_pos = [1, 9, 17, 25, 33], cp_size = 8,
        window_size = (15, 0). The effective mask is:
              1  9 17 25 33
           ----------------
         0 |  0  0  0  0  0
         8 |  1  0  0  0  0
        16 |  1  1  0  0  0
        24 |  0  1  1  0  0
        32 |  0  0  1  1  0
        SequenceDescriptor outputs:
        q_seqlen = [4, ...], q_seq_offsets = [1, ...],
        kv_seqlen = [4, ...], kv_seq_offsets = [0, ...].
        If diagonal are all 1, left window size = 2. Now since diagonal are all 0,
        we need to use left window size = 2 - 1 = 1 to make cuDNN work.
    Example 3:
        q_pos = [7, 15, 23, 31, 39], kv_pos = [0, 8, 16, 24, 32], cp_size = 8,
        window_size = (22, 0). The effective mask is:
              0  8 16 24 32
           ----------------
         7 |  1  0  0  0  0
        15 |  1  1  0  0  0
        23 |  0  1  1  0  0
        31 |  0  0  1  1  0
        39 |  0  0  0  1  1
        SequenceDescriptor outputs: {q,kv}_seqlen = [5, ...], {q,kv}_seq_offsets = [0, ...].
        Adjust window size = (1, 0).
    """

    left_limit = q_pos0 - window_size[0]
    right_limit = q_pos0 + window_size[1]

    # Count how many left/right steps of size cp_size we can take from kv_pos0 -/+ cp_size
    left_steps = (kv_pos0 - cp_size - left_limit) // cp_size + 1
    right_steps = (right_limit - kv_pos0 - cp_size) // cp_size + 1
    left_steps = max(left_steps, 0)
    right_steps = max(right_steps, 0)

    # If kv_pos0 > q_pos0, we must reduce left window size by 1
    shift = 1 if kv_pos0 > q_pos0 else 0
    left_steps = left_steps - shift

    return left_steps, right_steps


Reese Wang's avatar
Reese Wang committed
2371
2372
2373
2374
2375
2376
2377
2378
2379
2380
2381
2382
2383
2384
2385
2386
2387
2388
2389
2390
class FusedRingAttnStripedFwdPrimitive(FusedAttnFwdPrimitive):
    """
    Fused Striped Ring Attention Forward Primitive
    """

    @staticmethod
    def partition(config, mesh, arg_infos, result_infos):
        is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1
        if not is_context_parallel:
            return FusedAttnFwdPrimitive.partition(config, mesh, arg_infos, result_infos)

        helper = _FusedAttnCPWithP2PHelper(mesh, config)
        helper.check_supported()

        out_sharding = result_infos[0].sharding
        softmax_aux_sharding = result_infos[1].sharding
        rng_state_sharding = seed_sharding = NamedSharding(
            mesh, PartitionSpec(get_all_mesh_axes(), None)
        )
        arg_shardings = [arg_i.sharding for arg_i in arg_infos]
2391
        arg_shardings[5] = seed_sharding
2392
2393
2394
        # Ensure segment_pos gets same sharding as ID.
        arg_shardings[-1] = arg_shardings[-3]
        arg_shardings[-2] = arg_shardings[-4]
Reese Wang's avatar
Reese Wang committed
2395
2396
2397
2398
2399
2400
2401
2402
        arg_shardings = tuple(arg_shardings)
        out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding)

        def fwd_impl(
            q,
            k,
            v,
            bias,
2403
            _softmax_offset,
Reese Wang's avatar
Reese Wang committed
2404
2405
2406
2407
2408
2409
2410
2411
2412
2413
2414
2415
2416
2417
2418
2419
2420
2421
2422
2423
2424
2425
2426
2427
            seed,
            q_seqlen,
            kv_seqlen,
            q_seq_offsets,
            k_seq_offsets,
            q_segment_ids,
            kv_segment_ids,
            q_segment_pos,
            kv_segment_pos,
        ):
            if q_segment_ids.size == 0 or kv_segment_ids.size == 0:
                raise ValueError("THD + ring attn only supports passing seqment_ids/pos")

            _not_used = jnp.zeros(0, dtype=v.dtype)

            # Combine KV tensors if separate for better permute scheduling and performance.
            # Eventually XLA should perform this automatically.
            kv = helper.stack_kv(k, v)
            if not config.qkv_layout.is_qkvpacked():
                subblock_config = replace(config, qkv_layout=config.qkv_layout.to_kvpacked())
            else:
                subblock_config = config

            cp_size = get_mesh_axis_size(config.cp_axis, mesh)
2428
            cp_rank = get_mesh_axis_rank_host(config.cp_axis, mesh)
Reese Wang's avatar
Reese Wang committed
2429
2430
2431
2432
2433
2434
2435
2436
            cp_perm = [(i, (i + 1) % cp_size) for i in range(cp_size)]

            batch, q_max_seqlen, head, _ = q.shape
            output = jnp.zeros(q.shape).astype(jnp.float32)
            softmax_aux = jnp.zeros((batch, q_max_seqlen, head, 1), dtype=jnp.float32)

            # RNG shape should be the shared shape. This is unused for ring attention as we do not
            # support dropout currently.
2437
            rng_state_shape = (seed.shape[0], *result_infos[2].shape[1:])
Reese Wang's avatar
Reese Wang committed
2438
2439
2440
2441
2442
2443
2444
2445
2446
2447
2448
            rng_state = jnp.zeros(rng_state_shape).astype(result_infos[2].dtype)

            def scan_kv_block(idx, carry):
                kv, kv_segment_ids, kv_segment_pos, output, softmax_aux = carry

                # TODO(rewang): To check whether we need special handle for the last idx
                # Send KV block to next step so we can overlap compute.
                kv_next = helper.permute_kv(kv, cp_perm)
                kv_segment_ids_next = helper.permute_kv(kv_segment_ids, cp_perm)
                kv_segment_pos_next = helper.permute_kv(kv_segment_pos, cp_perm)

2449
2450
2451
2452
2453
2454
                def compute(config):
                    return FusedAttnFwdPrimitive.impl(
                        q,
                        kv,
                        _not_used,
                        bias,
2455
                        _softmax_offset,
2456
2457
2458
2459
2460
2461
2462
2463
2464
                        seed,
                        q_seqlen,
                        kv_seqlen,
                        q_seq_offsets,
                        k_seq_offsets,
                        q_segment_ids,
                        kv_segment_ids,
                        q_segment_pos,
                        kv_segment_pos,
2465
                        config=config,
2466
2467
2468
2469
2470
2471
2472
2473
2474
2475
2476
2477
2478
2479
                    )

                if config.window_size != (-1, -1):
                    kv_src_rank = (cp_size + cp_rank - idx) % cp_size
                    # Note: all inputs of adjust_cp_striped_window_size should be host values
                    cp_striped_window_size = adjust_cp_striped_window_size(
                        cp_rank, kv_src_rank, cp_size, config.window_size
                    )
                    current_config = replace(
                        subblock_config, cp_striped_window_size=cp_striped_window_size
                    )
                else:
                    current_config = subblock_config
                output_per_step, softmax_aux_per_step, _ = compute(current_config)
Reese Wang's avatar
Reese Wang committed
2480
2481
2482
2483
2484
2485
2486
2487
2488
2489
2490
2491
2492
2493
2494
2495
2496
2497
2498
2499
2500
2501
2502
2503
2504
2505
2506
2507
2508
2509
2510
2511
2512
2513
2514
2515
2516
2517
2518
2519
2520
2521
2522
2523
2524
2525
2526
2527
2528
2529
2530
2531
2532
2533
2534

                softmax_aux_per_step = softmax_aux_per_step.reshape((batch, q_max_seqlen, head, 1))

                def skip_correction(_output, _softmax_aux, output_per_step, softmax_aux_per_step):
                    # No correction done here but we cast outputs to float32 and perform reduction
                    # in full precision.
                    return output_per_step.astype(jnp.float32), softmax_aux_per_step

                def correction(output, softmax_aux, output_per_step, softmax_aux_per_step):
                    new_out = output - jax.nn.sigmoid(softmax_aux_per_step - softmax_aux) * (
                        output - output_per_step
                    )
                    new_aux = softmax_aux - jax.nn.log_sigmoid(softmax_aux - softmax_aux_per_step)
                    return new_out, new_aux

                # first step there is no correction we get initial output and stats
                output, softmax_aux = lax.cond(
                    idx == 0,
                    skip_correction,
                    correction,
                    output,
                    softmax_aux,
                    output_per_step,
                    softmax_aux_per_step,
                )

                return (kv_next, kv_segment_ids_next, kv_segment_pos_next, output, softmax_aux)

            carry = (kv, kv_segment_ids, kv_segment_pos, output, softmax_aux)
            if helper.use_scanloop():
                carry = lax.fori_loop(0, cp_size, scan_kv_block, carry)
            else:
                for i in range(0, cp_size):
                    carry = scan_kv_block(i, carry)
            (_, _, _, output, softmax_aux) = carry

            return output.astype(q.dtype), softmax_aux, rng_state

        return mesh, fwd_impl, out_shardings, arg_shardings


register_primitive(FusedRingAttnStripedFwdPrimitive)


class FusedRingAttnStripedBwdPrimitive(FusedAttnBwdPrimitive):
    """
    Fused Striped Ring Attention Backward Primitive
    """

    @staticmethod
    def partition(config, mesh, arg_infos, result_infos):
        is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1
        if not is_context_parallel:
            return FusedAttnBwdPrimitive.partition(config, mesh, arg_infos, result_infos)

2535
2536
2537
2538
2539
        arg_shardings = [arg_i.sharding for arg_i in arg_infos]
        # Ensure segment_pos gets same sharding as ID.
        arg_shardings[-1] = arg_shardings[-3]
        arg_shardings[-2] = arg_shardings[-4]
        arg_shardings = tuple(arg_shardings)
2540
2541
        # dq, dk, dv, dbias, dsoftmax_offset sharding = q, k, v, bias, softmax_offset sharding
        out_shardings = tuple(arg.sharding for arg in arg_infos[:5])
Reese Wang's avatar
Reese Wang committed
2542
2543
2544
2545
2546
2547
2548
2549
2550

        helper = _FusedAttnCPWithP2PHelper(mesh, config)
        helper.check_supported()

        def bwd_impl(
            q,
            k,
            v,
            bias,
2551
            _softmax_offset,
Reese Wang's avatar
Reese Wang committed
2552
2553
2554
2555
2556
2557
2558
2559
2560
2561
2562
2563
2564
2565
2566
2567
2568
2569
2570
2571
2572
2573
2574
2575
2576
2577
2578
2579
            softmax_aux,
            rng_state,
            output,
            doutput,
            q_seqlen,
            kv_seqlen,
            q_seq_offsets,
            k_seq_offsets,
            q_segment_ids,
            kv_segment_ids,
            q_segment_pos,
            kv_segment_pos,
        ):

            if q_segment_ids.size == 0 or kv_segment_ids.size == 0:
                raise ValueError("THD + ring attn only supports passing seqment_ids/pos")

            _not_used = jnp.zeros(0, dtype=output.dtype)

            # Combine KV tensors if separate for better permute scheduling and performance.
            # Eventually XLA should perform this automatically.
            kv = helper.stack_kv(k, v)
            if not config.qkv_layout.is_qkvpacked():
                subblock_config = replace(config, qkv_layout=config.qkv_layout.to_kvpacked())
            else:
                subblock_config = config

            cp_size = get_mesh_axis_size(config.cp_axis, mesh)
2580
2581
            # We need cp_rank to be a host value for adjust_cp_striped_window_size()
            cp_rank = get_mesh_axis_rank_host(config.cp_axis, mesh)
Reese Wang's avatar
Reese Wang committed
2582
2583
2584
2585
2586
2587
            cp_perm = [(i, (i + 1) % cp_size) for i in range(cp_size)]

            dq = jnp.zeros_like(q)
            dkv = jnp.zeros_like(kv)
            dbias = jnp.zeros_like(bias)

2588
            def scan_kv_block(idx, carry):
Reese Wang's avatar
Reese Wang committed
2589
2590
2591
2592
2593
2594
2595
2596
2597
                kv, kv_segment_ids, kv_segment_pos, dq, dkv, dbias = carry

                # Start communication that feeds the next iteration.
                # We further combine the tensors to improve overlap.
                kv_dkv = jnp.stack([kv, dkv])
                kv_dkv = helper.permute_kv(kv_dkv, cp_perm)
                kv_segment_ids_next = helper.permute_kv(kv_segment_ids, cp_perm)
                kv_segment_pos_next = helper.permute_kv(kv_segment_pos, cp_perm)

2598
                def compute(config):
2599
                    dq_per_step, dkv_per_step, _, dbias_per_step, _ = FusedAttnBwdPrimitive.impl(
Reese Wang's avatar
Reese Wang committed
2600
2601
2602
2603
                        q,
                        kv,
                        _not_used,
                        bias,
2604
                        _softmax_offset,
Reese Wang's avatar
Reese Wang committed
2605
2606
2607
2608
2609
2610
2611
2612
2613
2614
2615
2616
                        softmax_aux,
                        rng_state,
                        output,
                        doutput,
                        q_seqlen,
                        kv_seqlen,
                        q_seq_offsets,
                        k_seq_offsets,
                        q_segment_ids,
                        kv_segment_ids,
                        q_segment_pos,
                        kv_segment_pos,
2617
                        config=config,
Reese Wang's avatar
Reese Wang committed
2618
2619
2620
                    )
                    return dq_per_step, dkv_per_step, dbias_per_step

2621
2622
2623
2624
2625
2626
2627
2628
2629
2630
2631
2632
                if config.window_size != (-1, -1):
                    kv_src_rank = (cp_size + cp_rank - idx) % cp_size
                    # Note: all inputs of adjust_cp_striped_window_size should be host values
                    cp_striped_window_size = adjust_cp_striped_window_size(
                        cp_rank, kv_src_rank, cp_size, config.window_size
                    )
                    current_config = replace(
                        subblock_config, cp_striped_window_size=cp_striped_window_size
                    )
                else:
                    current_config = subblock_config
                dq_per_step, dkv_per_step, dbias_per_step = compute(current_config)
Reese Wang's avatar
Reese Wang committed
2633
2634
2635
2636
2637
2638
2639
2640
2641
2642
2643
2644
2645
2646
2647
2648
2649
2650
2651
2652
2653
2654
2655
2656
2657

                kv_next, dkv = jnp.unstack(kv_dkv)
                dq += dq_per_step
                dkv += dkv_per_step
                if config.attn_bias_type is not AttnBiasType.NO_BIAS:
                    dbias = dbias + dbias_per_step

                return (kv_next, kv_segment_ids_next, kv_segment_pos_next, dq, dkv, dbias)

            carry = (kv, kv_segment_ids, kv_segment_pos, dq, dkv, dbias)
            if helper.use_scanloop():
                carry = lax.fori_loop(0, cp_size, scan_kv_block, carry)
            else:
                for idx in range(cp_size):
                    carry = scan_kv_block(idx, carry)
            (_, _, _, dq, dkv, dbias) = carry

            # Final permute to put gradients back to their final resting place.
            dkv = helper.permute_kv(dkv, cp_perm)

            global_dbias = dbias
            if config.attn_bias_type is not AttnBiasType.NO_BIAS:
                global_dbias = all_reduce_sum_along_dp_fsdp(dbias, mesh)

            dk, dv = helper.unstack_kv(dkv)
2658
2659
2660
            # Return dummy dsoftmax_offset for arity matching (ring attention doesn't use it)
            dummy_dsoftmax_offset = jnp.empty_like(_softmax_offset)
            return dq, dk, dv, global_dbias, dummy_dsoftmax_offset
Reese Wang's avatar
Reese Wang committed
2661
2662
2663
2664
2665
2666
2667

        return mesh, bwd_impl, out_shardings, arg_shardings


register_primitive(FusedRingAttnStripedBwdPrimitive)


2668
2669
2670
2671
2672
2673
2674
2675
2676
2677
def _maybe_context_parallel_axis(cp_axis: str):
    if not cp_axis:
        gmr = global_mesh_resource()
        if gmr is not None:
            cp_axis = gmr.cp_resource
        else:
            cp_axis = ""
    return cp_axis


2678
2679
2680
def fused_attn_fwd(
    qkv: Tuple[jnp.ndarray, ...],
    bias: Optional[jnp.ndarray],
2681
    softmax_offset: Optional[jnp.ndarray],
2682
    sequence_descriptor: SequenceDescriptor,
2683
    seed: Optional[jnp.ndarray],
Reese Wang's avatar
Reese Wang committed
2684
2685
    attn_bias_type: AttnBiasType,
    attn_mask_type: AttnMaskType,
2686
    softmax_type: AttnSoftmaxType,
Reese Wang's avatar
Reese Wang committed
2687
    qkv_layout: QKVLayout,
2688
2689
2690
    scaling_factor: float,
    dropout_probability: float,
    is_training: bool,
2691
    max_segments_per_seq: int,
2692
    window_size: Optional[Tuple[int, int]] = None,
2693
    context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT,
2694
2695
    context_parallel_causal_load_balanced: bool = False,
    context_parallel_axis: str = "",
2696
) -> jnp.ndarray:
2697
    """
2698
2699
2700
2701
2702
2703
2704
2705
2706
2707
2708
2709
2710
    Perform the forward pass of with cuDNN fused attention implementations.

    This function implements the following formula:
        BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2
    Args:
        qkv (Tuple[jnp.ndarray, ...]): A tuple containing query, key, and value tensors.
        It supports three formats:
            - `(qkv_packed,)`: For interleaved QKV packed format, typically used when query, key,
              and value have the same shape (e.g., self-attention).
            - `(query, kv_packed)`: For separate query and KV packed format, typically used when
              query has a different shape (e.g., cross-attention).
            - `(query, key, value)`: For separate query, key, and value tensors.
        bias (Optional[jnp.ndarray]): An optional bias tensor to be added to the attention scores.
2711
        softmax_offset (Optional[jnp.ndarray]): An optional softmax offset tensor.
2712
2713
2714
2715
2716
2717
2718
        q_seqlen (jnp.ndarray): Sequence lengths for the query, with shape [batch,].
        kv_seqlen (jnp.ndarray): Sequence lengths for the key and value, with shape [batch,].
        q_seq_offsets (jnp.ndarray):
            The offsets in the sequence dim for the query, with shape [batch + 1,].
        kv_seq_offsets (jnp.ndarray):
            The offsets in the sequence dim for the query, with shape [batch + 1,].
        seed (Optional[jnp.ndarray]): Optional random seed for dropout.
Reese Wang's avatar
Reese Wang committed
2719
2720
        attn_bias_type (AttnBiasType): Type of attention bias.
        attn_mask_type (AttnMaskType): Type of attention mask.
2721
        softmax_type (AttnSoftmaxType): Type of softmax.
Reese Wang's avatar
Reese Wang committed
2722
        qkv_layout (QKVLayout): Layout of the QKV tensors.
2723
2724
2725
        scaling_factor (float): Scaling factor for the attention scores.
        dropout_probability (float): Dropout probability to apply during attention.
        is_training (bool): Flag indicating whether the model is in training mode.
2726
2727
2728
2729
2730
        max_segments_per_seq (int):
            Indicating the maximum number of segments inside a sequence. This parameter is to
            constrain the limit usage and need to be static during the e2e training. The XLA compile
            time and memory consumption is proportional to `max_segments_per_seq`.
        window_size (Optional[Tuple[int, int]]): Sliding window size.
2731
2732
2733
        context_parallel_causal_load_balanced (bool):
            Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
        context_parallel_axis (str): The name of the context parallel axis.
2734
2735
    Returns:
        (jnp.ndarray): The output tensor from the fused attention.
2736
    """
2737
2738
2739
    seed = _FusedAttnRNGStateChecker().check_seed(seed, dropout_probability, is_training)
    # For optional tensors, which custom calls doesn't support None
    _not_used = jnp.zeros(0, dtype=qkv[0].dtype)
2740

Reese Wang's avatar
Reese Wang committed
2741
2742
2743
2744
2745
2746
2747
2748
2749
2750
2751
2752
2753
2754
2755
2756
2757
    if qkv_layout.is_qkvpacked():
        assert len(qkv) == 1, f"qkv=(packed_qkv,) is expected with {qkv_layout=} but got {qkv=}"
        qkv_for_primitive = [*qkv, _not_used, _not_used]
    elif qkv_layout.is_kvpacked():
        assert (
            len(qkv) == 2
        ), f"qkv=(query, packed_kv) is expected with {qkv_layout=} but got {qkv=}"
        qkv_for_primitive = [*qkv, _not_used]
    elif qkv_layout.is_separate():
        assert (
            len(qkv) == 3
        ), f"qkv=(query, key, value) is expected with {qkv_layout=} but got {qkv=}"
        qkv_for_primitive = qkv
    else:
        raise ValueError(f"Unknown {qkv_layout=}")

    if attn_bias_type == AttnBiasType.NO_BIAS:
2758
        assert bias is None
2759
        bias = jnp.zeros(0, dtype=qkv[0].dtype)
2760

2761
2762
2763
2764
2765
2766
2767
2768
2769
2770
2771
2772
2773
2774
2775
2776
2777
2778
2779
2780
2781
2782
2783
2784
2785
    if softmax_offset is None:
        assert (
            softmax_type != AttnSoftmaxType.LEARNABLE_SOFTMAX
        ), f"Softmax type {softmax_type} is not supported when softmax_offset is None"
        if softmax_type == AttnSoftmaxType.OFF_BY_ONE_SOFTMAX:
            num_heads = qkv[0].shape[-2]
            # Create tensor [1, h, 1, 1] filled with zeros (logit value = 0)
            # This adds exp(0 - x_max) = exp(-x_max) to the denominator,
            # which contributes exactly 1 after normalization, giving: exp(x_i) / (sum(exp(x_j)) + 1)
            softmax_offset = jnp.zeros((1, num_heads, 1, 1), dtype=jnp.float32)
            # Shard by heads dimension
            softmax_offset = with_sharding_constraint_by_logical_axes(
                softmax_offset, (None, HEAD_AXES, None, None)
            )
        else:
            assert softmax_type == AttnSoftmaxType.VANILLA_SOFTMAX
            softmax_offset = jnp.zeros(0, dtype=jnp.float32)
    else:
        assert softmax_offset.dtype == jnp.float32
        # Shard by heads dimension if not VANILLA_SOFTMAX
        if softmax_type != AttnSoftmaxType.VANILLA_SOFTMAX:
            softmax_offset = with_sharding_constraint_by_logical_axes(
                softmax_offset, (None, HEAD_AXES, None, None)
            )

2786
    fused_config = _FusedAttnConfig(
2787
2788
        attn_bias_type=attn_bias_type,
        attn_mask_type=attn_mask_type,
2789
        qkv_layout=qkv_layout,
2790
        softmax_type=softmax_type,
2791
2792
2793
        scaling_factor=scaling_factor,
        dropout_probability=dropout_probability,
        is_training=is_training,
2794
        max_segments_per_seq=max_segments_per_seq,
2795
        window_size=(-1, -1) if window_size is None else window_size,
2796
2797
        context_parallel_load_balanced=context_parallel_causal_load_balanced,
        cp_axis=_maybe_context_parallel_axis(context_parallel_axis),
2798
        cp_striped_window_size=None,
2799
2800
    )

2801
    primitive = None
2802
2803
    match context_parallel_strategy:
        case CPStrategy.DEFAULT | CPStrategy.ALL_GATHER:
2804
            primitive = FusedAttnCPWithAllGatherFwdPrimitive.outer_primitive
2805
        case CPStrategy.RING:
Reese Wang's avatar
Reese Wang committed
2806
2807
2808
2809
2810
            # We must use stripe attention for THD-RING
            if qkv_layout.is_thd():
                primitive = FusedRingAttnStripedFwdPrimitive.outer_primitive
            else:
                primitive = FusedRingAttnFwdPrimitive.outer_primitive
2811

2812
    seq_desc_flatten, _ = jax.tree.flatten(sequence_descriptor)
2813
    output, softmax_aux, rng_state = primitive.bind(
2814
2815
        *qkv_for_primitive,
        bias,
2816
        softmax_offset,
2817
        seed,
2818
        *seq_desc_flatten,
2819
        config=fused_config,
2820
    )
2821
2822
    rng_state = with_sharding_constraint(rng_state, PartitionSpec(get_all_mesh_axes(), None))
    return (output, softmax_aux, rng_state)
2823
2824


2825
2826
2827
def fused_attn_bwd(
    qkv: Tuple[jnp.ndarray, ...],
    bias: Optional[jnp.ndarray],
2828
    softmax_offset: Optional[jnp.ndarray],
2829
2830
2831
2832
    softmax_aux: jnp.ndarray,
    rng_state: jnp.ndarray,
    output: jnp.ndarray,
    doutput: jnp.ndarray,
2833
    sequence_descriptor: SequenceDescriptor,
Reese Wang's avatar
Reese Wang committed
2834
2835
2836
    attn_bias_type: AttnBiasType,
    attn_mask_type: AttnMaskType,
    qkv_layout: QKVLayout,
2837
    softmax_type: AttnSoftmaxType,
2838
2839
2840
    scaling_factor: float,
    dropout_probability: float,
    is_training: bool,
2841
    max_segments_per_seq: int,
2842
    window_size: Optional[Tuple[int, int]] = None,
2843
    context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT,
2844
2845
    context_parallel_causal_load_balanced: bool = False,
    context_parallel_axis: str = "",
2846
):
2847
    """
2848
2849
2850
2851
2852
2853
2854
2855
2856
2857
2858
    Perform the backward pass of the cuDNN fused attention implementations.

    Args:
        qkv (Tuple[jnp.ndarray, ...]): A tuple containing the original query, key, and value tensors
        used in the forward pass. It supports three formats:
            - `(qkv_packed,)`: For interleaved QKV packed format, typically used when query, key,
              and value have the same shape (e.g., self-attention).
            - `(query, kv_packed)`: For separate query and KV packed format, typically used when
              query has a different shape (e.g., cross-attention).
            - `(query, key, value)`: For separate query, key, and value tensors.
        bias (Optional[jnp.ndarray]): An optional bias tensor to be added to the attention scores.
2859
        softmax_offset (Optional[jnp.ndarray]): An optional softmax offset tensor.
2860
2861
2862
2863
2864
2865
2866
2867
2868
2869
        softmax_aux (jnp.ndarray): Auxiliary tensors from the softmax step used in the forward pass.
        rng_state (jnp.ndarray): Auxiliary tensors to save the random state in the forward pass.
        output (jnp.ndarray): The output tensor from the forward pass.
        doutput (jnp.ndarray): The gradient with respect to the output.
        q_seqlen (jnp.ndarray): Sequence lengths for the query, with shape [batch,].
        kv_seqlen (jnp.ndarray): Sequence lengths for the key and value, with shape [batch,].
        q_seq_offsets (jnp.ndarray):
            The offsets in the sequence dim for the query, with shape [batch + 1,].
        kv_seq_offsets (jnp.ndarray):
            The offsets in the sequence dim for the query, with shape [batch + 1,].
Reese Wang's avatar
Reese Wang committed
2870
2871
        attn_bias_type (AttnBiasType): Type of attention bias.
        attn_mask_type (AttnMaskType): Type of attention mask.
2872
        softmax_type (AttnSoftmaxType): Type of softmax.
Reese Wang's avatar
Reese Wang committed
2873
        qkv_layout (QKVLayout): Layout of the QKV tensors.
2874
2875
2876
        scaling_factor (float): Scaling factor for the attention scores.
        dropout_probability (float): Dropout probability to apply during attention.
        is_training (bool): Flag indicating whether the model is in training mode.
2877
2878
2879
2880
2881
        max_segments_per_seq (int):
            Indicating the maximum number of segments inside a sequence. This parameter is to
            constrain the limit usage and need to be static during the e2e training. The XLA compile
            time and memory consumption is proportional to `max_segments_per_seq`.
        window_size (Optional[Tuple[int, int]]): Sliding window size .
2882
2883
2884
        context_parallel_causal_load_balanced (bool):
            Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
        context_parallel_axis (str): The name of the context parallel axis.
2885
2886
2887
2888
2889
    Returns:
        Tuple[jnp.ndarray, ...], jnp.ndarray:
        - The first tuple contains the gradients with respect to the input `qkv` tensors in the
          same format as the input `qkv`.
        - The second value is the gradient with respect to `bias`, or `None` if `bias` is `None`.
2890
    """
2891
2892
2893
    # For optional tensors, which custom calls doesn't support None
    _not_used = jnp.zeros(0, dtype=qkv[0].dtype)

Reese Wang's avatar
Reese Wang committed
2894
2895
2896
2897
2898
2899
2900
2901
2902
2903
2904
2905
2906
2907
2908
2909
2910
    if qkv_layout.is_qkvpacked():
        assert len(qkv) == 1, f"qkv=(packed_qkv,) is expected with {qkv_layout=} but got {qkv=}"
        qkv_for_primitive = [*qkv, _not_used, _not_used]
    elif qkv_layout.is_kvpacked():
        assert (
            len(qkv) == 2
        ), f"qkv=(query, packed_kv) is expected with {qkv_layout=} but got {qkv=}"
        qkv_for_primitive = [*qkv, _not_used]
    elif qkv_layout.is_separate():
        assert (
            len(qkv) == 3
        ), f"qkv=(query, key, value) is expected with {qkv_layout=} but got {qkv=}"
        qkv_for_primitive = qkv
    else:
        raise ValueError(f"Unknown {qkv_layout=}")

    if attn_bias_type == AttnBiasType.NO_BIAS:
2911
        assert bias is None
2912
        bias = jnp.zeros(0, dtype=qkv[0].dtype)
2913

2914
2915
2916
2917
2918
2919
2920
2921
2922
2923
2924
2925
2926
2927
2928
2929
2930
2931
2932
2933
2934
2935
    if softmax_offset is None:
        assert softmax_type != AttnSoftmaxType.LEARNABLE_SOFTMAX, f"Unknown {softmax_type=}"
        if softmax_type == AttnSoftmaxType.OFF_BY_ONE_SOFTMAX:
            num_heads = qkv[0].shape[-2]
            # Create tensor [1, h, 1, 1] filled with zeros
            softmax_offset = jnp.zeros((1, num_heads, 1, 1), dtype=jnp.float32)
            # Shard by heads dimension
            softmax_offset = with_sharding_constraint_by_logical_axes(
                softmax_offset, (None, HEAD_AXES, None, None)
            )
        elif softmax_type == AttnSoftmaxType.VANILLA_SOFTMAX:
            softmax_offset = jnp.zeros(0, dtype=jnp.float32)
        else:
            raise NotImplementedError(f"Unknown {softmax_type=}")
    else:
        softmax_offset = softmax_offset.astype(jnp.float32)
        # Shard by heads dimension if not VANILLA_SOFTMAX
        if softmax_type != AttnSoftmaxType.VANILLA_SOFTMAX:
            softmax_offset = with_sharding_constraint_by_logical_axes(
                softmax_offset, (None, HEAD_AXES, None, None)
            )

2936
2937
2938
2939
    # TODO(KshitijLakhani): Add a check for cuDNN version when determinism does get supported on
    # sm100+
    compute_capabilities = get_all_device_compute_capability()
    if any(x >= 100 for x in compute_capabilities):
2940
2941
        assert not (
            attn_bias_type != AttnBiasType.NO_BIAS and dropout_probability != 0
2942
        ), "For sm100+, bprop kernel support for dropout + determinism (bias) is not supported"
2943

2944
2945
2946
2947
    fused_config = _FusedAttnConfig(
        attn_bias_type=attn_bias_type,
        attn_mask_type=attn_mask_type,
        qkv_layout=qkv_layout,
2948
        softmax_type=softmax_type,
2949
2950
2951
2952
        scaling_factor=scaling_factor,
        dropout_probability=dropout_probability,
        is_training=is_training,
        max_segments_per_seq=max_segments_per_seq,
2953
        window_size=(-1, -1) if window_size is None else window_size,
2954
2955
        context_parallel_load_balanced=context_parallel_causal_load_balanced,
        cp_axis=_maybe_context_parallel_axis(context_parallel_axis),
2956
        cp_striped_window_size=None,
2957
2958
    )

2959
    primitive = None
2960
2961
    match context_parallel_strategy:
        case CPStrategy.DEFAULT | CPStrategy.ALL_GATHER:
2962
            primitive = FusedAttnCPWithAllGatherBwdPrimitive.outer_primitive
2963
        case CPStrategy.RING:
Reese Wang's avatar
Reese Wang committed
2964
2965
2966
2967
            if qkv_layout.is_thd():
                primitive = FusedRingAttnStripedBwdPrimitive.outer_primitive
            else:
                primitive = FusedRingAttnBwdPrimitive.outer_primitive
2968
2969

    seq_desc_flatten, _ = jax.tree.flatten(sequence_descriptor)
2970
    *qkv_grads, bias_grad, softmax_offset_grad = primitive.bind(
2971
        *qkv_for_primitive,
2972
        bias,
2973
        softmax_offset,
2974
2975
2976
2977
        softmax_aux,
        rng_state,
        output,
        doutput,
2978
        *seq_desc_flatten,
2979
        config=fused_config,
2980
    )
2981
    return tuple(qkv_grads[: len(qkv)]), bias_grad, softmax_offset_grad