attention.py 107 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.
"""JAX/TE custom ops for attention"""
import operator
6
import os
7
import warnings
8
9
10
from dataclasses import dataclass, replace
from functools import partial, reduce
from typing import Optional, Tuple
11

12
import jax
13
import jax.numpy as jnp
14
from jax import dtypes, lax, ffi
15
from jax.sharding import PartitionSpec, NamedSharding
16
from jax.experimental.custom_partitioning import SdyShardingRule
17
18
19

import transformer_engine_jax
from transformer_engine_jax import NVTE_Fused_Attn_Backend
Reese Wang's avatar
Reese Wang committed
20
21
22
23
24
25
26
27
from transformer_engine.jax.attention import (
    AttnBiasType,
    AttnMaskType,
    QKVLayout,
    QKVFormat,
    CPStrategy,
    SequenceDescriptor,
)
28

29
30
31
32
33
from .base import BasePrimitive, register_primitive
from .misc import (
    check_valid_batch_dims,
    jax_dtype_to_te_dtype,
    te_dtype_to_jax_dtype,
34
    get_padded_spec,
35
    get_cudnn_version,
36
    get_all_device_compute_capability,
37
38
)
from ..sharding import (
39
40
    global_mesh_resource,
    lax_paral_op,
41
    all_reduce_sum_along_dp_fsdp,
42
43
    get_mesh_axis_size,
    get_mesh_axis_rank,
44
    get_mesh_axis_rank_host,
45
46
    get_all_mesh_axes,
    num_of_devices,
47
    with_sharding_constraint,
48
49
50
)


51
52
53
54
55
__all__ = [
    "FusedAttnHelper",
    "fused_attn_fwd",
    "fused_attn_bwd",
]
56
57


58
59
60
61
62
63
64
65
66
67
68
@partial(
    jax.tree_util.register_dataclass,
    data_fields=[],
    meta_fields=[
        "attn_bias_type",
        "attn_mask_type",
        "qkv_layout",
        "scaling_factor",
        "dropout_probability",
        "is_training",
        "max_segments_per_seq",
69
        "window_size",
70
71
        "context_parallel_load_balanced",
        "cp_axis",
72
        "cp_striped_window_size",
73
74
75
76
77
78
79
80
    ],
)
@dataclass(frozen=True)
class _FusedAttnConfig:
    """
    Passes static configuration properties of fused attention.
    """

Reese Wang's avatar
Reese Wang committed
81
82
83
    attn_bias_type: AttnBiasType
    attn_mask_type: AttnMaskType
    qkv_layout: QKVLayout
84
85
86
87
    scaling_factor: float
    dropout_probability: float
    is_training: bool
    max_segments_per_seq: int
88
    window_size: Tuple[int, int]
89
90
    context_parallel_load_balanced: bool
    cp_axis: str
91
    cp_striped_window_size: Tuple[int, int]  # Only for CP + Ring + THD + SWA
92
93


94
95
96
97
98
99
@dataclass(frozen=True)
class FusedAttnHelper:
    """
    Helper for the fused attention backend
    """

100
    is_training: bool
101
102
    q_dtype: jnp.dtype
    kv_dtype: jnp.dtype
Reese Wang's avatar
Reese Wang committed
103
104
105
    qkv_layout: QKVLayout
    attn_bias_type: AttnBiasType
    attn_mask_type: AttnMaskType
106
107
108
109
110
    dropout_probability: float
    q_num_heads: int
    kv_num_heads: int
    q_max_seqlen: int
    kv_max_seqlen: int
111
112
    head_dim_qk: int
    head_dim_v: int
113
    window_size: Tuple[int, int]
114
115
116
117
118
119
120
121

    def is_fused_attn_kernel_available(self):
        """Check if there is available fused attention kernel"""
        return self.get_fused_attn_backend() != NVTE_Fused_Attn_Backend.NVTE_No_Backend

    def get_fused_attn_backend(self):
        """Get the fused attention kernel backend"""
        return transformer_engine_jax.get_fused_attn_backend(
122
            self.is_training,
123
124
            jax_dtype_to_te_dtype(self.q_dtype),
            jax_dtype_to_te_dtype(self.kv_dtype),
Reese Wang's avatar
Reese Wang committed
125
126
127
            self.qkv_layout.value,
            self.attn_bias_type.value,
            self.attn_mask_type.value,
128
129
130
131
132
            self.dropout_probability,
            self.q_num_heads,
            self.kv_num_heads,
            self.q_max_seqlen,
            self.kv_max_seqlen,
133
134
            self.head_dim_qk,
            self.head_dim_v,
135
136
            self.window_size[0],
            self.window_size[1],
137
        )
138

139
140
141
142
143
    @staticmethod
    def is_non_deterministic_allowed():
        """Check if non-deterministic kernels are allowed"""
        return bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1")))

144
145
146
    @staticmethod
    def parse_qkv_aval(q_aval, k_aval, v_aval, qkv_layout):
        """Parse qkv aval"""
Reese Wang's avatar
Reese Wang committed
147
148
149
150
151
152
153
        if qkv_layout.get_qkv_format() == QKVFormat.SBHD:
            raise NotImplementedError
        if qkv_layout.is_qkvpacked():
            *q_batch_shape, q_max_seqlen, nqkv, attn_heads, q_head_dim = q_aval.shape
            kv_batch_shape = q_batch_shape
            kv_max_seqlen = q_max_seqlen
            num_gqa_groups = attn_heads
154
            v_head_dim = q_head_dim
Reese Wang's avatar
Reese Wang committed
155
156
157
            assert nqkv == 3
        elif qkv_layout.is_kvpacked():
            *q_batch_shape, q_max_seqlen, attn_heads, q_head_dim = q_aval.shape
158
159
160
            *kv_batch_shape, kv_max_seqlen, nkv, num_gqa_groups, v_head_dim = k_aval.shape
            assert q_batch_shape == kv_batch_shape
            assert q_head_dim == v_head_dim
Reese Wang's avatar
Reese Wang committed
161
162
163
            assert nkv == 2
        elif qkv_layout.is_separate():
            *q_batch_shape, q_max_seqlen, attn_heads, q_head_dim = q_aval.shape
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
            *k_batch_shape, k_max_seqlen, k_num_gqa_groups, k_head_dim = k_aval.shape
            *v_batch_shape, v_max_seqlen, v_num_gqa_groups, v_head_dim = v_aval.shape
            assert (
                q_head_dim == k_head_dim
            ), f"Mismatched q_head_dim: {q_head_dim} and k_head_dim: {k_head_dim}"
            assert (
                k_max_seqlen == v_max_seqlen
            ), f"Mismatched k_max_seqlen: {k_max_seqlen} and v_max_seqlen: {v_max_seqlen}"
            kv_max_seqlen = k_max_seqlen
            assert q_batch_shape == k_batch_shape == v_batch_shape, (
                f"Mismatched qkv batch size for q_batch_shape: {q_batch_shape}, k_batch_shape:"
                f" {k_batch_shape} and v_batch_shape: {v_batch_shape}"
            )
            assert k_num_gqa_groups == v_num_gqa_groups, (
                f"Mismatched k_num_gqa_groups: {k_num_gqa_groups} and v_num_gqa_groups:"
                f" {v_num_gqa_groups}"
            )
            num_gqa_groups = k_num_gqa_groups
Reese Wang's avatar
Reese Wang committed
182
183
        else:
            raise ValueError(f"Unexpected {qkv_layout=}")
184
185
186
187
188
189
190
191
192
193
194
195
196
        assert q_aval.dtype == k_aval.dtype == v_aval.dtype, (
            f"Mismatched data types for q_aval: {q_aval.dtype}, k_aval: {k_aval.dtype}, v_aval:"
            f" {v_aval.dtype}"
        )
        return (
            q_batch_shape,
            q_max_seqlen,
            kv_max_seqlen,
            attn_heads,
            num_gqa_groups,
            q_head_dim,
            v_head_dim,
        )
197
198
199
200
201
202
203
204
205
206
207


@dataclass(frozen=True)
class _FusedAttnRNGStateChecker:
    """
    Checker for guarding the fused attention rng state.
    The fused attention backend requires a 64 bits seed and a 64 bits offset.
    However, JAX doesn't enable 64 bits by default,
    so we have to emulate seed as two 32 bits array.
    The offset calculation is maintained in the backend.
    """
208

209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
    rng_state_dtype: jnp.dtype = jnp.uint32
    # (seed,) with internal dtype int64
    seed_size: int = 2
    # (seed, offset) with internal dtype int64
    rng_state_size: int = 2 * 2

    def check_seed(self, seed, dropout_probability, is_training):
        """
        Check the seed and convert the data type of seed if possible.
        """
        # Jax can't bind None, create a dummy tensor for None
        if seed is None:
            dropout_enabled = dropout_probability > 0 and is_training
            assert not dropout_enabled, "seed is not allowed to be None when dropout is enabled."
            seed = jnp.zeros(2, dtype=self.rng_state_dtype)
            seed = jnp.repeat(seed, num_of_devices())

        if seed.dtype != self.rng_state_dtype:
            warnings.warn(
                f"Requested {seed.dtype=} is not available, and will be "
                f"casted to dtype {self.rng_state_dtype}. "
230
231
                "Please use threefry/rbg/unsafe_rbg PRNG implementations to remove this warning."
            )
232
233
234
235
236
237
238
239
240
241
242
243
244
            seed = seed.astype(self.rng_state_dtype)

        assert seed.dtype == self.rng_state_dtype
        # Backend takes an int64_t seed, so only the first two u32 elements are taken
        assert seed.size >= self.seed_size

        return seed


def generate_cu_seqlen(actual_seqlen):
    """
    Generating cumsum seqlen for a batch
    """
245
246
    actual_seqlen = jnp.where(actual_seqlen < 0, 0, actual_seqlen)
    cu_seqlen = jnp.cumulative_sum(actual_seqlen, include_initial=True)
247
248
249
250
251
252
253
    return cu_seqlen


class FusedAttnFwdPrimitive(BasePrimitive):
    """
    Fused Attention Forward Primitive
    """
254

255
    name = "te_fused_attn_forward_ffi"
256
    multiple_results = True
257
    impl_static_args = (13,)
258
259
260
261
    inner_primitive = None
    outer_primitive = None

    @staticmethod
262
263
264
265
266
    def abstract(
        q_aval,
        k_aval,
        v_aval,
        bias_aval,
267
        seed_aval,
268
269
        q_seqlen_or_cu_seqlen_aval,
        kv_seqlen_or_cu_seqlen_aval,
270
271
        _q_seq_offsets,
        _k_seq_offsets,
272
273
274
275
        _q_segment_ids,
        _kv_segment_ids,
        _q_segment_pos,
        _kv_segment_pos,
276
        *,
277
        config: _FusedAttnConfig,
278
    ):
279
280
281
282
283
284
285
        """
        Fused attention fwd abstract
        """
        q_dtype = dtypes.canonicalize_dtype(q_aval.dtype)
        k_dtype = dtypes.canonicalize_dtype(k_aval.dtype)
        v_dtype = dtypes.canonicalize_dtype(v_aval.dtype)
        bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype)
286
287
288
289
290
291
292
        assert (
            q_dtype == k_dtype == v_dtype == bias_dtype
        ), f"q_dtype={q_dtype}, k_dtype={k_dtype}, v_dtype={v_dtype}, bias_dtype={bias_dtype}"
        assert q_seqlen_or_cu_seqlen_aval.dtype == kv_seqlen_or_cu_seqlen_aval.dtype, (
            f"q_seqlen_or_cu_seqlen_aval={q_seqlen_or_cu_seqlen_aval},"
            f" kv_seqlen_or_cu_seqlen_aval={kv_seqlen_or_cu_seqlen_aval}"
        )
293

294
295
296
297
298
299
300
301
302
        (
            batch_shape,
            q_max_seqlen,
            kv_max_seqlen,
            attn_heads,
            num_gqa_groups,
            q_head_dim,
            v_head_dim,
        ) = FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout)
303

304
        output_shape = (*batch_shape, q_max_seqlen, attn_heads, v_head_dim)
305
306
307
        out_aval = q_aval.update(shape=output_shape, dtype=q_dtype)

        # backend determines the softmax buffer shape/dtype
308
        backend = FusedAttnHelper(
309
            config.is_training,
310
311
            q_dtype,
            k_dtype,
312
313
314
315
            config.qkv_layout,
            config.attn_bias_type,
            config.attn_mask_type,
            config.dropout_probability,
316
317
318
319
            attn_heads,
            num_gqa_groups,
            q_max_seqlen,
            kv_max_seqlen,
320
321
            q_head_dim,
            v_head_dim,
322
            config.window_size,
323
        ).get_fused_attn_backend()
324
325
326
327
328

        if backend == NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen:
            softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, kv_max_seqlen)
            softmax_dtype = q_dtype
        elif backend == NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen:
329
330
            # cuDNN 9.6 reduces the required softmax shape
            if get_cudnn_version() >= (9, 6, 0):
331
332
333
334
                if config.qkv_layout.is_thd():
                    softmax_shape = (*batch_shape, q_max_seqlen, attn_heads, 1)
                else:
                    softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, 1)
335
336
337
338
339
340
341
            else:
                softmax_shape = (
                    *batch_shape,
                    attn_heads,
                    q_max_seqlen,
                    config.max_segments_per_seq,
                )
342
343
            softmax_dtype = dtypes.canonicalize_dtype(jnp.float32)
        else:
344
            raise ValueError(f"Unsupported {backend=}")
345
346
347
348
349
350
351
352
353
354
        softmax_aux_aval = q_aval.update(shape=softmax_shape, dtype=softmax_dtype)

        # JAX does not enable 64-bit int by default so we get XLA to allocate x8 memory with
        # 32-bit unsigned int to get the buffer size we need in the C++ kernel
        checker = _FusedAttnRNGStateChecker()
        seed_dtype = dtypes.canonicalize_dtype(seed_aval.dtype)
        assert seed_dtype == checker.rng_state_dtype
        rng_state_shape = (seed_aval.shape[0], checker.rng_state_size)
        rng_state_aval = seed_aval.update(shape=rng_state_shape, dtype=checker.rng_state_dtype)

Reese Wang's avatar
Reese Wang committed
355
        if config.attn_bias_type == AttnBiasType.NO_BIAS:
356
357
358
359
360
361
362
363
364
            bias_batch = bias_heads = 0
        else:
            *bias_batch_shape, bias_heads, _, _ = bias_aval.shape
            bias_batch = reduce(operator.mul, bias_batch_shape)

        # do a dummy kernel call here to get workspace buffer shapes/dtypes that XLA needs to
        # prepare for the active fused-attn backend
        input_batch = reduce(operator.mul, batch_shape)
        wkspace_info = transformer_engine_jax.get_fused_attn_fwd_workspace_sizes(
365
366
367
368
369
370
371
            input_batch,
            bias_batch,
            q_max_seqlen,
            kv_max_seqlen,
            attn_heads,
            num_gqa_groups,
            bias_heads,
372
373
            q_head_dim,
            v_head_dim,
374
375
            config.scaling_factor,
            config.dropout_probability,
Reese Wang's avatar
Reese Wang committed
376
377
378
            config.attn_bias_type.value,
            config.attn_mask_type.value,
            config.qkv_layout.value,
379
            jax_dtype_to_te_dtype(q_aval.dtype),
380
381
            config.is_training,
            config.max_segments_per_seq,
382
383
            config.window_size[0],
            config.window_size[1],
384
385
386
387
        )
        wkspace_aval = q_aval.update(
            shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1])
        )
388
389
390
391
392
393
394
395

        return out_aval, softmax_aux_aval, rng_state_aval, wkspace_aval

    @staticmethod
    def outer_abstract(*args, **kwargs):
        """
        Fused attention fwd outer primitive abstract
        """
396
397
398
        out_aval, softmax_aux_aval, rng_state_aval, _ = FusedAttnFwdPrimitive.abstract(
            *args, **kwargs
        )
399
400
401
        return out_aval, softmax_aux_aval, rng_state_aval

    @staticmethod
402
403
404
405
406
407
    def lowering(
        ctx,
        q,
        k,
        v,
        bias,
408
        seed,
409
410
        q_cu_seqlen,
        kv_cu_seqlen,
411
412
        q_seq_offsets,
        k_seq_offsets,
413
414
415
416
        _q_segment_ids,
        _kv_segment_ids,
        _q_segment_pos,
        _kv_segment_pos,
417
        *,
418
        config: _FusedAttnConfig,
419
    ):
420
421
422
423
424
        """
        Fused attention fwd lowering rules
        """
        q_aval, k_aval, v_aval, bias_aval, *_ = ctx.avals_in

425
426
427
428
429
430
431
432
433
        (
            batch_shape,
            q_max_seqlen,
            kv_max_seqlen,
            attn_heads,
            num_gqa_groups,
            q_head_dim,
            v_head_dim,
        ) = FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout)
434
435
436

        input_batch = reduce(operator.mul, batch_shape)

Reese Wang's avatar
Reese Wang committed
437
        if config.attn_bias_type == AttnBiasType.NO_BIAS:
438
439
440
441
442
            bias_batch = bias_heads = 0
        else:
            *bias_batch_shape, bias_heads, _, _ = bias_aval.shape
            bias_batch = reduce(operator.mul, bias_batch_shape)

443
444
445
446
447
448
449
        if config.cp_striped_window_size is not None:
            window_size_left = config.cp_striped_window_size[0]
            window_size_right = config.cp_striped_window_size[1]
        else:
            window_size_left = config.window_size[0]
            window_size_right = config.window_size[1]

450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
        return ffi.ffi_lowering(FusedAttnFwdPrimitive.name)(
            ctx,
            q,
            k,
            v,
            bias,
            seed,
            q_cu_seqlen,
            kv_cu_seqlen,
            q_seq_offsets,
            k_seq_offsets,
            _q_segment_ids,
            _kv_segment_ids,
            _q_segment_pos,
            _kv_segment_pos,  # ffi_lowering needs number of parameters meets primitive.lowering
            input_batch=input_batch,
            bias_batch=bias_batch,
            q_max_seqlen=q_max_seqlen,
            kv_max_seqlen=kv_max_seqlen,
            attn_heads=attn_heads,
            num_gqa_groups=num_gqa_groups,
            bias_heads=bias_heads,
472
473
            qk_head_dim=q_head_dim,
            v_head_dim=v_head_dim,
474
475
476
477
478
479
480
481
            max_segments_per_seq=config.max_segments_per_seq,
            scaling_factor=float(config.scaling_factor),
            dropout_probability=float(config.dropout_probability),
            bias_type=int(config.attn_bias_type.value),
            mask_type=int(config.attn_mask_type.value),
            qkv_layout=int(config.qkv_layout.value),
            is_training=config.is_training,
            deterministic=not FusedAttnHelper.is_non_deterministic_allowed(),
482
483
            window_size_left=window_size_left,
            window_size_right=window_size_right,
484
        )
485
486

    @staticmethod
487
488
489
490
491
    def impl(
        q,
        k,
        v,
        bias,
492
        seed,
493
494
        q_seqlen,
        kv_seqlen,
495
496
        q_seq_offsets,
        k_seq_offsets,
497
498
499
500
        _q_segment_ids,
        _kv_segment_ids,
        _q_segment_pos,
        _kv_segment_pos,
501
        config: _FusedAttnConfig,
502
    ):
503
504
        assert FusedAttnFwdPrimitive.inner_primitive is not None

505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
        sequence_descriptor = SequenceDescriptor(
            seqlens=(q_seqlen, kv_seqlen),
            seq_offsets=(q_seq_offsets, k_seq_offsets),
            segment_ids=(_q_segment_ids, _kv_segment_ids),
            segment_pos=(_q_segment_pos, _kv_segment_pos),
        )

        (q_seqlen, kv_seqlen), (q_seq_offsets, k_seq_offsets) = (
            sequence_descriptor.get_seqlens_and_offsets(
                config.attn_mask_type,
                config.qkv_layout,
                config.window_size,
                config.max_segments_per_seq,
            )
        )

Reese Wang's avatar
Reese Wang committed
521
        if config.qkv_layout.is_thd():
522

523
            def _fix_len_take(x, condition, fill_value=-1):
524
525
526
527
                x_shape = x.shape
                x = x.flatten()
                size = x.size
                indices = jnp.nonzero(condition.flatten(), size=size, fill_value=size)[0]
528
                y = jnp.take(x, indices, fill_value=fill_value)
529
530
531
532
533
534
535
536
537
538
                return jnp.reshape(y, x_shape)

            def convert_to_2d(offsets, batch, max_seqlen):
                offsets_2d = jnp.where(
                    offsets >= 0,
                    offsets + (jnp.arange(batch) * max_seqlen)[..., jnp.newaxis],
                    offsets,
                )
                return offsets_2d

Reese Wang's avatar
Reese Wang committed
539
540
541
542
543
            batch, q_max_seqlen, kv_max_seqlen, *_ = FusedAttnHelper.parse_qkv_aval(
                q, k, v, config.qkv_layout
            )
            assert len(batch) == 1, f"Expected len(batch) == 1, but got {len(batch)=}"
            kv_batch = q_batch = batch[0]
544
545

            # Gather valid q_seqlen, which is greater than 0
546
            # cuDNN version < 9.3.0:
547
            # [[3, 5, 7, -1, -1], [2, 4, 6, -1, -1]] -> [[3, 5, 7, 2, 4], [6, -1, -1, -1, -1]]
548
549
550
551
552
553
            # cuDNN version >= 9.3.0, which supports act_seqlen = 0
            # [[3, 5, 7, -1, -1], [2, 4, 6, -1, -1]] -> [[3, 5, 7, 2, 4], [6, 0, 0, 0, 0]]
            if get_cudnn_version() >= (9, 3, 0):
                fill_value = 0
            else:
                fill_value = -1
554

555
556
            q_seqlen = _fix_len_take(q_seqlen, q_seqlen > 0, fill_value=fill_value)
            kv_seqlen = _fix_len_take(kv_seqlen, kv_seqlen > 0, fill_value=fill_value)
557
558
559
560
561

            # Flatten the offset calculation
            # max_seqlen = 8, [[0, 3, 5, -1], [0, 2, 4, -1]] -> [[0, 3, 5, -1], [8, 11, 13, -1]]
            q_seq_offsets = convert_to_2d(q_seq_offsets, q_batch, q_max_seqlen)
            k_seq_offsets = convert_to_2d(k_seq_offsets, kv_batch, kv_max_seqlen)
562

563
564
            # Gather valid q_seq_offsets, which is greater and equal to 0
            # [[0, 3, 5, -1], [8, 11, 13, -1]] -> [[0, 3, 5, 8], [11, 13, -1, -1]]
565
            # And set the unused position to max size (batch * max_seqlen)
566
            # [[0, 3, 5, 8], [11, 13, -1, -1]] -> [[0, 3, 5, 8], [11, 13, b*s, b*s]]
567
568
569
570
571
572
            q_seq_offsets = _fix_len_take(
                q_seq_offsets, q_seq_offsets >= 0, fill_value=q_batch * q_max_seqlen
            )
            k_seq_offsets = _fix_len_take(
                k_seq_offsets, k_seq_offsets >= 0, fill_value=kv_batch * kv_max_seqlen
            )
573
574
575

        q_cu_seqlen = generate_cu_seqlen(q_seqlen.flatten())
        kv_cu_seqlen = generate_cu_seqlen(kv_seqlen.flatten())
576
577
578
579
580
581

        output, softmax_aux, rng_state, _ = FusedAttnFwdPrimitive.inner_primitive.bind(
            q,
            k,
            v,
            bias,
582
            seed,
583
584
            q_cu_seqlen,
            kv_cu_seqlen,
585
586
            q_seq_offsets,
            k_seq_offsets,
587
588
589
590
            _q_segment_ids,
            _kv_segment_ids,
            _q_segment_pos,
            _kv_segment_pos,
591
            config=config,
592
        )
593
594
595
        return output, softmax_aux, rng_state

    @staticmethod
596
    def batcher(batched_args, batch_dims, *, config):
597
598
        check_valid_batch_dims(batch_dims)
        assert FusedAttnFwdPrimitive.outer_primitive is not None
599
        q_bdim, _, _, _, seed_bdim, *_ = batch_dims
600
601

        out_bdims = q_bdim, q_bdim, seed_bdim
602
        return (
603
            FusedAttnFwdPrimitive.outer_primitive.bind(*batched_args, config=config),
604
605
            out_bdims,
        )
606
607

    @staticmethod
608
609
    def infer_sharding_from_operands(config, mesh, arg_infos, result_infos):
        del result_infos
610
        q_spec = get_padded_spec(arg_infos[0])
611
612
613
614
615

        # when supported softmax_aux shape is (b, s, h, 1) for thd on cudnn 9.6+
        # otherwise softmax_aux shape is (b, h, s, 1) or (b, h, s, max_segments)
        is_packed_softmax = get_cudnn_version() >= (9, 6, 0) and config.qkv_layout.is_thd()

Reese Wang's avatar
Reese Wang committed
616
617
618
        if config.qkv_layout.is_qkvpacked():
            # q_spec = (...batch, q_seqlen, 3, head, hidden)
            out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec[:-3], *q_spec[-2:]))
619
620
621
622
623
624
625
626
            if not is_packed_softmax:
                softmax_aux_sharding = NamedSharding(
                    mesh, PartitionSpec(*q_spec[:-4], q_spec[-2], q_spec[-4], None)
                )
            else:
                softmax_aux_sharding = NamedSharding(
                    mesh, PartitionSpec(*q_spec[:-4], q_spec[-4], q_spec[-2], None)
                )
Reese Wang's avatar
Reese Wang committed
627
628
629
630
        elif config.qkv_layout.is_kvpacked():
            # q_spec = (...batch, q_seqlen, head, hidden)
            # k_spec = (...batch, kv_seqlen, 2, num_gqa_groups, hidden)
            out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
631
632
633
634
635
636
637
638
            if not is_packed_softmax:
                softmax_aux_sharding = NamedSharding(
                    mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], None)
                )
            else:
                softmax_aux_sharding = NamedSharding(
                    mesh, PartitionSpec(*q_spec[:-3], q_spec[-3], q_spec[-2], None)
                )
Reese Wang's avatar
Reese Wang committed
639
640
641
642
        elif config.qkv_layout.is_separate():
            # q_spec = (...batch, q_seqlen, head, hidden)
            # k_spec = (...batch, kv_seqlen, num_gqa_groups, hidden)
            out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
643
644
645
646
647
648
649
650
            if not is_packed_softmax:
                softmax_aux_sharding = NamedSharding(
                    mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], None)
                )
            else:
                softmax_aux_sharding = NamedSharding(
                    mesh, PartitionSpec(*q_spec[:-3], q_spec[-3], q_spec[-2], None)
                )
Reese Wang's avatar
Reese Wang committed
651
652
        else:
            raise ValueError(f"Unsupported {config.qkv_layout=}")
653

654
655
656
657
        rng_state_sharding = NamedSharding(mesh, PartitionSpec(get_all_mesh_axes(), None))
        return (out_sharding, softmax_aux_sharding, rng_state_sharding)

    @staticmethod
658
    def partition(config, mesh, arg_infos, result_infos):
659
660
        out_sharding = result_infos[0].sharding
        softmax_aux_sharding = result_infos[1].sharding
661
662
663
        rng_state_sharding = seed_sharding = NamedSharding(
            mesh, PartitionSpec(get_all_mesh_axes(), None)
        )
664
665
        arg_shardings = [arg_i.sharding for arg_i in arg_infos]
        arg_shardings[4] = seed_sharding
666
667
        arg_shardings[-1] = arg_shardings[-3]
        arg_shardings[-2] = arg_shardings[-4]
668
        arg_shardings = tuple(arg_shardings)
669
        out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding)
670
        impl = partial(FusedAttnFwdPrimitive.impl, config=config)
671
672
        return mesh, impl, out_shardings, arg_shardings

673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
    @staticmethod
    def shardy_sharding_rule(config, mesh, value_types, result_types):
        del mesh, result_types

        # Keep in sync with `infer_sharding_from_operands`.
        # We only need the first input. Fill up the rest with placeholders.
        input_spec = [(f"…{x}",) for x in range(len(value_types))]
        # The RNG state sharding cannot be expressed as a Shardy rule. We use with_sharding_constraint
        # instead. This has to happen outside of the primitive, see `fused_attn_fwd`.
        rng_sharding = (f"…{len(value_types)}",)

        if config.qkv_layout.is_qkvpacked():
            input_spec[0] = ("…0", "seqlen", "three", "head", "hidden")
        elif config.qkv_layout.is_kvpacked() or config.qkv_layout.is_separate():
            input_spec[0] = ("…0", "seqlen", "head", "hidden")
        else:
            raise ValueError(f"Unsupported {config.qkv_layout=}")

        is_packed_softmax = get_cudnn_version() >= (9, 6, 0) and config.qkv_layout.is_thd()
        out_sharding = ("…0", "seqlen", "head", "hidden")
        if is_packed_softmax:
            softmax_aux_sharding = ("…0", "seqlen", "head", "i")
        else:
            softmax_aux_sharding = ("…0", "head", "seqlen", "i")

        return SdyShardingRule(
            tuple(input_spec), (out_sharding, softmax_aux_sharding, rng_sharding)
        )

702
703
704
705
706
707
708
709

register_primitive(FusedAttnFwdPrimitive)


class FusedAttnBwdPrimitive(BasePrimitive):
    """
    Fused Attention Backward Primitive
    """
710

711
    name = "te_fused_attn_backward_ffi"
712
    multiple_results = True
713
    impl_static_args = (16,)
714
715
716
717
    inner_primitive = None
    outer_primitive = None

    @staticmethod
718
719
720
721
722
723
724
725
726
    def abstract(
        q_aval,
        k_aval,
        v_aval,
        bias_aval,
        softmax_aux_aval,
        rng_state_aval,
        output_aval,
        doutput_aval,
727
728
729
730
        q_seqlen_or_cu_seqlen_aval,
        kv_seqlen_or_cu_seqlen_aval,
        _q_seq_offsets,
        _k_seq_offsets,
731
732
733
734
        _q_segment_ids,
        _kv_segment_ids,
        _q_segment_pos,
        _kv_segment_pos,
735
        *,
736
        config,
737
    ):
738
739
740
741
742
743
744
745
746
747
748
        """
        Fused attention bwd abstract
        """
        del softmax_aux_aval, rng_state_aval, output_aval

        q_dtype = dtypes.canonicalize_dtype(q_aval.dtype)
        k_dtype = dtypes.canonicalize_dtype(k_aval.dtype)
        v_dtype = dtypes.canonicalize_dtype(v_aval.dtype)
        bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype)
        doutput_dtype = dtypes.canonicalize_dtype(doutput_aval.dtype)
        assert q_dtype == k_dtype == v_dtype == bias_dtype == doutput_dtype
749
        assert q_seqlen_or_cu_seqlen_aval.dtype == kv_seqlen_or_cu_seqlen_aval.dtype
750

751
752
753
754
755
756
757
758
759
        (
            batch_shape,
            q_max_seqlen,
            kv_max_seqlen,
            attn_heads,
            num_gqa_groups,
            qk_head_dim,
            v_head_dim,
        ) = FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout)
760

Reese Wang's avatar
Reese Wang committed
761
        if config.attn_bias_type == AttnBiasType.NO_BIAS:
762
763
764
765
766
            bias_batch = bias_heads = 0
        else:
            *bias_batch_shape, bias_heads, _, _ = bias_aval.shape
            bias_batch = reduce(operator.mul, bias_batch_shape)

767
768
        deterministic = not FusedAttnHelper.is_non_deterministic_allowed()

769
        input_batch = reduce(operator.mul, batch_shape)
770
771
772
773
774
775
776
777
        wkspace_shape, wkspace_dtype = transformer_engine_jax.get_fused_attn_bwd_workspace_sizes(
            input_batch,
            bias_batch,
            q_max_seqlen,
            kv_max_seqlen,
            attn_heads,
            num_gqa_groups,
            bias_heads,
778
779
            qk_head_dim,
            v_head_dim,
780
781
            config.scaling_factor,
            config.dropout_probability,
Reese Wang's avatar
Reese Wang committed
782
783
784
            config.attn_bias_type.value,
            config.attn_mask_type.value,
            config.qkv_layout.value,
785
            jax_dtype_to_te_dtype(q_aval.dtype),
786
            config.is_training,
787
            deterministic,
788
            config.max_segments_per_seq,
789
790
            config.window_size[0],
            config.window_size[1],
791
        )
792
793
794
795
796

        dq_aval = q_aval.update(shape=q_aval.shape, dtype=q_dtype)
        dk_aval = k_aval.update(shape=k_aval.shape, dtype=k_dtype)
        dv_aval = v_aval.update(shape=v_aval.shape, dtype=v_dtype)
        dbias_aval = bias_aval.update(shape=bias_aval.shape, dtype=bias_dtype)
797
798
799
        wkspace_aval = q_aval.update(
            shape=wkspace_shape, dtype=te_dtype_to_jax_dtype(wkspace_dtype)
        )
800
801
802
803
804
805
806
807

        return dq_aval, dk_aval, dv_aval, dbias_aval, wkspace_aval

    @staticmethod
    def outer_abstract(*args, **kwargs):
        """
        Fused attention fwd outer primitive abstract
        """
808
        dq_aval, dk_aval, dv_aval, dbias_aval, _ = FusedAttnBwdPrimitive.abstract(*args, **kwargs)
809
810
811
        return dq_aval, dk_aval, dv_aval, dbias_aval

    @staticmethod
812
813
814
815
816
817
818
819
820
821
822
823
    def lowering(
        ctx,
        q,
        k,
        v,
        bias,
        softmax_aux,
        rng_state,
        output,
        doutput,
        q_cu_seqlen,
        kv_cu_seqlen,
824
825
        q_seq_offsets,
        k_seq_offsets,
826
827
828
829
        q_segment_ids,
        kv_segment_ids,
        q_segment_pos,
        kv_segment_pos,
830
        *,
831
        config,
832
    ):
833
834
835
836
837
        """
        Fused attention bwd lowering rules
        """
        q_aval, k_aval, v_aval, bias_aval, *_ = ctx.avals_in

838
839
840
841
842
843
844
845
846
        (
            batch_shape,
            q_max_seqlen,
            kv_max_seqlen,
            attn_heads,
            num_gqa_groups,
            qk_head_dim,
            v_head_dim,
        ) = FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout)
847
848
849

        input_batch = reduce(operator.mul, batch_shape)

Reese Wang's avatar
Reese Wang committed
850
        if config.attn_bias_type == AttnBiasType.NO_BIAS:
851
852
853
854
855
            bias_batch = bias_heads = 0
        else:
            *bias_batch_shape, bias_heads, _, _ = bias_aval.shape
            bias_batch = reduce(operator.mul, bias_batch_shape)

856
857
858
859
860
861
862
        if config.cp_striped_window_size is not None:
            window_size_left = config.cp_striped_window_size[0]
            window_size_right = config.cp_striped_window_size[1]
        else:
            window_size_left = config.window_size[0]
            window_size_right = config.window_size[1]

863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
        return ffi.ffi_lowering(FusedAttnBwdPrimitive.name)(
            ctx,
            q,
            k,
            v,
            bias,
            softmax_aux,
            rng_state,
            output,
            doutput,
            q_cu_seqlen,
            kv_cu_seqlen,
            q_seq_offsets,
            k_seq_offsets,
            q_segment_ids,
            kv_segment_ids,
            q_segment_pos,
            kv_segment_pos,  # ffi_lowering needs number of parameters meets primitive.lowering
            input_batch=input_batch,
            bias_batch=bias_batch,
            q_max_seqlen=q_max_seqlen,
            kv_max_seqlen=kv_max_seqlen,
            attn_heads=attn_heads,
            num_gqa_groups=num_gqa_groups,
            bias_heads=bias_heads,
888
889
            qk_head_dim=qk_head_dim,
            v_head_dim=v_head_dim,
890
891
892
893
894
895
896
897
            max_segments_per_seq=config.max_segments_per_seq,
            scaling_factor=float(config.scaling_factor),
            dropout_probability=float(config.dropout_probability),
            bias_type=int(config.attn_bias_type.value),
            mask_type=int(config.attn_mask_type.value),
            qkv_layout=int(config.qkv_layout.value),
            is_training=config.is_training,
            deterministic=not FusedAttnHelper.is_non_deterministic_allowed(),
898
899
            window_size_left=window_size_left,
            window_size_right=window_size_right,
900
        )
901
902

    @staticmethod
903
904
905
906
907
908
909
910
911
912
913
    def impl(
        q,
        k,
        v,
        bias,
        softmax_aux,
        rng_state,
        output,
        doutput,
        q_seqlen,
        kv_seqlen,
914
915
        q_seq_offsets,
        k_seq_offsets,
916
917
918
919
        _q_segment_ids,
        _kv_segment_ids,
        _q_segment_pos,
        _kv_segment_pos,
920
        config,
921
    ):
922
923
        assert FusedAttnBwdPrimitive.inner_primitive is not None

924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
        sequence_descriptor = SequenceDescriptor(
            seqlens=(q_seqlen, kv_seqlen),
            seq_offsets=(q_seq_offsets, k_seq_offsets),
            segment_ids=(_q_segment_ids, _kv_segment_ids),
            segment_pos=(_q_segment_pos, _kv_segment_pos),
        )

        (q_seqlen, kv_seqlen), (q_seq_offsets, k_seq_offsets) = (
            sequence_descriptor.get_seqlens_and_offsets(
                config.attn_mask_type,
                config.qkv_layout,
                config.window_size,
                config.max_segments_per_seq,
            )
        )

Reese Wang's avatar
Reese Wang committed
940
        if config.qkv_layout.is_thd():
941

942
            def _fix_len_take(x, condition, fill_value=-1):
943
944
945
946
947
                x_shape = x.shape
                x = x.flatten()
                size = x.size
                indices = jnp.nonzero(condition.flatten(), size=size, fill_value=size)[0]
                # TODO(rewang): try indices_are_sorted
948
                y = jnp.take(x, indices, fill_value=fill_value)
949
950
951
952
953
954
955
956
957
958
                return jnp.reshape(y, x_shape)

            def convert_to_2d(offsets, batch, max_seqlen):
                offsets_2d = jnp.where(
                    offsets >= 0,
                    offsets + (jnp.arange(batch) * max_seqlen)[..., jnp.newaxis],
                    offsets,
                )
                return offsets_2d

Reese Wang's avatar
Reese Wang committed
959
960
961
962
963
            batch, q_max_seqlen, kv_max_seqlen, *_ = FusedAttnHelper.parse_qkv_aval(
                q, k, v, config.qkv_layout
            )
            assert len(batch) == 1
            kv_batch = q_batch = batch[0]
964
965

            # Gather valid q_seqlen, which is greater than 0
966
            # cuDNN version < 9.3.0:
967
            # [[3, 5, 7, -1, -1], [2, 4, 6, -1, -1]] -> [[3, 5, 7, 2, 4], [6, -1, -1, -1, -1]]
968
969
970
971
972
973
974
975
            # cuDNN version >= 9.3.0, which supports act_seqlen = 0
            # [[3, 5, 7, -1, -1], [2, 4, 6, -1, -1]] -> [[3, 5, 7, 2, 4], [6, 0, 0, 0, 0]]
            if get_cudnn_version() >= (9, 3, 0):
                fill_value = 0
            else:
                fill_value = -1
            q_seqlen = _fix_len_take(q_seqlen, q_seqlen > 0, fill_value=fill_value)
            kv_seqlen = _fix_len_take(kv_seqlen, kv_seqlen > 0, fill_value=fill_value)
976
977
978
979
980

            # Flatten the offset calculation
            # max_seqlen = 8, [[0, 3, 5, -1], [0, 2, 4, -1]] -> [[0, 3, 5, -1], [8, 11, 13, -1]]
            q_seq_offsets = convert_to_2d(q_seq_offsets, q_batch, q_max_seqlen)
            k_seq_offsets = convert_to_2d(k_seq_offsets, kv_batch, kv_max_seqlen)
981

982
983
            # Gather valid q_seq_offsets, which is greater and equal to 0
            # [[0, 3, 5, -1], [8, 11, 13, -1]] -> [[0, 3, 5, 8], [11, 13, -1, -1]]
984
            # And set the unused position to max size (batch * max_seqlen)
985
            # [[0, 3, 5, 8], [11, 13, -1, -1]] -> [[0, 3, 5, 8], [11, 13, b*s, b*s]]
986
987
988
989
990
991
            q_seq_offsets = _fix_len_take(
                q_seq_offsets, q_seq_offsets >= 0, fill_value=q_batch * q_max_seqlen
            )
            k_seq_offsets = _fix_len_take(
                k_seq_offsets, k_seq_offsets >= 0, fill_value=kv_batch * kv_max_seqlen
            )
992
993
994

        q_cu_seqlen = generate_cu_seqlen(q_seqlen.flatten())
        kv_cu_seqlen = generate_cu_seqlen(kv_seqlen.flatten())
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006

        dq, dk, dv, dbias, _ = FusedAttnBwdPrimitive.inner_primitive.bind(
            q,
            k,
            v,
            bias,
            softmax_aux,
            rng_state,
            output,
            doutput,
            q_cu_seqlen,
            kv_cu_seqlen,
1007
1008
            q_seq_offsets,
            k_seq_offsets,
1009
1010
1011
1012
            _q_segment_ids,
            _kv_segment_ids,
            _q_segment_pos,
            _kv_segment_pos,
1013
            config=config,
1014
        )
1015
1016
1017
        return dq, dk, dv, dbias

    @staticmethod
1018
    def batcher(batched_args, batch_dims, *, config):
1019
1020
1021
1022
1023
        check_valid_batch_dims(batch_dims)
        assert FusedAttnBwdPrimitive.outer_primitive is not None
        q_bdim, k_bdim, v_bdim, *_ = batch_dims

        out_bdims = q_bdim, k_bdim, v_bdim, q_bdim
1024
        return (
1025
            FusedAttnBwdPrimitive.outer_primitive.bind(*batched_args, config=config),
1026
1027
            out_bdims,
        )
1028
1029

    @staticmethod
1030
1031
    def infer_sharding_from_operands(config, mesh, arg_infos, result_infos):
        del config, result_infos
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
        q_spec = get_padded_spec(arg_infos[0])
        k_spec = get_padded_spec(arg_infos[1])
        v_spec = get_padded_spec(arg_infos[2])
        bias_spec = get_padded_spec(arg_infos[3])
        dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
        dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec))
        dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec))
        dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec))
        return (dq_sharding, dk_sharding, dv_sharding, dbias_sharding)

    @staticmethod
1043
    def partition(config, mesh, arg_infos, result_infos):
1044
1045
1046
1047
1048
1049
1050
1051
1052
        del result_infos
        q_spec = get_padded_spec(arg_infos[0])
        k_spec = get_padded_spec(arg_infos[1])
        v_spec = get_padded_spec(arg_infos[2])
        bias_spec = get_padded_spec(arg_infos[3])
        dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
        dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec))
        dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec))
        dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec))
1053
1054
1055
1056
        arg_shardings = [arg_i.sharding for arg_i in arg_infos]
        arg_shardings[-1] = arg_shardings[-3]
        arg_shardings[-2] = arg_shardings[-4]
        arg_shardings = tuple(arg_shardings)
1057
1058
        out_shardings = (dq_sharding, dk_sharding, dv_sharding, dbias_sharding)

1059
        def sharded_impl(
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
            q,
            k,
            v,
            bias,
            softmax_aux,
            rng_state,
            output,
            doutput,
            q_cu_seqlen,
            kv_cu_seqlen,
            q_seq_offsets,
            k_seq_offsets,
1072
1073
1074
1075
            _q_segment_ids,
            _kv_segment_ids,
            _q_segment_pos,
            _kv_segment_pos,
1076
        ):
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
            local_dq, local_dk, local_dv, local_dbias = FusedAttnBwdPrimitive.impl(
                q,
                k,
                v,
                bias,
                softmax_aux,
                rng_state,
                output,
                doutput,
                q_cu_seqlen,
                kv_cu_seqlen,
1088
1089
                q_seq_offsets,
                k_seq_offsets,
1090
1091
1092
1093
                _q_segment_ids,
                _kv_segment_ids,
                _q_segment_pos,
                _kv_segment_pos,
1094
                config=config,
1095
            )
1096
            global_dbias = local_dbias
Reese Wang's avatar
Reese Wang committed
1097
            if config.attn_bias_type is not AttnBiasType.NO_BIAS:
1098
                global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh)
1099
1100
1101
1102
            return local_dq, local_dk, local_dv, global_dbias

        return mesh, sharded_impl, out_shardings, arg_shardings

1103
1104
1105
1106
1107
1108
1109
1110
1111
    @staticmethod
    def shardy_sharding_rule(config, mesh, value_types, result_types):
        del config, mesh
        # We only care about the four first arguments.
        # Keep in sync with `infer_sharding_from_operands`.
        input_spec = tuple((f"…{x}",) for x in range(len(value_types)))
        output_spec = tuple((f"…{x}",) for x in range(len(result_types)))
        return SdyShardingRule(input_spec, output_spec)

1112
1113
1114
1115

register_primitive(FusedAttnBwdPrimitive)


Reese Wang's avatar
Reese Wang committed
1116
def reorder_causal_dual_chunk_swap(tensor, cp_size: int, seq_dim: int, to_contiguous: bool):
1117
1118
1119
1120
1121
1122
1123
1124
1125
    """Reorders a tensor for load balancing the compute of causal attention."""
    if cp_size == 1:
        return tensor

    if cp_size % 2 != 0:
        raise ValueError(f"{cp_size=} must be a multiple of 2.")

    # Need to ensure we have 2 pairs to swap for balancing between cp ranks
    if tensor.shape[seq_dim] % (cp_size * 2) != 0:
Reese Wang's avatar
Reese Wang committed
1126
        raise ValueError(f"{tensor.shape[seq_dim]=} is not a multiple of {cp_size*2=}")
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167

    # [B, S, H, D] -> [B, 2*cp_size, S/2*cp_size, D]
    # [S, B, H, D] -> [2*cp_size, S/2*cp_size, B, H, D]
    ori_tensor_shape = tensor.shape
    tensor = tensor.reshape(
        (
            *ori_tensor_shape[:seq_dim],
            2 * cp_size,
            ori_tensor_shape[seq_dim] // (2 * cp_size),
            *ori_tensor_shape[seq_dim + 1 :],
        )
    )

    parts = []
    if not to_contiguous:
        for cp_rank in range(cp_size):
            # [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] -> [B, 2, S/2*cp_size, H, D]
            # [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] -> [2, S/2*cp_size, B, H, D]
            index = jnp.array([cp_rank, (2 * cp_size - cp_rank - 1)])
            parts.append(jnp.take(tensor, index, axis=seq_dim))
    else:
        for cp_rank in range(cp_size // 2):
            # [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] -> [B, 2, S/2*cp_size, H, D]
            # [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] -> [2, S/2*cp_size, B, H, D]
            base = 4 * cp_rank
            index = jnp.array([base, base + 2])
            parts.append(jnp.take(tensor, index, axis=seq_dim))
        for cp_rank in range(cp_size // 2):
            # [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] -> [B, 2, S/2*cp_size, H, D]
            # [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] -> [2, S/2*cp_size, B, H, D]
            base = 2 * cp_size - 1 - 4 * cp_rank
            index = jnp.array([base, base - 2])
            parts.append(jnp.take(tensor, index, axis=seq_dim))

    # [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D]
    # [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D]
    combined = jnp.stack(parts, axis=seq_dim)

    return combined.reshape(ori_tensor_shape)


Reese Wang's avatar
Reese Wang committed
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
def reorder_causal_striped(tensor, cp_size: int, seq_dim: int, is_inverse: bool):
    """Reorders a tensor for load balancing with striped pattern"""
    origin_shape = tensor.shape
    if origin_shape[seq_dim] % cp_size != 0:
        raise ValueError(
            "Expected origin_shape[seq_dim] is multiple of cp_size but got"
            f" {origin_shape[seq_dim]=} and {cp_size=}"
        )

    if not is_inverse:
        new_shape = [
            *origin_shape[:seq_dim],
            *[origin_shape[seq_dim] // cp_size, cp_size],
            *origin_shape[seq_dim + 1 :],
        ]
    else:
        new_shape = [
            *origin_shape[:seq_dim],
            *[cp_size, origin_shape[seq_dim] // cp_size],
            *origin_shape[seq_dim + 1 :],
        ]

    chunked_tensor = tensor.reshape(new_shape)
    reordered_chunked_tensor = jnp.swapaxes(chunked_tensor, seq_dim, seq_dim + 1)
    return reordered_chunked_tensor.reshape(origin_shape)


1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
@dataclass(frozen=True)
class _FusedAttnCPWithAllGatherHelper:
    """Helper class to assist with running the all-gather strategy for CP attention."""

    mesh: jax.sharding.Mesh
    config: _FusedAttnConfig

    def check_supported(self):
        """Checks if the context parallel implementation is supported by the given arguments."""
        header = "Context parallel fused attention"

Reese Wang's avatar
Reese Wang committed
1206
        allowed_layouts = [QKVLayout.BSHD_BS2HD, QKVLayout.BSHD_BSHD_BSHD]
1207
1208
1209
        if self.config.qkv_layout not in allowed_layouts:
            raise ValueError(
                f"{header} only supports layouts:"
1210
                f" {','.join(map(str, allowed_layouts))} got: {self.config.qkv_layout}"
1211
            )
1212

Reese Wang's avatar
Reese Wang committed
1213
        if self.config.attn_bias_type != AttnBiasType.NO_BIAS:
1214
            raise ValueError(f"{header} does not support bias got: {self.config.attn_bias_type}")
1215

Reese Wang's avatar
Reese Wang committed
1216
        allowed_masks = [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]
1217
1218
1219
        if self.config.attn_mask_type not in allowed_masks:
            raise ValueError(
                f"{header} only supports masking types: "
1220
                f" {','.join(map(str, allowed_masks))} got: {self.config.attn_mask_type}"
1221
            )
1222

1223
1224
1225
1226
1227
1228
1229
1230
        if self.config.max_segments_per_seq != 1:
            raise ValueError(
                f"{header} only supports max_segments_per_seq == 1 got:"
                f" {self.config.max_segments_per_seq}"
            )

        if self.config.dropout_probability != 0.0:
            raise ValueError(f"{header} does not support dropout")
1231
1232
1233

    def get_adjusted_mask(self):
        """Converts the mask for context parallelism."""
Reese Wang's avatar
Reese Wang committed
1234
1235
        if self.config.attn_mask_type == AttnMaskType.CAUSAL_MASK:
            return AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK
1236
1237
        return self.config.attn_mask_type

1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
    def get_step_config(self) -> _FusedAttnConfig:
        """Returns a _FusedAttnConfig for single CP step call to fused attention."""
        return _FusedAttnConfig(
            attn_bias_type=self.config.attn_bias_type,
            attn_mask_type=self.get_adjusted_mask(),
            qkv_layout=self.config.qkv_layout,
            scaling_factor=self.config.scaling_factor,
            dropout_probability=self.config.dropout_probability,
            is_training=self.config.is_training,
            max_segments_per_seq=self.config.max_segments_per_seq,
            window_size=self.config.window_size,
            context_parallel_load_balanced=self.config.context_parallel_load_balanced,
            cp_axis=self.config.cp_axis,
1251
            cp_striped_window_size=None,
1252
1253
        )

1254
1255
1256
1257
    def all_gather_kv(self, k, v):
        """Performs a all-gather of k and v over context parallel ranks."""

        def ag(x):
1258
            x = lax_paral_op(
1259
1260
                x, lax.all_gather, self.config.cp_axis, mesh=self.mesh, axis=1, tiled=True
            )
1261
1262
            if self.config.context_parallel_load_balanced:
                cp_size = get_mesh_axis_size(self.config.cp_axis, self.mesh)
Reese Wang's avatar
Reese Wang committed
1263
                x = reorder_causal_dual_chunk_swap(x, cp_size, 1, to_contiguous=True)
1264
            return x
1265

Reese Wang's avatar
Reese Wang committed
1266
1267
1268
1269
        if self.config.qkv_layout.is_kvpacked():
            return ag(k), v
        if self.config.qkv_layout.is_separate():
            return ag(k), ag(v)
1270
1271
1272
1273
1274
1275
1276

        return k, v  # fall through

    def reduce_scatter_dkv(self, dk, dv):
        """Performs a reduce-scatter of dk and dv over context parallel ranks."""

        def rs(x):
1277
1278
            if self.config.context_parallel_load_balanced:
                cp_size = get_mesh_axis_size(self.config.cp_axis, self.mesh)
Reese Wang's avatar
Reese Wang committed
1279
                x = reorder_causal_dual_chunk_swap(x, cp_size, 1, to_contiguous=False)
1280

1281
1282
1283
1284
1285
1286
1287
1288
1289
            return lax_paral_op(
                x,
                lax.psum_scatter,
                self.config.cp_axis,
                mesh=self.mesh,
                scatter_dimension=1,
                tiled=True,
            )

Reese Wang's avatar
Reese Wang committed
1290
1291
1292
1293
        if self.config.qkv_layout.is_kvpacked():
            return rs(dk), dv
        if self.config.qkv_layout.is_separate():
            return rs(dk), rs(dv)
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329

        return dk, dv  # fall through

    def kv_seqlens_for_rank(self, cp_rank, kv_max_seqlen, kv_seqlen_per_subrank):
        """Returns sequence lengths of KV to use for each sub rank of the given cp_rank.

        Example: CP=4, MaxLen = 1024, Unbalanced
           cp_rank 0: [128, 256]
           cp_rank 1: [384, 512]
           cp_rank 2: [640, 768]
           cp_rank 3: [896, 1024]

        Example: CP=4, MaxLen = 1024, Balanced
           cp_rank 0: [128, 1024]
           cp_rank 1: [256, 896]
           cp_rank 2: [384, 768]
           cp_rank 3: [512, 640]
        """
        if self.config.context_parallel_load_balanced:
            kv_seq_this_rank = [
                (cp_rank + 1) * kv_seqlen_per_subrank,
                kv_max_seqlen - cp_rank * kv_seqlen_per_subrank,
            ]
        else:
            kv_seq_this_rank = [
                (cp_rank * 2 + 1) * kv_seqlen_per_subrank,
                (cp_rank * 2 + 2) * kv_seqlen_per_subrank,
            ]
        return kv_seq_this_rank

    def slice_kv(self, k, v, slice_seq_len):
        """Slices k and v tensors to a sequence length of slice_seq_len."""

        def sliced(x):
            return lax.dynamic_slice_in_dim(x, 0, slice_seq_len, axis=1)

Reese Wang's avatar
Reese Wang committed
1330
1331
1332
1333
        if self.config.qkv_layout.is_kvpacked():
            return sliced(k), v
        if self.config.qkv_layout.is_separate():
            return sliced(k), sliced(v)
1334
1335
1336
1337
1338
1339
1340
1341
1342

        return k, v  # fall through

    def pad_kv(self, dk, dv, pad_seq_len):
        """Pads dk and dv tensors to a sequence length of pad_seq_len."""

        def pad(x, npad):
            return jnp.pad(x, npad, "constant", constant_values=0.0)

Reese Wang's avatar
Reese Wang committed
1343
1344
1345
1346
1347
1348
        if self.config.qkv_layout.is_kvpacked():
            npad = [[0, 0], [0, pad_seq_len], [0, 0], [0, 0], [0, 0]]
            return pad(dk, npad), dv
        if self.config.qkv_layout.is_separate():
            npad = [[0, 0], [0, pad_seq_len], [0, 0], [0, 0]]
            return pad(dk, npad), pad(dv, npad)
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363

        return dk, dv  # fall through


class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive):
    """
    Fused Attention Forward with Context Parallelism Primitive

    This context parallel implementation uses all-gather to collect KV inputs from context parallel ranks.
    """

    @staticmethod
    def partition(config, mesh, arg_infos, result_infos):
        # Call base implementation for non-context parallel mesh to avoid unecessary work.
        is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1
1364
        assert (
1365
            not is_context_parallel or config.window_size[0] == -1
1366
        ), "Sliding window attention is not supported when context parallelism is enabled"
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
        if not is_context_parallel:
            return FusedAttnFwdPrimitive.partition(config, mesh, arg_infos, result_infos)

        helper = _FusedAttnCPWithAllGatherHelper(mesh, config)
        helper.check_supported()

        out_sharding = result_infos[0].sharding
        softmax_aux_sharding = result_infos[1].sharding
        rng_state_sharding = seed_sharding = NamedSharding(
            mesh, PartitionSpec(get_all_mesh_axes(), None)
        )
1378
1379
1380
        arg_shardings = [arg_i.sharding for arg_i in arg_infos]
        arg_shardings[4] = seed_sharding
        arg_shardings = tuple(arg_shardings)
1381
1382
        out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding)

1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
        def impl(
            q,
            k,
            v,
            bias,
            seed,
            q_seqlen,
            kv_seqlen,
            q_seq_offsets,
            k_seq_offsets,
            _q_segment_ids,
            _kv_segment_ids,
            _q_segment_pos,
            _kv_segment_pos,
        ):
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
            cp_size = get_mesh_axis_size(config.cp_axis, mesh)
            cp_rank = get_mesh_axis_rank(config.cp_axis, mesh)

            # cuDNN does not support right-aligned masking with dynamic sequence length padding.
            # Therefore we must explicitly instantiate each CP rank slicing and use a runtime switch
            # to select the appropriate computation. Each case generates a [..., SEQ/CP, ..] tensor
            # meeting the expectation of the SPMD model.
            # TODO(mgoldfarb-nvidia): When cuDNN supports we should be able to make use of a padding
            # mask/sequence length tensor to avoid this unrolled loop.
            def _cross_attn(idx, q, k, v, bias, q_seqlen, kv_seqlen, seed):
                kv_max_seqlen = k.shape[1]
                kv_seqlen_per_subrank = kv_max_seqlen // (cp_size * 2)
                assert kv_max_seqlen % cp_size == 0, "sequence length must evenly divide cp size"

                q_split = jnp.split(q, 2, axis=1)

                kv_seqlens_for_rank = helper.kv_seqlens_for_rank(
                    idx, kv_max_seqlen, kv_seqlen_per_subrank
                )

                results = []
                for sub_idx in range(2):
Reese Wang's avatar
Reese Wang committed
1420
                    if config.attn_mask_type == AttnMaskType.NO_MASK:
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
                        k_unmasked, v_unmasked = k, v  # full kv used for unmasked
                    else:
                        k_unmasked, v_unmasked = helper.slice_kv(k, v, kv_seqlens_for_rank[sub_idx])

                    q_seqlen_for_step = q_seqlen / (cp_size * 2)
                    num_kv_chunks = kv_max_seqlen // kv_seqlens_for_rank[sub_idx]
                    kv_seqlen_for_step = (kv_seqlen / (cp_size * 2)) * num_kv_chunks

                    output, softmax_aux, rng_state = FusedAttnFwdPrimitive.impl(
                        q_split[sub_idx],
                        k_unmasked,
                        v_unmasked,
                        bias,
1434
                        seed,
1435
1436
1437
1438
                        q_seqlen_for_step,
                        kv_seqlen_for_step,
                        q_seq_offsets,
                        k_seq_offsets,
1439
1440
1441
1442
                        _q_segment_ids,
                        _kv_segment_ids,
                        _q_segment_pos,
                        _kv_segment_pos,
1443
                        config=helper.get_step_config(),
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
                    )
                    results.append((output, softmax_aux, rng_state))

                output = jnp.concatenate((results[0][0], results[1][0]), axis=1)
                softmax_aux = jnp.concatenate((results[0][1], results[1][1]), axis=2)
                rng_state = results[1][2]  # Use the final RNG state

                return output, softmax_aux, rng_state

            k_ag, v_ag = helper.all_gather_kv(k, v)

            functions = [
                partial(_cross_attn, idx, q, k_ag, v_ag, bias, q_seqlen, kv_seqlen, seed)
                for idx in range(cp_size)
            ]

            return lax.switch(cp_rank, functions)

        return mesh, impl, out_shardings, arg_shardings


register_primitive(FusedAttnCPWithAllGatherFwdPrimitive)


class FusedAttnCPWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive):
    """
    Fused Attention Backward with Context Parallelism Primitive.

    This context parallel implementation uses all-gather to collect KV and dKV inputs from context parallel ranks.
    The gradients are subsequently reduce-scattered back to each context parallel rank.
    """

    @staticmethod
    def partition(config, mesh, arg_infos, result_infos):
        # Call base implementation for non-context parallel mesh to avoid unecessary work.
        is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1
1480
        assert (
1481
            not is_context_parallel or config.window_size[0] == -1
1482
        ), "Sliding window attention is not supported when context parallelism is enabled"
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
        if not is_context_parallel:
            return FusedAttnBwdPrimitive.partition(config, mesh, arg_infos, result_infos)

        # Ensure we can support this configuration with context parallelism.
        helper = _FusedAttnCPWithAllGatherHelper(mesh, config)
        helper.check_supported()

        del result_infos
        q_spec = get_padded_spec(arg_infos[0])
        k_spec = get_padded_spec(arg_infos[1])
        v_spec = get_padded_spec(arg_infos[2])
        bias_spec = get_padded_spec(arg_infos[3])
        dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
        dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec))
        dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec))
        dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec))
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
        out_shardings = (dq_sharding, dk_sharding, dv_sharding, dbias_sharding)

        def impl(
            q,
            k,
            v,
            bias,
            softmax_aux,
            rng_state,
            output,
            doutput,
            q_seqlen,
            kv_seqlen,
            q_seq_offsets,
            k_seq_offsets,
1515
1516
1517
1518
            _q_segment_ids,
            _kv_segment_ids,
            _q_segment_pos,
            _kv_segment_pos,
1519
1520
1521
1522
1523
1524
        ):
            cp_size = get_mesh_axis_size(config.cp_axis, mesh)
            cp_rank = get_mesh_axis_rank(config.cp_axis, mesh)

            # See comment in FusedAttnCPFwdPrimitive.partition for why we define this function.
            def _cross_attn_bwd(
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
                idx,
                q,
                k,
                v,
                bias,
                softmax_aux,
                rng_state,
                output,
                doutput,
                q_seqlen,
                kv_seqlen,
                _q_segment_ids,
                _kv_segment_ids,
                _q_segment_pos,
                _kv_segment_pos,
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
            ):
                kv_max_seqlen = k.shape[1]
                kv_seqlen_per_subrank = kv_max_seqlen // (cp_size * 2)
                assert kv_max_seqlen % cp_size == 0, "sequence length must evenly divide cp size"

                q_split = jnp.split(q, 2, axis=1)
                output_split = jnp.split(output, 2, axis=1)
                doutput_split = jnp.split(doutput, 2, axis=1)
                softmax_aux_split = jnp.split(softmax_aux, 2, axis=2)

                kv_seqlens_for_rank = helper.kv_seqlens_for_rank(
                    idx, kv_max_seqlen, kv_seqlen_per_subrank
                )

                results = []
                for sub_idx in range(2):
Reese Wang's avatar
Reese Wang committed
1556
                    if config.attn_mask_type == AttnMaskType.NO_MASK:
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
                        k_unmasked, v_unmasked = k, v  # full kv used for unmasked
                    else:
                        k_unmasked, v_unmasked = helper.slice_kv(k, v, kv_seqlens_for_rank[sub_idx])

                    q_seqlen_for_step = q_seqlen // (cp_size * 2)
                    num_kv_chunks = kv_max_seqlen // kv_seqlens_for_rank[sub_idx]
                    kv_seqlen_for_step = (kv_seqlen // (cp_size * 2)) * num_kv_chunks

                    dq_local, dk_local, dv_local, dbias_local = FusedAttnBwdPrimitive.impl(
                        q_split[sub_idx],
                        k_unmasked,
                        v_unmasked,
                        bias,
                        softmax_aux_split[sub_idx],
                        rng_state,
                        output_split[sub_idx],
                        doutput_split[sub_idx],
                        q_seqlen_for_step,
                        kv_seqlen_for_step,
                        q_seq_offsets,
                        k_seq_offsets,
1578
1579
1580
1581
                        _q_segment_ids,
                        _kv_segment_ids,
                        _q_segment_pos,
                        _kv_segment_pos,
1582
                        config=helper.get_step_config(),
1583
1584
1585
                    )

                    # pad dk/dv to be unsliced shape so we can reduce scatter over all ranks.
Reese Wang's avatar
Reese Wang committed
1586
                    if config.attn_mask_type != AttnMaskType.NO_MASK:
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
                        pad_length = kv_max_seqlen - kv_seqlens_for_rank[sub_idx]
                        dk_local, dv_local = helper.pad_kv(dk_local, dv_local, pad_length)

                    results.append((dq_local, dk_local, dv_local, dbias_local))

                dq_local = jnp.concatenate((results[0][0], results[1][0]), axis=1)
                dk_local_pad = results[0][1] + results[1][1]
                dv_local_pad = results[0][2] + results[1][2]
                return dq_local, dk_local_pad, dv_local_pad, results[1][3]

            k_ag, v_ag = helper.all_gather_kv(k, v)

            functions = [
                partial(
                    _cross_attn_bwd,
                    idx,
                    q,
                    k_ag,
                    v_ag,
                    bias,
                    softmax_aux,
                    rng_state,
                    output,
                    doutput,
                    q_seqlen,
                    kv_seqlen,
1613
1614
1615
1616
                    _q_segment_ids,
                    _kv_segment_ids,
                    _q_segment_pos,
                    _kv_segment_pos,
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
                )
                for idx in range(cp_size)
            ]

            dq, dk_local, dv_local, dbias = lax.switch(cp_rank, functions)
            dk, dv = helper.reduce_scatter_dkv(dk_local, dv_local)

            return dq, dk, dv, dbias

        return mesh, impl, out_shardings, arg_shardings


register_primitive(FusedAttnCPWithAllGatherBwdPrimitive)


1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
@dataclass(frozen=True)
class _FusedAttnCPWithP2PHelper:
    """Helper class to assist with running the P2P ring strategy for CP attention."""

    mesh: jax.sharding.Mesh
    config: _FusedAttnConfig

    @staticmethod
    def use_scanloop():
        """Returns true if the implementation will use a scan loop for iteration."""
        use_scan = bool(int(os.getenv("NVTE_FUSED_RING_ATTENTION_USE_SCAN", "1")))
1643
        return use_scan
1644
1645
1646
1647
1648

    def check_supported(self):
        """Checks if the context parallel implementation is supported by the given arguments."""
        header = "Context parallel fused ring attention"

Reese Wang's avatar
Reese Wang committed
1649
1650
1651
1652
1653
        if self.config.qkv_layout.is_thd():
            allowed_layouts = [QKVLayout.THD_T2HD, QKVLayout.THD_THD_THD]
        else:
            allowed_layouts = [QKVLayout.BSHD_BS2HD, QKVLayout.BSHD_BSHD_BSHD]

1654
1655
1656
1657
1658
1659
        if self.config.qkv_layout not in allowed_layouts:
            raise ValueError(
                f"{header} only supports layouts:"
                f" {','.join(map(str, allowed_layouts))} got: {self.config.qkv_layout}"
            )

Reese Wang's avatar
Reese Wang committed
1660
        if self.config.attn_bias_type != AttnBiasType.NO_BIAS:
1661
1662
            raise ValueError(f"{header} does not support bias got: {self.config.attn_bias_type}")

Reese Wang's avatar
Reese Wang committed
1663
1664
1665
1666
        if self.config.qkv_layout.is_thd():
            allowed_masks = [AttnMaskType.PADDING_CAUSAL_MASK]
        else:
            allowed_masks = [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]
1667
1668
1669
1670
1671
1672
        if self.config.attn_mask_type not in allowed_masks:
            raise ValueError(
                f"{header} only supports masking types: "
                f" {','.join(map(str, allowed_masks))} got: {self.config.attn_mask_type}"
            )

Reese Wang's avatar
Reese Wang committed
1673
        if not self.config.qkv_layout.is_thd() and self.config.max_segments_per_seq != 1:
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
            raise ValueError(
                f"{header} only supports max_segments_per_seq == 1 got:"
                f" {self.config.max_segments_per_seq}"
            )

        if self.config.dropout_probability != 0.0:
            raise ValueError(f"{header} does not support dropout")

        # We want to encourage use of scan loop to minimize unrolling and ensure more
        # predictable scheduling from XLA. The unrolled flavor will be supported but
        # not the prefered implementation.
        if not self.use_scanloop():
            warnings.warn(
                "Scan loop is disabled for fused ring attention. To enable set"
1688
                " NVTE_FUSED_RING_ATTENTION_USE_SCAN=1 in your environment"
1689
1690
            )

1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
        # If using scanloop, idx in scan_kv_block() will be a traced device value, but
        # _normalize_window_size_for_cp_striped() requires all parameters to be host values
        is_context_parallel = get_mesh_axis_size(self.config.cp_axis, self.mesh) > 1
        is_thd_layout = self.config.qkv_layout.is_thd()
        is_sliding_window = self.config.window_size[0] != -1
        if is_context_parallel and is_thd_layout and is_sliding_window and self.use_scanloop():
            raise ValueError(
                f"{header} with THD format and sliding window does not support using scan loop"
            )

1701
1702
1703
1704
1705
    def get_step_config(self, attn_mask_type) -> _FusedAttnConfig:
        """Returns a _FusedAttnConfig for single CP step call to fused attention."""
        return _FusedAttnConfig(
            attn_bias_type=self.config.attn_bias_type,
            attn_mask_type=attn_mask_type,
Reese Wang's avatar
Reese Wang committed
1706
            qkv_layout=QKVLayout.BSHD_BS2HD,
1707
1708
1709
1710
1711
1712
1713
            scaling_factor=self.config.scaling_factor,
            dropout_probability=self.config.dropout_probability,
            is_training=self.config.is_training,
            max_segments_per_seq=self.config.max_segments_per_seq,
            window_size=self.config.window_size,
            context_parallel_load_balanced=self.config.context_parallel_load_balanced,
            cp_axis=self.config.cp_axis,
1714
            cp_striped_window_size=None,
1715
1716
1717
1718
1719
        )

    def stack_kv(self, k, v):
        """Stacks k and v tensors if not stacked."""
        _not_used = jnp.zeros(0, dtype=k.dtype)
Reese Wang's avatar
Reese Wang committed
1720
1721
1722
1723
        if self.config.qkv_layout.is_kvpacked():
            return k
        if self.config.qkv_layout.is_separate():
            return jnp.stack([k, v], axis=2)
1724
1725
1726
1727
1728
        return _not_used

    def unstack_kv(self, kv):
        """Un-stacks k and v tensors if not stacked."""
        _not_used = jnp.zeros(0, dtype=kv.dtype)
Reese Wang's avatar
Reese Wang committed
1729
1730
1731
1732
        if self.config.qkv_layout.is_kvpacked():
            return kv, _not_used
        if self.config.qkv_layout.is_separate():
            return jnp.unstack(kv, axis=2)
1733
1734
1735
1736
1737
1738
        return _not_used, _not_used  # fall through

    def permute_kv(self, kv, cp_perm):
        """Permutes kv around the ring as described by cp_perm."""
        return lax_paral_op(kv, lax.ppermute, self.config.cp_axis, mesh=self.mesh, perm=cp_perm)

1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
    @staticmethod
    def correct_output_and_softmax_aux(output, softmax_aux, partial_output, partial_softmax_aux):
        """
        Corrects the output and softmax_aux tensor after each iteration of ring attention.

        See https://github.com/zhuzilin/ring-flash-attention/pull/34#issuecomment-2076126795 for
        derivation of this equation.
        """
        new_out = output - jax.nn.sigmoid(partial_softmax_aux - softmax_aux).transpose(
            0, 2, 1, 3
        ) * (output - partial_output)
        new_aux = softmax_aux - jax.nn.log_sigmoid(softmax_aux - partial_softmax_aux)
        return new_out, new_aux
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784

    def adjust_seqlen(self, seqlen, max_seqlen, idx):
        """Adjust the sequence length per step."""
        seqlen_of_curr_step = seqlen - max_seqlen * idx
        seqlen_of_curr_step = jnp.where(seqlen_of_curr_step < 0, 0, seqlen_of_curr_step)
        seqlen_per_step = jnp.where(
            seqlen_of_curr_step < max_seqlen, seqlen_of_curr_step, max_seqlen
        )
        return seqlen_per_step


class FusedRingAttnFwdPrimitive(FusedAttnFwdPrimitive):
    """
    Fused Ring Attention Forward Primitive
    """

    @staticmethod
    def partition(config, mesh, arg_infos, result_infos):
        is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1
        assert (
            not is_context_parallel or config.window_size[0] == -1
        ), "Sliding window attention is not supported when context parallelism is enabled"
        if not is_context_parallel:
            return FusedAttnFwdPrimitive.partition(config, mesh, arg_infos, result_infos)

        helper = _FusedAttnCPWithP2PHelper(mesh, config)
        helper.check_supported()

        out_sharding = result_infos[0].sharding
        softmax_aux_sharding = result_infos[1].sharding
        rng_state_sharding = seed_sharding = NamedSharding(
            mesh, PartitionSpec(get_all_mesh_axes(), None)
        )
1785
1786
1787
        arg_shardings = [arg_i.sharding for arg_i in arg_infos]
        arg_shardings[4] = seed_sharding
        arg_shardings = tuple(arg_shardings)
1788
1789
1790
1791
1792
1793
1794
        out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding)

        def ring_attn_fwd_impl(
            q,
            k,
            v,
            bias,
1795
            seed,
1796
1797
1798
1799
            q_seqlen,
            kv_seqlen,
            q_seq_offsets,
            k_seq_offsets,
1800
1801
1802
1803
            _q_segment_ids,
            _kv_segment_ids,
            _q_segment_pos,
            _kv_segment_pos,
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
        ):
            _not_used = jnp.zeros(0, dtype=v.dtype)

            # Combine KV tensors if separate for better permute scheduling and performance.
            # Eventually XLA should perform this automatically.
            kv = helper.stack_kv(k, v)

            batch, q_max_seqlen, head, _ = q.shape
            kv_max_seqlen = k.shape[1]

            cp_size = get_mesh_axis_size(config.cp_axis, mesh)
            cp_rank = get_mesh_axis_rank(config.cp_axis, mesh)
            cp_perm = [(i, (i + 1) % cp_size) for i in range(cp_size)]

1818
            output = jnp.zeros(q.shape).astype(jnp.float32)
1819
1820
1821
1822
1823
1824
1825
1826
            softmax_aux = jnp.full((batch, head, q_max_seqlen, 1), -jnp.inf, dtype=jnp.float32)

            # RNG shape should be the shared shape. This is unused for ring attention as we do not
            # support dropout currently.
            rng_state_shape = (result_infos[2].shape[0] // mesh.size, *result_infos[2].shape[1:])
            rng_state = jnp.zeros(rng_state_shape).astype(result_infos[2].dtype)

            def scan_kv_block(idx, carry):
1827
                kv, output, softmax_aux = carry
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839

                # Send KV block to next step so we can overlap compute.
                kv_next = helper.permute_kv(kv, cp_perm)

                def mask_compute(attn_mask_type):
                    q_seqlen_per_step = helper.adjust_seqlen(q_seqlen, q_max_seqlen, idx)
                    kv_seqlen_per_step = helper.adjust_seqlen(kv_seqlen, kv_max_seqlen, idx)
                    output_per_step, softmax_aux_per_step, _ = FusedAttnFwdPrimitive.impl(
                        q,
                        kv,
                        _not_used,
                        bias,
1840
                        seed,
1841
1842
1843
1844
                        q_seqlen_per_step,
                        kv_seqlen_per_step,
                        q_seq_offsets,
                        k_seq_offsets,
1845
1846
1847
1848
1849
                        _q_segment_ids,
                        _kv_segment_ids,
                        _q_segment_pos,
                        _kv_segment_pos,
                        config=helper.get_step_config(attn_mask_type),
1850
1851
1852
                    )
                    return output_per_step, softmax_aux_per_step

Reese Wang's avatar
Reese Wang committed
1853
1854
                causal_mask_compute = partial(mask_compute, AttnMaskType.CAUSAL_MASK)
                no_mask_compute = partial(mask_compute, AttnMaskType.NO_MASK)
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864

                def half_kv_no_mask_compute():
                    q_seqlen_per_step = helper.adjust_seqlen(q_seqlen, q_max_seqlen, idx)
                    kv_seqlen_per_step = helper.adjust_seqlen(kv_seqlen, kv_max_seqlen, idx) // 2
                    kv_part = lax.slice_in_dim(kv, 0, kv.shape[1] // 2, axis=1)
                    output_per_step, softmax_aux_per_step, _ = FusedAttnFwdPrimitive.impl(
                        q,
                        kv_part,
                        _not_used,
                        bias,
1865
                        seed,
1866
1867
1868
1869
                        q_seqlen_per_step,
                        kv_seqlen_per_step,
                        q_seq_offsets,
                        k_seq_offsets,
1870
1871
1872
1873
                        _q_segment_ids,
                        _kv_segment_ids,
                        _q_segment_pos,
                        _kv_segment_pos,
Reese Wang's avatar
Reese Wang committed
1874
                        config=helper.get_step_config(AttnMaskType.NO_MASK),
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
                    )
                    return output_per_step, softmax_aux_per_step

                def half_q_no_mask_compute():
                    q_seqlen_per_step = helper.adjust_seqlen(q_seqlen, q_max_seqlen, idx) // 2
                    kv_seqlen_per_step = helper.adjust_seqlen(kv_seqlen, kv_max_seqlen, idx)
                    q_part = lax.slice_in_dim(q, q_max_seqlen // 2, q_max_seqlen, axis=1)
                    output_per_step, softmax_aux_per_step, _ = FusedAttnFwdPrimitive.impl(
                        q_part,
                        kv,
                        _not_used,
                        bias,
1887
                        seed,
1888
1889
1890
1891
                        q_seqlen_per_step,
                        kv_seqlen_per_step,
                        q_seq_offsets,
                        k_seq_offsets,
1892
1893
1894
1895
                        _q_segment_ids,
                        _kv_segment_ids,
                        _q_segment_pos,
                        _kv_segment_pos,
Reese Wang's avatar
Reese Wang committed
1896
                        config=helper.get_step_config(AttnMaskType.NO_MASK),
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
                    )
                    output_per_step = jnp.concat([jnp.zeros_like(q_part), output_per_step], axis=1)
                    softmax_aux_per_step = jnp.concat(
                        [
                            jnp.full_like(softmax_aux_per_step, -jnp.inf),
                            softmax_aux_per_step,
                        ],
                        axis=2,
                    )
                    return output_per_step, softmax_aux_per_step

                def skip_compute():
                    output_per_step = jnp.zeros_like(q)
                    softmax_aux_per_step = jnp.full(
                        (batch, head, q.shape[1], 1), -jnp.inf, dtype=jnp.float32
                    )
                    return output_per_step, softmax_aux_per_step

Reese Wang's avatar
Reese Wang committed
1915
                if config.attn_mask_type == AttnMaskType.CAUSAL_MASK:
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
                    # This is for nested jax.lax.cond
                    def jax_cond_wrap():
                        if config.context_parallel_load_balanced:
                            return lax.cond(
                                (idx <= cp_rank), half_kv_no_mask_compute, half_q_no_mask_compute
                            )
                        return lax.cond((idx <= cp_rank), no_mask_compute, skip_compute)

                    output_per_step, softmax_aux_per_step = lax.cond(
                        idx == 0, causal_mask_compute, jax_cond_wrap
                    )
                else:
                    output_per_step, softmax_aux_per_step = no_mask_compute()

1930
1931
1932
1933
1934
                def skip_correction(output, softmax_aux, output_per_step, softmax_aux_per_step):
                    # No correction done here but we cast outputs to float32 and perform reduction
                    # in full precision.
                    # pylint: disable=unused-argument
                    return output_per_step.astype(jnp.float32), softmax_aux_per_step
1935

1936
1937
1938
1939
                def correction(output, softmax_aux, output_per_step, softmax_aux_per_step):
                    return helper.correct_output_and_softmax_aux(
                        output, softmax_aux, output_per_step, softmax_aux_per_step
                    )
1940

1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
                # first step there is no correction we get initial output and stats
                output, softmax_aux = lax.cond(
                    (idx == 0),
                    skip_correction,
                    correction,
                    output,
                    softmax_aux,
                    output_per_step,
                    softmax_aux_per_step,
                )

                return (kv_next, output, softmax_aux)

            carry = (kv, output, softmax_aux)
1955
1956
1957
1958
1959
            if helper.use_scanloop():
                carry = lax.fori_loop(0, cp_size, scan_kv_block, carry)
            else:
                for i in range(0, cp_size):
                    carry = scan_kv_block(i, carry)
1960
            (kv, output, softmax_aux) = carry
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012

            output = output.astype(q.dtype)
            return output, softmax_aux, rng_state

        return mesh, ring_attn_fwd_impl, out_shardings, arg_shardings


register_primitive(FusedRingAttnFwdPrimitive)


class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive):
    """
    Fused Ring Attention Backward Primitive
    """

    @staticmethod
    def partition(config, mesh, arg_infos, result_infos):
        is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1
        assert (
            not is_context_parallel or config.window_size[0] == -1
        ), "Sliding window attention is not supported when context parallelism is enabled"
        if not is_context_parallel:
            return FusedAttnBwdPrimitive.partition(config, mesh, arg_infos, result_infos)

        del result_infos
        q_spec = get_padded_spec(arg_infos[0])
        k_spec = get_padded_spec(arg_infos[1])
        v_spec = get_padded_spec(arg_infos[2])
        bias_spec = get_padded_spec(arg_infos[3])
        dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
        dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec))
        dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec))
        dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec))
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
        out_shardings = (dq_sharding, dk_sharding, dv_sharding, dbias_sharding)

        helper = _FusedAttnCPWithP2PHelper(mesh, config)
        helper.check_supported()

        def ring_attn_bwd_impl(
            q,
            k,
            v,
            bias,
            softmax_aux,
            rng_state,
            output,
            doutput,
            q_seqlen,
            kv_seqlen,
            q_seq_offsets,
            k_seq_offsets,
2013
2014
2015
2016
            _q_segment_ids,
            _kv_segment_ids,
            _q_segment_pos,
            _kv_segment_pos,
2017
2018
2019
2020
2021
2022
2023
2024
2025
2026
2027
2028
2029
2030
2031
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
2055
2056
2057
2058
2059
2060
        ):
            _not_used = jnp.zeros(0, dtype=output.dtype)

            # Combine KV tensors if separate for better permute scheduling and performance.
            # Eventually XLA should perform this automatically.
            kv = helper.stack_kv(k, v)

            q_max_seqlen = q.shape[1]
            kv_max_seqlen = k.shape[1]

            cp_size = get_mesh_axis_size(config.cp_axis, mesh)
            cp_rank = get_mesh_axis_rank(config.cp_axis, mesh)
            cp_perm = [(i, (i + 1) % cp_size) for i in range(cp_size)]

            dq = jnp.zeros_like(q)
            dk_dv = helper.stack_kv(jnp.zeros_like(k), jnp.zeros_like(v))
            dbias = jnp.zeros_like(bias)

            def scan_kv_block(idx, carry):

                kv, dq, dk_dv, dbias = carry

                # Start communication that feeds the next iteraton.
                # We further combine the tensors to improve overlap.

                kv_dk_dv = jnp.stack([kv, dk_dv])
                kv_dk_dv = helper.permute_kv(kv_dk_dv, cp_perm)

                def mask_compute(attn_mask_type):
                    q_seqlen_per_step = helper.adjust_seqlen(q_seqlen, q_max_seqlen, idx)
                    kv_seqlen_per_step = helper.adjust_seqlen(kv_seqlen, kv_max_seqlen, idx)
                    dq_per_step, dk_dv_per_step, _, dbias_per_step = FusedAttnBwdPrimitive.impl(
                        q,
                        kv,
                        _not_used,
                        bias,
                        softmax_aux,
                        rng_state,
                        output,
                        doutput,
                        q_seqlen_per_step,
                        kv_seqlen_per_step,
                        q_seq_offsets,
                        k_seq_offsets,
2061
2062
2063
2064
                        _q_segment_ids,
                        _kv_segment_ids,
                        _q_segment_pos,
                        _kv_segment_pos,
2065
2066
2067
2068
                        config=helper.get_step_config(attn_mask_type),
                    )
                    return dq_per_step, dk_dv_per_step, dbias_per_step

Reese Wang's avatar
Reese Wang committed
2069
2070
                causal_mask_compute = partial(mask_compute, AttnMaskType.CAUSAL_MASK)
                no_mask_compute = partial(mask_compute, AttnMaskType.NO_MASK)
2071
2072
2073
2074
2075
2076
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088

                def half_kv_no_mask_compute():
                    q_seqlen_per_step = helper.adjust_seqlen(q_seqlen, q_max_seqlen, idx)
                    kv_seqlen_per_step = helper.adjust_seqlen(kv_seqlen, kv_max_seqlen, idx) // 2
                    kv_part = lax.slice_in_dim(kv, 0, kv_max_seqlen // 2, axis=1)
                    dq_per_step, dk_dv_per_step, _, dbias_per_step = FusedAttnBwdPrimitive.impl(
                        q,
                        kv_part,
                        _not_used,
                        bias,
                        softmax_aux,
                        rng_state,
                        output,
                        doutput,
                        q_seqlen_per_step,
                        kv_seqlen_per_step,
                        q_seq_offsets,
                        k_seq_offsets,
2089
2090
2091
2092
                        _q_segment_ids,
                        _kv_segment_ids,
                        _q_segment_pos,
                        _kv_segment_pos,
Reese Wang's avatar
Reese Wang committed
2093
                        config=helper.get_step_config(AttnMaskType.NO_MASK),
2094
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115
2116
2117
2118
2119
2120
2121
2122
2123
2124
2125
2126
                    )
                    dk_dv_per_step = jnp.concat(
                        [dk_dv_per_step, jnp.zeros_like(dk_dv_per_step)], axis=1
                    )
                    return dq_per_step, dk_dv_per_step, dbias_per_step

                def half_q_no_mask_compute():
                    q_seqlen_per_step = helper.adjust_seqlen(q_seqlen, q_max_seqlen, idx) // 2
                    kv_seqlen_per_step = helper.adjust_seqlen(kv_seqlen, kv_max_seqlen, idx)

                    q_part = lax.slice_in_dim(q, q_max_seqlen // 2, q_max_seqlen, axis=1)
                    doutput_part = lax.slice_in_dim(
                        doutput, q_max_seqlen // 2, q_max_seqlen, axis=1
                    )
                    output_part = lax.slice_in_dim(output, q_max_seqlen // 2, q_max_seqlen, axis=1)

                    softmax_aux_part = lax.slice_in_dim(
                        softmax_aux, q_max_seqlen // 2, q_max_seqlen, axis=2
                    )

                    dq_per_step, dk_dv_per_step, _, dbias_per_step = FusedAttnBwdPrimitive.impl(
                        q_part,
                        kv,
                        _not_used,
                        bias,
                        softmax_aux_part,
                        rng_state,
                        output_part,
                        doutput_part,
                        q_seqlen_per_step,
                        kv_seqlen_per_step,
                        q_seq_offsets,
                        k_seq_offsets,
2127
2128
2129
2130
                        _q_segment_ids,
                        _kv_segment_ids,
                        _q_segment_pos,
                        _kv_segment_pos,
Reese Wang's avatar
Reese Wang committed
2131
                        config=helper.get_step_config(AttnMaskType.NO_MASK),
2132
2133
2134
2135
2136
2137
2138
                    )
                    dq_per_step = jnp.concat([jnp.zeros_like(dq_per_step), dq_per_step], axis=1)
                    return dq_per_step, dk_dv_per_step, dbias_per_step

                def skip_compute():
                    return jnp.zeros_like(q), jnp.zeros_like(kv), jnp.zeros_like(bias)

Reese Wang's avatar
Reese Wang committed
2139
                if config.attn_mask_type == AttnMaskType.CAUSAL_MASK:
2140
2141
2142
2143
2144
2145
2146
2147
2148
2149
2150
2151
2152
2153
2154
2155
2156
                    # This is for nested jax.lax.cond
                    def jax_cond_wrap():
                        if config.context_parallel_load_balanced:
                            return lax.cond(
                                (idx <= cp_rank), half_kv_no_mask_compute, half_q_no_mask_compute
                            )
                        return lax.cond((idx <= cp_rank), no_mask_compute, skip_compute)

                    dq_per_step, dk_dv_per_step, dbias_per_step = lax.cond(
                        idx == 0, causal_mask_compute, jax_cond_wrap
                    )
                else:
                    dq_per_step, dk_dv_per_step, dbias_per_step = no_mask_compute()

                kv_next, dk_dv = jnp.unstack(kv_dk_dv)
                dq = dq + dq_per_step
                dk_dv = dk_dv + dk_dv_per_step
Reese Wang's avatar
Reese Wang committed
2157
                if config.attn_bias_type is not AttnBiasType.NO_BIAS:
2158
2159
2160
2161
2162
2163
2164
2165
2166
2167
2168
2169
2170
2171
2172
2173
                    dbias = dbias + dbias_per_step

                return (kv_next, dq, dk_dv, dbias)

            carry = (kv, dq, dk_dv, dbias)
            if helper.use_scanloop():
                carry = lax.fori_loop(0, cp_size, scan_kv_block, carry)
            else:
                for i in range(0, cp_size):
                    carry = scan_kv_block(i, carry)
            (kv, dq, dk_dv, dbias) = carry

            # Final permute to put gradients back to their final resting place.
            dk_dv = helper.permute_kv(dk_dv, cp_perm)

            global_dbias = dbias
Reese Wang's avatar
Reese Wang committed
2174
            if config.attn_bias_type is not AttnBiasType.NO_BIAS:
2175
2176
2177
2178
2179
2180
2181
2182
2183
2184
2185
                global_dbias = all_reduce_sum_along_dp_fsdp(dbias, mesh)

            dk, dv = helper.unstack_kv(dk_dv)
            return dq, dk, dv, global_dbias

        return mesh, ring_attn_bwd_impl, out_shardings, arg_shardings


register_primitive(FusedRingAttnBwdPrimitive)


2186
2187
2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
2198
2199
2200
2201
2202
2203
2204
2205
2206
2207
2208
2209
2210
2211
2212
2213
2214
2215
2216
2217
2218
2219
2220
2221
2222
2223
2224
2225
2226
2227
2228
2229
2230
2231
2232
2233
2234
2235
2236
2237
2238
2239
2240
2241
2242
2243
2244
2245
2246
def adjust_cp_striped_window_size(q_pos0, kv_pos0, cp_size, window_size):
    """
    Adjust window size with cp_size for striped sharding, where both q_pos and
    kv_pos are arithmetic sequences like [x, x+cp_size, x+2*cp_size, ...].
    Example 1:
        q_pos = kv_pos = [0, 8, 16, 24, 32], cp_size = 8, window_size = (15, 0).
        q_pos = 32 can look at kv_pos at [24, 32]. The effective mask is:
              0  8 16 24 32
           ----------------
         0 |  1  0  0  0  0
         8 |  1  1  0  0  0
        16 |  0  1  1  0  0
        24 |  0  0  1  1  0
        32 |  0  0  0  1  1
        SequenceDescriptor outputs: {q,kv}_seqlen = [5, ...], {q,kv}_seq_offsets = [0, ...].
        Adjusted window size = (1, 0).
    Example 2:
        q_pos = [0, 8, 16, 24, 32], kv_pos = [1, 9, 17, 25, 33], cp_size = 8,
        window_size = (15, 0). The effective mask is:
              1  9 17 25 33
           ----------------
         0 |  0  0  0  0  0
         8 |  1  0  0  0  0
        16 |  1  1  0  0  0
        24 |  0  1  1  0  0
        32 |  0  0  1  1  0
        SequenceDescriptor outputs:
        q_seqlen = [4, ...], q_seq_offsets = [1, ...],
        kv_seqlen = [4, ...], kv_seq_offsets = [0, ...].
        If diagonal are all 1, left window size = 2. Now since diagonal are all 0,
        we need to use left window size = 2 - 1 = 1 to make cuDNN work.
    Example 3:
        q_pos = [7, 15, 23, 31, 39], kv_pos = [0, 8, 16, 24, 32], cp_size = 8,
        window_size = (22, 0). The effective mask is:
              0  8 16 24 32
           ----------------
         7 |  1  0  0  0  0
        15 |  1  1  0  0  0
        23 |  0  1  1  0  0
        31 |  0  0  1  1  0
        39 |  0  0  0  1  1
        SequenceDescriptor outputs: {q,kv}_seqlen = [5, ...], {q,kv}_seq_offsets = [0, ...].
        Adjust window size = (1, 0).
    """

    left_limit = q_pos0 - window_size[0]
    right_limit = q_pos0 + window_size[1]

    # Count how many left/right steps of size cp_size we can take from kv_pos0 -/+ cp_size
    left_steps = (kv_pos0 - cp_size - left_limit) // cp_size + 1
    right_steps = (right_limit - kv_pos0 - cp_size) // cp_size + 1
    left_steps = max(left_steps, 0)
    right_steps = max(right_steps, 0)

    # If kv_pos0 > q_pos0, we must reduce left window size by 1
    shift = 1 if kv_pos0 > q_pos0 else 0
    left_steps = left_steps - shift

    return left_steps, right_steps


Reese Wang's avatar
Reese Wang committed
2247
2248
2249
2250
2251
2252
2253
2254
2255
2256
2257
2258
2259
2260
2261
2262
2263
2264
2265
2266
2267
2268
2269
2270
2271
2272
2273
2274
2275
2276
2277
2278
2279
2280
2281
2282
2283
2284
2285
2286
2287
2288
2289
2290
2291
2292
2293
2294
2295
2296
2297
2298
2299
class FusedRingAttnStripedFwdPrimitive(FusedAttnFwdPrimitive):
    """
    Fused Striped Ring Attention Forward Primitive
    """

    @staticmethod
    def partition(config, mesh, arg_infos, result_infos):
        is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1
        if not is_context_parallel:
            return FusedAttnFwdPrimitive.partition(config, mesh, arg_infos, result_infos)

        helper = _FusedAttnCPWithP2PHelper(mesh, config)
        helper.check_supported()

        out_sharding = result_infos[0].sharding
        softmax_aux_sharding = result_infos[1].sharding
        rng_state_sharding = seed_sharding = NamedSharding(
            mesh, PartitionSpec(get_all_mesh_axes(), None)
        )
        arg_shardings = [arg_i.sharding for arg_i in arg_infos]
        arg_shardings[4] = seed_sharding
        arg_shardings = tuple(arg_shardings)
        out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding)

        def fwd_impl(
            q,
            k,
            v,
            bias,
            seed,
            q_seqlen,
            kv_seqlen,
            q_seq_offsets,
            k_seq_offsets,
            q_segment_ids,
            kv_segment_ids,
            q_segment_pos,
            kv_segment_pos,
        ):
            if q_segment_ids.size == 0 or kv_segment_ids.size == 0:
                raise ValueError("THD + ring attn only supports passing seqment_ids/pos")

            _not_used = jnp.zeros(0, dtype=v.dtype)

            # Combine KV tensors if separate for better permute scheduling and performance.
            # Eventually XLA should perform this automatically.
            kv = helper.stack_kv(k, v)
            if not config.qkv_layout.is_qkvpacked():
                subblock_config = replace(config, qkv_layout=config.qkv_layout.to_kvpacked())
            else:
                subblock_config = config

            cp_size = get_mesh_axis_size(config.cp_axis, mesh)
2300
            cp_rank = get_mesh_axis_rank_host(config.cp_axis, mesh)
Reese Wang's avatar
Reese Wang committed
2301
2302
2303
2304
2305
2306
2307
2308
2309
2310
2311
2312
2313
2314
2315
2316
2317
2318
2319
2320
            cp_perm = [(i, (i + 1) % cp_size) for i in range(cp_size)]

            batch, q_max_seqlen, head, _ = q.shape
            output = jnp.zeros(q.shape).astype(jnp.float32)
            softmax_aux = jnp.zeros((batch, q_max_seqlen, head, 1), dtype=jnp.float32)

            # RNG shape should be the shared shape. This is unused for ring attention as we do not
            # support dropout currently.
            rng_state_shape = (result_infos[2].shape[0] // mesh.size, *result_infos[2].shape[1:])
            rng_state = jnp.zeros(rng_state_shape).astype(result_infos[2].dtype)

            def scan_kv_block(idx, carry):
                kv, kv_segment_ids, kv_segment_pos, output, softmax_aux = carry

                # TODO(rewang): To check whether we need special handle for the last idx
                # Send KV block to next step so we can overlap compute.
                kv_next = helper.permute_kv(kv, cp_perm)
                kv_segment_ids_next = helper.permute_kv(kv_segment_ids, cp_perm)
                kv_segment_pos_next = helper.permute_kv(kv_segment_pos, cp_perm)

2321
2322
2323
2324
2325
2326
2327
2328
2329
2330
2331
2332
2333
2334
2335
2336
2337
2338
2339
2340
2341
2342
2343
2344
2345
2346
2347
2348
2349
2350
                def compute(config):
                    return FusedAttnFwdPrimitive.impl(
                        q,
                        kv,
                        _not_used,
                        bias,
                        seed,
                        q_seqlen,
                        kv_seqlen,
                        q_seq_offsets,
                        k_seq_offsets,
                        q_segment_ids,
                        kv_segment_ids,
                        q_segment_pos,
                        kv_segment_pos,
                        config,
                    )

                if config.window_size != (-1, -1):
                    kv_src_rank = (cp_size + cp_rank - idx) % cp_size
                    # Note: all inputs of adjust_cp_striped_window_size should be host values
                    cp_striped_window_size = adjust_cp_striped_window_size(
                        cp_rank, kv_src_rank, cp_size, config.window_size
                    )
                    current_config = replace(
                        subblock_config, cp_striped_window_size=cp_striped_window_size
                    )
                else:
                    current_config = subblock_config
                output_per_step, softmax_aux_per_step, _ = compute(current_config)
Reese Wang's avatar
Reese Wang committed
2351
2352
2353
2354
2355
2356
2357
2358
2359
2360
2361
2362
2363
2364
2365
2366
2367
2368
2369
2370
2371
2372
2373
2374
2375
2376
2377
2378
2379
2380
2381
2382
2383
2384
2385
2386
2387
2388
2389
2390
2391
2392
2393
2394
2395
2396
2397
2398
2399
2400
2401
2402
2403
2404
2405
2406
2407
2408
2409
2410
2411
2412
2413
2414
2415
2416
2417
2418
2419
2420
2421
2422
2423
2424
2425
2426
2427
2428
2429
2430
2431
2432
2433
2434
2435
2436
2437
2438
2439
2440
2441
2442
2443
2444
2445

                softmax_aux_per_step = softmax_aux_per_step.reshape((batch, q_max_seqlen, head, 1))

                def skip_correction(_output, _softmax_aux, output_per_step, softmax_aux_per_step):
                    # No correction done here but we cast outputs to float32 and perform reduction
                    # in full precision.
                    return output_per_step.astype(jnp.float32), softmax_aux_per_step

                def correction(output, softmax_aux, output_per_step, softmax_aux_per_step):
                    new_out = output - jax.nn.sigmoid(softmax_aux_per_step - softmax_aux) * (
                        output - output_per_step
                    )
                    new_aux = softmax_aux - jax.nn.log_sigmoid(softmax_aux - softmax_aux_per_step)
                    return new_out, new_aux

                # first step there is no correction we get initial output and stats
                output, softmax_aux = lax.cond(
                    idx == 0,
                    skip_correction,
                    correction,
                    output,
                    softmax_aux,
                    output_per_step,
                    softmax_aux_per_step,
                )

                return (kv_next, kv_segment_ids_next, kv_segment_pos_next, output, softmax_aux)

            carry = (kv, kv_segment_ids, kv_segment_pos, output, softmax_aux)
            if helper.use_scanloop():
                carry = lax.fori_loop(0, cp_size, scan_kv_block, carry)
            else:
                for i in range(0, cp_size):
                    carry = scan_kv_block(i, carry)
            (_, _, _, output, softmax_aux) = carry

            return output.astype(q.dtype), softmax_aux, rng_state

        return mesh, fwd_impl, out_shardings, arg_shardings


register_primitive(FusedRingAttnStripedFwdPrimitive)


class FusedRingAttnStripedBwdPrimitive(FusedAttnBwdPrimitive):
    """
    Fused Striped Ring Attention Backward Primitive
    """

    @staticmethod
    def partition(config, mesh, arg_infos, result_infos):
        is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1
        if not is_context_parallel:
            return FusedAttnBwdPrimitive.partition(config, mesh, arg_infos, result_infos)

        arg_shardings = tuple(arg.sharding for arg in arg_infos)
        # dq, dk, dv, dbias sharding = q, k, v, bias sharding
        out_shardings = tuple(arg.sharding for arg in arg_infos[:4])

        helper = _FusedAttnCPWithP2PHelper(mesh, config)
        helper.check_supported()

        def bwd_impl(
            q,
            k,
            v,
            bias,
            softmax_aux,
            rng_state,
            output,
            doutput,
            q_seqlen,
            kv_seqlen,
            q_seq_offsets,
            k_seq_offsets,
            q_segment_ids,
            kv_segment_ids,
            q_segment_pos,
            kv_segment_pos,
        ):

            if q_segment_ids.size == 0 or kv_segment_ids.size == 0:
                raise ValueError("THD + ring attn only supports passing seqment_ids/pos")

            _not_used = jnp.zeros(0, dtype=output.dtype)

            # Combine KV tensors if separate for better permute scheduling and performance.
            # Eventually XLA should perform this automatically.
            kv = helper.stack_kv(k, v)
            if not config.qkv_layout.is_qkvpacked():
                subblock_config = replace(config, qkv_layout=config.qkv_layout.to_kvpacked())
            else:
                subblock_config = config

            cp_size = get_mesh_axis_size(config.cp_axis, mesh)
2446
2447
            # We need cp_rank to be a host value for adjust_cp_striped_window_size()
            cp_rank = get_mesh_axis_rank_host(config.cp_axis, mesh)
Reese Wang's avatar
Reese Wang committed
2448
2449
2450
2451
2452
2453
            cp_perm = [(i, (i + 1) % cp_size) for i in range(cp_size)]

            dq = jnp.zeros_like(q)
            dkv = jnp.zeros_like(kv)
            dbias = jnp.zeros_like(bias)

2454
            def scan_kv_block(idx, carry):
Reese Wang's avatar
Reese Wang committed
2455
2456
2457
2458
2459
2460
2461
2462
2463
                kv, kv_segment_ids, kv_segment_pos, dq, dkv, dbias = carry

                # Start communication that feeds the next iteration.
                # We further combine the tensors to improve overlap.
                kv_dkv = jnp.stack([kv, dkv])
                kv_dkv = helper.permute_kv(kv_dkv, cp_perm)
                kv_segment_ids_next = helper.permute_kv(kv_segment_ids, cp_perm)
                kv_segment_pos_next = helper.permute_kv(kv_segment_pos, cp_perm)

2464
                def compute(config):
Reese Wang's avatar
Reese Wang committed
2465
2466
2467
2468
2469
2470
2471
2472
2473
2474
2475
2476
2477
2478
2479
2480
2481
                    dq_per_step, dkv_per_step, _, dbias_per_step = FusedAttnBwdPrimitive.impl(
                        q,
                        kv,
                        _not_used,
                        bias,
                        softmax_aux,
                        rng_state,
                        output,
                        doutput,
                        q_seqlen,
                        kv_seqlen,
                        q_seq_offsets,
                        k_seq_offsets,
                        q_segment_ids,
                        kv_segment_ids,
                        q_segment_pos,
                        kv_segment_pos,
2482
                        config=config,
Reese Wang's avatar
Reese Wang committed
2483
2484
2485
                    )
                    return dq_per_step, dkv_per_step, dbias_per_step

2486
2487
2488
2489
2490
2491
2492
2493
2494
2495
2496
2497
                if config.window_size != (-1, -1):
                    kv_src_rank = (cp_size + cp_rank - idx) % cp_size
                    # Note: all inputs of adjust_cp_striped_window_size should be host values
                    cp_striped_window_size = adjust_cp_striped_window_size(
                        cp_rank, kv_src_rank, cp_size, config.window_size
                    )
                    current_config = replace(
                        subblock_config, cp_striped_window_size=cp_striped_window_size
                    )
                else:
                    current_config = subblock_config
                dq_per_step, dkv_per_step, dbias_per_step = compute(current_config)
Reese Wang's avatar
Reese Wang committed
2498
2499
2500
2501
2502
2503
2504
2505
2506
2507
2508
2509
2510
2511
2512
2513
2514
2515
2516
2517
2518
2519
2520
2521
2522
2523
2524
2525
2526
2527
2528
2529
2530

                kv_next, dkv = jnp.unstack(kv_dkv)
                dq += dq_per_step
                dkv += dkv_per_step
                if config.attn_bias_type is not AttnBiasType.NO_BIAS:
                    dbias = dbias + dbias_per_step

                return (kv_next, kv_segment_ids_next, kv_segment_pos_next, dq, dkv, dbias)

            carry = (kv, kv_segment_ids, kv_segment_pos, dq, dkv, dbias)
            if helper.use_scanloop():
                carry = lax.fori_loop(0, cp_size, scan_kv_block, carry)
            else:
                for idx in range(cp_size):
                    carry = scan_kv_block(idx, carry)
            (_, _, _, dq, dkv, dbias) = carry

            # Final permute to put gradients back to their final resting place.
            dkv = helper.permute_kv(dkv, cp_perm)

            global_dbias = dbias
            if config.attn_bias_type is not AttnBiasType.NO_BIAS:
                global_dbias = all_reduce_sum_along_dp_fsdp(dbias, mesh)

            dk, dv = helper.unstack_kv(dkv)
            return dq, dk, dv, global_dbias

        return mesh, bwd_impl, out_shardings, arg_shardings


register_primitive(FusedRingAttnStripedBwdPrimitive)


2531
2532
2533
2534
2535
2536
2537
2538
2539
2540
def _maybe_context_parallel_axis(cp_axis: str):
    if not cp_axis:
        gmr = global_mesh_resource()
        if gmr is not None:
            cp_axis = gmr.cp_resource
        else:
            cp_axis = ""
    return cp_axis


2541
2542
2543
def fused_attn_fwd(
    qkv: Tuple[jnp.ndarray, ...],
    bias: Optional[jnp.ndarray],
2544
    sequence_descriptor: SequenceDescriptor,
2545
    seed: Optional[jnp.ndarray],
Reese Wang's avatar
Reese Wang committed
2546
2547
2548
    attn_bias_type: AttnBiasType,
    attn_mask_type: AttnMaskType,
    qkv_layout: QKVLayout,
2549
2550
2551
    scaling_factor: float,
    dropout_probability: float,
    is_training: bool,
2552
    max_segments_per_seq: int,
2553
    window_size: Optional[Tuple[int, int]] = None,
2554
    context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT,
2555
2556
    context_parallel_causal_load_balanced: bool = False,
    context_parallel_axis: str = "",
2557
) -> jnp.ndarray:
2558
    """
2559
2560
2561
2562
2563
2564
2565
2566
2567
2568
2569
2570
2571
2572
2573
2574
2575
2576
2577
2578
    Perform the forward pass of with cuDNN fused attention implementations.

    This function implements the following formula:
        BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2
    Args:
        qkv (Tuple[jnp.ndarray, ...]): A tuple containing query, key, and value tensors.
        It supports three formats:
            - `(qkv_packed,)`: For interleaved QKV packed format, typically used when query, key,
              and value have the same shape (e.g., self-attention).
            - `(query, kv_packed)`: For separate query and KV packed format, typically used when
              query has a different shape (e.g., cross-attention).
            - `(query, key, value)`: For separate query, key, and value tensors.
        bias (Optional[jnp.ndarray]): An optional bias tensor to be added to the attention scores.
        q_seqlen (jnp.ndarray): Sequence lengths for the query, with shape [batch,].
        kv_seqlen (jnp.ndarray): Sequence lengths for the key and value, with shape [batch,].
        q_seq_offsets (jnp.ndarray):
            The offsets in the sequence dim for the query, with shape [batch + 1,].
        kv_seq_offsets (jnp.ndarray):
            The offsets in the sequence dim for the query, with shape [batch + 1,].
        seed (Optional[jnp.ndarray]): Optional random seed for dropout.
Reese Wang's avatar
Reese Wang committed
2579
2580
2581
        attn_bias_type (AttnBiasType): Type of attention bias.
        attn_mask_type (AttnMaskType): Type of attention mask.
        qkv_layout (QKVLayout): Layout of the QKV tensors.
2582
2583
2584
        scaling_factor (float): Scaling factor for the attention scores.
        dropout_probability (float): Dropout probability to apply during attention.
        is_training (bool): Flag indicating whether the model is in training mode.
2585
2586
2587
2588
2589
        max_segments_per_seq (int):
            Indicating the maximum number of segments inside a sequence. This parameter is to
            constrain the limit usage and need to be static during the e2e training. The XLA compile
            time and memory consumption is proportional to `max_segments_per_seq`.
        window_size (Optional[Tuple[int, int]]): Sliding window size.
2590
2591
2592
        context_parallel_causal_load_balanced (bool):
            Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
        context_parallel_axis (str): The name of the context parallel axis.
2593
2594
    Returns:
        (jnp.ndarray): The output tensor from the fused attention.
2595
    """
2596
2597
2598
    seed = _FusedAttnRNGStateChecker().check_seed(seed, dropout_probability, is_training)
    # For optional tensors, which custom calls doesn't support None
    _not_used = jnp.zeros(0, dtype=qkv[0].dtype)
2599

Reese Wang's avatar
Reese Wang committed
2600
2601
2602
2603
2604
2605
2606
2607
2608
2609
2610
2611
2612
2613
2614
2615
2616
    if qkv_layout.is_qkvpacked():
        assert len(qkv) == 1, f"qkv=(packed_qkv,) is expected with {qkv_layout=} but got {qkv=}"
        qkv_for_primitive = [*qkv, _not_used, _not_used]
    elif qkv_layout.is_kvpacked():
        assert (
            len(qkv) == 2
        ), f"qkv=(query, packed_kv) is expected with {qkv_layout=} but got {qkv=}"
        qkv_for_primitive = [*qkv, _not_used]
    elif qkv_layout.is_separate():
        assert (
            len(qkv) == 3
        ), f"qkv=(query, key, value) is expected with {qkv_layout=} but got {qkv=}"
        qkv_for_primitive = qkv
    else:
        raise ValueError(f"Unknown {qkv_layout=}")

    if attn_bias_type == AttnBiasType.NO_BIAS:
2617
        assert bias is None
2618
        bias = jnp.zeros(0, dtype=qkv[0].dtype)
2619

2620
    fused_config = _FusedAttnConfig(
2621
2622
        attn_bias_type=attn_bias_type,
        attn_mask_type=attn_mask_type,
2623
        qkv_layout=qkv_layout,
2624
2625
2626
        scaling_factor=scaling_factor,
        dropout_probability=dropout_probability,
        is_training=is_training,
2627
        max_segments_per_seq=max_segments_per_seq,
2628
        window_size=(-1, -1) if window_size is None else window_size,
2629
2630
        context_parallel_load_balanced=context_parallel_causal_load_balanced,
        cp_axis=_maybe_context_parallel_axis(context_parallel_axis),
2631
        cp_striped_window_size=None,
2632
2633
    )

2634
    primitive = None
2635
2636
    match context_parallel_strategy:
        case CPStrategy.DEFAULT | CPStrategy.ALL_GATHER:
2637
            primitive = FusedAttnCPWithAllGatherFwdPrimitive.outer_primitive
2638
        case CPStrategy.RING:
Reese Wang's avatar
Reese Wang committed
2639
2640
2641
2642
2643
            # We must use stripe attention for THD-RING
            if qkv_layout.is_thd():
                primitive = FusedRingAttnStripedFwdPrimitive.outer_primitive
            else:
                primitive = FusedRingAttnFwdPrimitive.outer_primitive
2644

2645
    seq_desc_flatten, _ = jax.tree.flatten(sequence_descriptor)
2646
    output, softmax_aux, rng_state = primitive.bind(
2647
2648
2649
        *qkv_for_primitive,
        bias,
        seed,
2650
        *seq_desc_flatten,
2651
        config=fused_config,
2652
    )
2653
2654
    rng_state = with_sharding_constraint(rng_state, PartitionSpec(get_all_mesh_axes(), None))
    return (output, softmax_aux, rng_state)
2655
2656


2657
2658
2659
def fused_attn_bwd(
    qkv: Tuple[jnp.ndarray, ...],
    bias: Optional[jnp.ndarray],
2660
2661
2662
2663
    softmax_aux: jnp.ndarray,
    rng_state: jnp.ndarray,
    output: jnp.ndarray,
    doutput: jnp.ndarray,
2664
    sequence_descriptor: SequenceDescriptor,
Reese Wang's avatar
Reese Wang committed
2665
2666
2667
    attn_bias_type: AttnBiasType,
    attn_mask_type: AttnMaskType,
    qkv_layout: QKVLayout,
2668
2669
2670
    scaling_factor: float,
    dropout_probability: float,
    is_training: bool,
2671
    max_segments_per_seq: int,
2672
    window_size: Optional[Tuple[int, int]] = None,
2673
    context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT,
2674
2675
    context_parallel_causal_load_balanced: bool = False,
    context_parallel_axis: str = "",
2676
):
2677
    """
2678
2679
2680
2681
2682
2683
2684
2685
2686
2687
2688
2689
2690
2691
2692
2693
2694
2695
2696
2697
2698
    Perform the backward pass of the cuDNN fused attention implementations.

    Args:
        qkv (Tuple[jnp.ndarray, ...]): A tuple containing the original query, key, and value tensors
        used in the forward pass. It supports three formats:
            - `(qkv_packed,)`: For interleaved QKV packed format, typically used when query, key,
              and value have the same shape (e.g., self-attention).
            - `(query, kv_packed)`: For separate query and KV packed format, typically used when
              query has a different shape (e.g., cross-attention).
            - `(query, key, value)`: For separate query, key, and value tensors.
        bias (Optional[jnp.ndarray]): An optional bias tensor to be added to the attention scores.
        softmax_aux (jnp.ndarray): Auxiliary tensors from the softmax step used in the forward pass.
        rng_state (jnp.ndarray): Auxiliary tensors to save the random state in the forward pass.
        output (jnp.ndarray): The output tensor from the forward pass.
        doutput (jnp.ndarray): The gradient with respect to the output.
        q_seqlen (jnp.ndarray): Sequence lengths for the query, with shape [batch,].
        kv_seqlen (jnp.ndarray): Sequence lengths for the key and value, with shape [batch,].
        q_seq_offsets (jnp.ndarray):
            The offsets in the sequence dim for the query, with shape [batch + 1,].
        kv_seq_offsets (jnp.ndarray):
            The offsets in the sequence dim for the query, with shape [batch + 1,].
Reese Wang's avatar
Reese Wang committed
2699
2700
2701
        attn_bias_type (AttnBiasType): Type of attention bias.
        attn_mask_type (AttnMaskType): Type of attention mask.
        qkv_layout (QKVLayout): Layout of the QKV tensors.
2702
2703
2704
        scaling_factor (float): Scaling factor for the attention scores.
        dropout_probability (float): Dropout probability to apply during attention.
        is_training (bool): Flag indicating whether the model is in training mode.
2705
2706
2707
2708
2709
        max_segments_per_seq (int):
            Indicating the maximum number of segments inside a sequence. This parameter is to
            constrain the limit usage and need to be static during the e2e training. The XLA compile
            time and memory consumption is proportional to `max_segments_per_seq`.
        window_size (Optional[Tuple[int, int]]): Sliding window size .
2710
2711
2712
        context_parallel_causal_load_balanced (bool):
            Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
        context_parallel_axis (str): The name of the context parallel axis.
2713
2714
2715
2716
2717
    Returns:
        Tuple[jnp.ndarray, ...], jnp.ndarray:
        - The first tuple contains the gradients with respect to the input `qkv` tensors in the
          same format as the input `qkv`.
        - The second value is the gradient with respect to `bias`, or `None` if `bias` is `None`.
2718
    """
2719
2720
2721
    # For optional tensors, which custom calls doesn't support None
    _not_used = jnp.zeros(0, dtype=qkv[0].dtype)

Reese Wang's avatar
Reese Wang committed
2722
2723
2724
2725
2726
2727
2728
2729
2730
2731
2732
2733
2734
2735
2736
2737
2738
    if qkv_layout.is_qkvpacked():
        assert len(qkv) == 1, f"qkv=(packed_qkv,) is expected with {qkv_layout=} but got {qkv=}"
        qkv_for_primitive = [*qkv, _not_used, _not_used]
    elif qkv_layout.is_kvpacked():
        assert (
            len(qkv) == 2
        ), f"qkv=(query, packed_kv) is expected with {qkv_layout=} but got {qkv=}"
        qkv_for_primitive = [*qkv, _not_used]
    elif qkv_layout.is_separate():
        assert (
            len(qkv) == 3
        ), f"qkv=(query, key, value) is expected with {qkv_layout=} but got {qkv=}"
        qkv_for_primitive = qkv
    else:
        raise ValueError(f"Unknown {qkv_layout=}")

    if attn_bias_type == AttnBiasType.NO_BIAS:
2739
        assert bias is None
2740
        bias = jnp.zeros(0, dtype=qkv[0].dtype)
2741

2742
2743
2744
2745
2746
    if 100 in get_all_device_compute_capability():
        assert not (
            attn_bias_type != AttnBiasType.NO_BIAS and dropout_probability != 0
        ), "For sm100, bprop kernel support for dropout + determinism (bias) is not supported"

2747
2748
2749
2750
2751
2752
2753
2754
    fused_config = _FusedAttnConfig(
        attn_bias_type=attn_bias_type,
        attn_mask_type=attn_mask_type,
        qkv_layout=qkv_layout,
        scaling_factor=scaling_factor,
        dropout_probability=dropout_probability,
        is_training=is_training,
        max_segments_per_seq=max_segments_per_seq,
2755
        window_size=(-1, -1) if window_size is None else window_size,
2756
2757
        context_parallel_load_balanced=context_parallel_causal_load_balanced,
        cp_axis=_maybe_context_parallel_axis(context_parallel_axis),
2758
        cp_striped_window_size=None,
2759
2760
    )

2761
    primitive = None
2762
2763
    match context_parallel_strategy:
        case CPStrategy.DEFAULT | CPStrategy.ALL_GATHER:
2764
            primitive = FusedAttnCPWithAllGatherBwdPrimitive.outer_primitive
2765
        case CPStrategy.RING:
Reese Wang's avatar
Reese Wang committed
2766
2767
2768
2769
            if qkv_layout.is_thd():
                primitive = FusedRingAttnStripedBwdPrimitive.outer_primitive
            else:
                primitive = FusedRingAttnBwdPrimitive.outer_primitive
2770
2771
2772

    seq_desc_flatten, _ = jax.tree.flatten(sequence_descriptor)
    *qkv_grads, bias_grad = primitive.bind(
2773
        *qkv_for_primitive,
2774
2775
2776
2777
2778
        bias,
        softmax_aux,
        rng_state,
        output,
        doutput,
2779
        *seq_desc_flatten,
2780
        config=fused_config,
2781
    )
2782
    return tuple(qkv_grads[: len(qkv)]), bias_grad