attention.py 101 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.
"""JAX/TE custom ops for attention"""
import operator
6
import os
7
import warnings
8
9
10
11
from dataclasses import dataclass, replace
from functools import partial, reduce
from typing import Optional, Tuple
from packaging import version
12

13
import jax
14
import jax.numpy as jnp
15
from jax import dtypes, lax
16
17
18
from jax.interpreters import mlir
from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding
19
20
21

import transformer_engine_jax
from transformer_engine_jax import NVTE_Fused_Attn_Backend
Reese Wang's avatar
Reese Wang committed
22
23
24
25
26
27
28
29
from transformer_engine.jax.attention import (
    AttnBiasType,
    AttnMaskType,
    QKVLayout,
    QKVFormat,
    CPStrategy,
    SequenceDescriptor,
)
30

31
32
33
34
35
36
from .base import BasePrimitive, register_primitive
from .custom_call import custom_caller, CustomCallArgsWrapper
from .misc import (
    check_valid_batch_dims,
    jax_dtype_to_te_dtype,
    te_dtype_to_jax_dtype,
37
    get_padded_spec,
38
    get_cudnn_version,
39
    is_ffi_enabled,
40
41
)
from ..sharding import (
42
43
    global_mesh_resource,
    lax_paral_op,
44
    all_reduce_sum_along_dp_fsdp,
45
46
    get_mesh_axis_size,
    get_mesh_axis_rank,
47
48
49
50
51
    get_all_mesh_axes,
    num_of_devices,
)


52
53
54
55
56
57
if version.parse(jax.__version__) >= version.parse("0.5.0"):
    from jax import ffi  # pylint: disable=ungrouped-imports
else:
    from jax.extend import ffi  # pylint: disable=ungrouped-imports


58
59
60
61
62
__all__ = [
    "FusedAttnHelper",
    "fused_attn_fwd",
    "fused_attn_bwd",
]
63
64


65
66
67
68
69
70
71
72
73
74
75
@partial(
    jax.tree_util.register_dataclass,
    data_fields=[],
    meta_fields=[
        "attn_bias_type",
        "attn_mask_type",
        "qkv_layout",
        "scaling_factor",
        "dropout_probability",
        "is_training",
        "max_segments_per_seq",
76
        "window_size",
77
78
79
80
81
82
83
84
85
86
        "context_parallel_load_balanced",
        "cp_axis",
    ],
)
@dataclass(frozen=True)
class _FusedAttnConfig:
    """
    Passes static configuration properties of fused attention.
    """

Reese Wang's avatar
Reese Wang committed
87
88
89
    attn_bias_type: AttnBiasType
    attn_mask_type: AttnMaskType
    qkv_layout: QKVLayout
90
91
92
93
    scaling_factor: float
    dropout_probability: float
    is_training: bool
    max_segments_per_seq: int
94
    window_size: Tuple[int, int]
95
96
97
98
    context_parallel_load_balanced: bool
    cp_axis: str


99
100
101
102
103
104
105
106
@dataclass(frozen=True)
class FusedAttnHelper:
    """
    Helper for the fused attention backend
    """

    q_dtype: jnp.dtype
    kv_dtype: jnp.dtype
Reese Wang's avatar
Reese Wang committed
107
108
109
    qkv_layout: QKVLayout
    attn_bias_type: AttnBiasType
    attn_mask_type: AttnMaskType
110
111
112
113
114
115
    dropout_probability: float
    q_num_heads: int
    kv_num_heads: int
    q_max_seqlen: int
    kv_max_seqlen: int
    head_dim: int
116
    window_size: Tuple[int, int]
117
118
119
120
121
122
123
124

    def is_fused_attn_kernel_available(self):
        """Check if there is available fused attention kernel"""
        return self.get_fused_attn_backend() != NVTE_Fused_Attn_Backend.NVTE_No_Backend

    def get_fused_attn_backend(self):
        """Get the fused attention kernel backend"""
        return transformer_engine_jax.get_fused_attn_backend(
125
126
            jax_dtype_to_te_dtype(self.q_dtype),
            jax_dtype_to_te_dtype(self.kv_dtype),
Reese Wang's avatar
Reese Wang committed
127
128
129
            self.qkv_layout.value,
            self.attn_bias_type.value,
            self.attn_mask_type.value,
130
131
132
133
134
135
            self.dropout_probability,
            self.q_num_heads,
            self.kv_num_heads,
            self.q_max_seqlen,
            self.kv_max_seqlen,
            self.head_dim,
136
137
            self.window_size[0],
            self.window_size[1],
138
        )
139

140
141
142
143
144
    @staticmethod
    def is_non_deterministic_allowed():
        """Check if non-deterministic kernels are allowed"""
        return bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1")))

145
146
147
    @staticmethod
    def parse_qkv_aval(q_aval, k_aval, v_aval, qkv_layout):
        """Parse qkv aval"""
Reese Wang's avatar
Reese Wang committed
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
        if qkv_layout.get_qkv_format() == QKVFormat.SBHD:
            raise NotImplementedError
        if qkv_layout.is_qkvpacked():
            *q_batch_shape, q_max_seqlen, nqkv, attn_heads, q_head_dim = q_aval.shape
            kv_batch_shape = q_batch_shape
            kv_max_seqlen = q_max_seqlen
            num_gqa_groups = attn_heads
            kv_head_dim = q_head_dim
            assert nqkv == 3
        elif qkv_layout.is_kvpacked():
            *q_batch_shape, q_max_seqlen, attn_heads, q_head_dim = q_aval.shape
            *kv_batch_shape, kv_max_seqlen, nkv, num_gqa_groups, kv_head_dim = k_aval.shape
            assert nkv == 2
        elif qkv_layout.is_separate():
            *q_batch_shape, q_max_seqlen, attn_heads, q_head_dim = q_aval.shape
            *kv_batch_shape, kv_max_seqlen, num_gqa_groups, kv_head_dim = k_aval.shape
            assert k_aval.shape == v_aval.shape, f"{k_aval.shape=} {v_aval.shape=}"
        else:
            raise ValueError(f"Unexpected {qkv_layout=}")
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
        assert q_batch_shape == kv_batch_shape
        assert q_head_dim == kv_head_dim
        assert q_aval.dtype == k_aval.dtype == v_aval.dtype

        return (q_batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, q_head_dim)


@dataclass(frozen=True)
class _FusedAttnRNGStateChecker:
    """
    Checker for guarding the fused attention rng state.
    The fused attention backend requires a 64 bits seed and a 64 bits offset.
    However, JAX doesn't enable 64 bits by default,
    so we have to emulate seed as two 32 bits array.
    The offset calculation is maintained in the backend.
    """
183

184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
    rng_state_dtype: jnp.dtype = jnp.uint32
    # (seed,) with internal dtype int64
    seed_size: int = 2
    # (seed, offset) with internal dtype int64
    rng_state_size: int = 2 * 2

    def check_seed(self, seed, dropout_probability, is_training):
        """
        Check the seed and convert the data type of seed if possible.
        """
        # Jax can't bind None, create a dummy tensor for None
        if seed is None:
            dropout_enabled = dropout_probability > 0 and is_training
            assert not dropout_enabled, "seed is not allowed to be None when dropout is enabled."
            seed = jnp.zeros(2, dtype=self.rng_state_dtype)
            seed = jnp.repeat(seed, num_of_devices())

        if seed.dtype != self.rng_state_dtype:
            warnings.warn(
                f"Requested {seed.dtype=} is not available, and will be "
                f"casted to dtype {self.rng_state_dtype}. "
205
206
                "Please use threefry/rbg/unsafe_rbg PRNG implementations to remove this warning."
            )
207
208
209
210
211
212
213
214
215
216
217
218
219
            seed = seed.astype(self.rng_state_dtype)

        assert seed.dtype == self.rng_state_dtype
        # Backend takes an int64_t seed, so only the first two u32 elements are taken
        assert seed.size >= self.seed_size

        return seed


def generate_cu_seqlen(actual_seqlen):
    """
    Generating cumsum seqlen for a batch
    """
220
221
    actual_seqlen = jnp.where(actual_seqlen < 0, 0, actual_seqlen)
    cu_seqlen = jnp.cumulative_sum(actual_seqlen, include_initial=True)
222
223
224
225
226
227
228
    return cu_seqlen


class FusedAttnFwdPrimitive(BasePrimitive):
    """
    Fused Attention Forward Primitive
    """
229

230
231
    name = "te_fused_attn_forward"
    multiple_results = True
232
    impl_static_args = (13,)
233
234
235
236
    inner_primitive = None
    outer_primitive = None

    @staticmethod
237
238
239
240
241
    def abstract(
        q_aval,
        k_aval,
        v_aval,
        bias_aval,
242
        seed_aval,
243
244
        q_seqlen_or_cu_seqlen_aval,
        kv_seqlen_or_cu_seqlen_aval,
245
246
        _q_seq_offsets,
        _k_seq_offsets,
247
248
249
250
        _q_segment_ids,
        _kv_segment_ids,
        _q_segment_pos,
        _kv_segment_pos,
251
        *,
252
        config: _FusedAttnConfig,
253
    ):
254
255
256
257
258
259
260
        """
        Fused attention fwd abstract
        """
        q_dtype = dtypes.canonicalize_dtype(q_aval.dtype)
        k_dtype = dtypes.canonicalize_dtype(k_aval.dtype)
        v_dtype = dtypes.canonicalize_dtype(v_aval.dtype)
        bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype)
261
262
263
264
265
266
267
        assert (
            q_dtype == k_dtype == v_dtype == bias_dtype
        ), f"q_dtype={q_dtype}, k_dtype={k_dtype}, v_dtype={v_dtype}, bias_dtype={bias_dtype}"
        assert q_seqlen_or_cu_seqlen_aval.dtype == kv_seqlen_or_cu_seqlen_aval.dtype, (
            f"q_seqlen_or_cu_seqlen_aval={q_seqlen_or_cu_seqlen_aval},"
            f" kv_seqlen_or_cu_seqlen_aval={kv_seqlen_or_cu_seqlen_aval}"
        )
268

269
        batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = (
270
            FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout)
271
        )
272
273
274
275
276

        output_shape = (*batch_shape, q_max_seqlen, attn_heads, head_dim)
        out_aval = q_aval.update(shape=output_shape, dtype=q_dtype)

        # backend determines the softmax buffer shape/dtype
277
278
279
        backend = FusedAttnHelper(
            q_dtype,
            k_dtype,
280
281
282
283
            config.qkv_layout,
            config.attn_bias_type,
            config.attn_mask_type,
            config.dropout_probability,
284
285
286
287
288
            attn_heads,
            num_gqa_groups,
            q_max_seqlen,
            kv_max_seqlen,
            head_dim,
289
            config.window_size,
290
        ).get_fused_attn_backend()
291
292
293
294
295

        if backend == NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen:
            softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, kv_max_seqlen)
            softmax_dtype = q_dtype
        elif backend == NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen:
296
297
298
299
300
301
302
303
304
305
            # cuDNN 9.6 reduces the required softmax shape
            if get_cudnn_version() >= (9, 6, 0):
                softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, 1)
            else:
                softmax_shape = (
                    *batch_shape,
                    attn_heads,
                    q_max_seqlen,
                    config.max_segments_per_seq,
                )
306
307
            softmax_dtype = dtypes.canonicalize_dtype(jnp.float32)
        else:
308
            raise ValueError(f"Unsupported {backend=}")
309
310
311
312
313
314
315
316
317
318
        softmax_aux_aval = q_aval.update(shape=softmax_shape, dtype=softmax_dtype)

        # JAX does not enable 64-bit int by default so we get XLA to allocate x8 memory with
        # 32-bit unsigned int to get the buffer size we need in the C++ kernel
        checker = _FusedAttnRNGStateChecker()
        seed_dtype = dtypes.canonicalize_dtype(seed_aval.dtype)
        assert seed_dtype == checker.rng_state_dtype
        rng_state_shape = (seed_aval.shape[0], checker.rng_state_size)
        rng_state_aval = seed_aval.update(shape=rng_state_shape, dtype=checker.rng_state_dtype)

Reese Wang's avatar
Reese Wang committed
319
        if config.attn_bias_type == AttnBiasType.NO_BIAS:
320
321
322
323
324
325
326
327
328
            bias_batch = bias_heads = 0
        else:
            *bias_batch_shape, bias_heads, _, _ = bias_aval.shape
            bias_batch = reduce(operator.mul, bias_batch_shape)

        # do a dummy kernel call here to get workspace buffer shapes/dtypes that XLA needs to
        # prepare for the active fused-attn backend
        input_batch = reduce(operator.mul, batch_shape)
        wkspace_info = transformer_engine_jax.get_fused_attn_fwd_workspace_sizes(
329
330
331
332
333
334
335
336
            input_batch,
            bias_batch,
            q_max_seqlen,
            kv_max_seqlen,
            attn_heads,
            num_gqa_groups,
            bias_heads,
            head_dim,
337
338
            config.scaling_factor,
            config.dropout_probability,
Reese Wang's avatar
Reese Wang committed
339
340
341
            config.attn_bias_type.value,
            config.attn_mask_type.value,
            config.qkv_layout.value,
342
            jax_dtype_to_te_dtype(q_aval.dtype),
343
344
            config.is_training,
            config.max_segments_per_seq,
345
346
            config.window_size[0],
            config.window_size[1],
347
348
349
350
        )
        wkspace_aval = q_aval.update(
            shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1])
        )
351
352
353
354
355
356
357
358

        return out_aval, softmax_aux_aval, rng_state_aval, wkspace_aval

    @staticmethod
    def outer_abstract(*args, **kwargs):
        """
        Fused attention fwd outer primitive abstract
        """
359
360
361
        out_aval, softmax_aux_aval, rng_state_aval, _ = FusedAttnFwdPrimitive.abstract(
            *args, **kwargs
        )
362
363
364
        return out_aval, softmax_aux_aval, rng_state_aval

    @staticmethod
365
366
367
368
369
370
    def lowering(
        ctx,
        q,
        k,
        v,
        bias,
371
        seed,
372
373
        q_cu_seqlen,
        kv_cu_seqlen,
374
375
        q_seq_offsets,
        k_seq_offsets,
376
377
378
379
        _q_segment_ids,
        _kv_segment_ids,
        _q_segment_pos,
        _kv_segment_pos,
380
        *,
381
        config: _FusedAttnConfig,
382
    ):
383
384
385
386
387
        """
        Fused attention fwd lowering rules
        """
        q_aval, k_aval, v_aval, bias_aval, *_ = ctx.avals_in

388
        batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = (
389
            FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout)
390
        )
391
392
393

        input_batch = reduce(operator.mul, batch_shape)

Reese Wang's avatar
Reese Wang committed
394
        if config.attn_bias_type == AttnBiasType.NO_BIAS:
395
396
397
398
399
            bias_batch = bias_heads = 0
        else:
            *bias_batch_shape, bias_heads, _, _ = bias_aval.shape
            bias_batch = reduce(operator.mul, bias_batch_shape)

400
        if is_ffi_enabled():
401
402
403
404
405
406
407
            name = "te_fused_attn_forward_ffi"
            out = ffi.ffi_lowering(name)(
                ctx,
                q,
                k,
                v,
                bias,
408
                seed,
409
410
411
412
                q_cu_seqlen,
                kv_cu_seqlen,
                q_seq_offsets,
                k_seq_offsets,
413
414
415
416
                _q_segment_ids,
                _kv_segment_ids,
                _q_segment_pos,
                _kv_segment_pos,  # ffi_lowering needs number of parameters meets primitive.lowering
417
418
419
420
421
422
423
424
425
426
427
                input_batch=input_batch,
                bias_batch=bias_batch,
                q_max_seqlen=q_max_seqlen,
                kv_max_seqlen=kv_max_seqlen,
                attn_heads=attn_heads,
                num_gqa_groups=num_gqa_groups,
                bias_heads=bias_heads,
                head_dim=head_dim,
                max_segments_per_seq=config.max_segments_per_seq,
                scaling_factor=float(config.scaling_factor),
                dropout_probability=float(config.dropout_probability),
Reese Wang's avatar
Reese Wang committed
428
429
430
                bias_type=int(config.attn_bias_type.value),
                mask_type=int(config.attn_mask_type.value),
                qkv_layout=int(config.qkv_layout.value),
431
432
433
434
435
436
437
438
439
440
441
                is_training=config.is_training,
                deterministic=not FusedAttnHelper.is_non_deterministic_allowed(),
                window_size_left=config.window_size[0],
                window_size_right=config.window_size[1],
            )
        else:
            operands = [
                q,
                k,
                v,
                bias,
442
                seed,
443
444
445
446
447
448
449
450
451
452
453
454
                q_cu_seqlen,
                kv_cu_seqlen,
                q_seq_offsets,
                k_seq_offsets,
            ]
            operand_shapes = map(lambda x: x.type.shape, operands)
            out_types = [
                ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
                for output in ctx.avals_out
            ]
            args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

455
456
            wkspace_aval = ctx.avals_out[-1]

457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
            opaque = transformer_engine_jax.pack_fused_attn_descriptor(
                input_batch,
                bias_batch,
                q_max_seqlen,
                kv_max_seqlen,
                attn_heads,
                num_gqa_groups,
                bias_heads,
                head_dim,
                config.max_segments_per_seq,
                wkspace_aval.size,
                config.scaling_factor,
                config.dropout_probability,
                config.attn_bias_type,
                config.attn_mask_type,
                config.qkv_layout,
                jax_dtype_to_te_dtype(q_aval.dtype),
                jax_dtype_to_te_dtype(wkspace_aval.dtype),
                config.is_training,
                not FusedAttnHelper.is_non_deterministic_allowed(),
                config.window_size[0],
                config.window_size[1],
            )
480

481
            out = custom_caller(FusedAttnFwdPrimitive.name, args, opaque, has_side_effect=False)
482
483
484
485

        return out

    @staticmethod
486
487
488
489
490
    def impl(
        q,
        k,
        v,
        bias,
491
        seed,
492
493
        q_seqlen,
        kv_seqlen,
494
495
        q_seq_offsets,
        k_seq_offsets,
496
497
498
499
        _q_segment_ids,
        _kv_segment_ids,
        _q_segment_pos,
        _kv_segment_pos,
500
        config: _FusedAttnConfig,
501
    ):
502
503
        assert FusedAttnFwdPrimitive.inner_primitive is not None

504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
        sequence_descriptor = SequenceDescriptor(
            seqlens=(q_seqlen, kv_seqlen),
            seq_offsets=(q_seq_offsets, k_seq_offsets),
            segment_ids=(_q_segment_ids, _kv_segment_ids),
            segment_pos=(_q_segment_pos, _kv_segment_pos),
        )

        (q_seqlen, kv_seqlen), (q_seq_offsets, k_seq_offsets) = (
            sequence_descriptor.get_seqlens_and_offsets(
                config.attn_mask_type,
                config.qkv_layout,
                config.window_size,
                config.max_segments_per_seq,
            )
        )

Reese Wang's avatar
Reese Wang committed
520
        if config.qkv_layout.is_thd():
521

522
            def _fix_len_take(x, condition, fill_value=-1):
523
524
525
526
                x_shape = x.shape
                x = x.flatten()
                size = x.size
                indices = jnp.nonzero(condition.flatten(), size=size, fill_value=size)[0]
527
                y = jnp.take(x, indices, fill_value=fill_value)
528
529
530
531
532
533
534
535
536
537
                return jnp.reshape(y, x_shape)

            def convert_to_2d(offsets, batch, max_seqlen):
                offsets_2d = jnp.where(
                    offsets >= 0,
                    offsets + (jnp.arange(batch) * max_seqlen)[..., jnp.newaxis],
                    offsets,
                )
                return offsets_2d

Reese Wang's avatar
Reese Wang committed
538
539
540
541
542
            batch, q_max_seqlen, kv_max_seqlen, *_ = FusedAttnHelper.parse_qkv_aval(
                q, k, v, config.qkv_layout
            )
            assert len(batch) == 1, f"Expected len(batch) == 1, but got {len(batch)=}"
            kv_batch = q_batch = batch[0]
543
544

            # Gather valid q_seqlen, which is greater than 0
545
            # cuDNN version < 9.3.0:
546
            # [[3, 5, 7, -1, -1], [2, 4, 6, -1, -1]] -> [[3, 5, 7, 2, 4], [6, -1, -1, -1, -1]]
547
548
549
550
551
552
            # cuDNN version >= 9.3.0, which supports act_seqlen = 0
            # [[3, 5, 7, -1, -1], [2, 4, 6, -1, -1]] -> [[3, 5, 7, 2, 4], [6, 0, 0, 0, 0]]
            if get_cudnn_version() >= (9, 3, 0):
                fill_value = 0
            else:
                fill_value = -1
553

554
555
            q_seqlen = _fix_len_take(q_seqlen, q_seqlen > 0, fill_value=fill_value)
            kv_seqlen = _fix_len_take(kv_seqlen, kv_seqlen > 0, fill_value=fill_value)
556
557
558
559
560

            # Flatten the offset calculation
            # max_seqlen = 8, [[0, 3, 5, -1], [0, 2, 4, -1]] -> [[0, 3, 5, -1], [8, 11, 13, -1]]
            q_seq_offsets = convert_to_2d(q_seq_offsets, q_batch, q_max_seqlen)
            k_seq_offsets = convert_to_2d(k_seq_offsets, kv_batch, kv_max_seqlen)
561

562
563
            # Gather valid q_seq_offsets, which is greater and equal to 0
            # [[0, 3, 5, -1], [8, 11, 13, -1]] -> [[0, 3, 5, 8], [11, 13, -1, -1]]
564
            # And set the unused position to max size (batch * max_seqlen)
565
            # [[0, 3, 5, 8], [11, 13, -1, -1]] -> [[0, 3, 5, 8], [11, 13, b*s, b*s]]
566
567
568
569
570
571
            q_seq_offsets = _fix_len_take(
                q_seq_offsets, q_seq_offsets >= 0, fill_value=q_batch * q_max_seqlen
            )
            k_seq_offsets = _fix_len_take(
                k_seq_offsets, k_seq_offsets >= 0, fill_value=kv_batch * kv_max_seqlen
            )
572
573
574

        q_cu_seqlen = generate_cu_seqlen(q_seqlen.flatten())
        kv_cu_seqlen = generate_cu_seqlen(kv_seqlen.flatten())
575
576
577
578
579
580

        output, softmax_aux, rng_state, _ = FusedAttnFwdPrimitive.inner_primitive.bind(
            q,
            k,
            v,
            bias,
581
            seed,
582
583
            q_cu_seqlen,
            kv_cu_seqlen,
584
585
            q_seq_offsets,
            k_seq_offsets,
586
587
588
589
            _q_segment_ids,
            _kv_segment_ids,
            _q_segment_pos,
            _kv_segment_pos,
590
            config=config,
591
        )
592
593
594
        return output, softmax_aux, rng_state

    @staticmethod
595
    def batcher(batched_args, batch_dims, *, config):
596
597
        check_valid_batch_dims(batch_dims)
        assert FusedAttnFwdPrimitive.outer_primitive is not None
598
        q_bdim, _, _, _, seed_bdim, *_ = batch_dims
599
600

        out_bdims = q_bdim, q_bdim, seed_bdim
601
        return (
602
            FusedAttnFwdPrimitive.outer_primitive.bind(*batched_args, config=config),
603
604
            out_bdims,
        )
605
606

    @staticmethod
607
608
    def infer_sharding_from_operands(config, mesh, arg_infos, result_infos):
        del result_infos
609
        q_spec = get_padded_spec(arg_infos[0])
Reese Wang's avatar
Reese Wang committed
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
        if config.qkv_layout.is_qkvpacked():
            # q_spec = (...batch, q_seqlen, 3, head, hidden)
            out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec[:-3], *q_spec[-2:]))
            softmax_aux_sharding = NamedSharding(
                mesh, PartitionSpec(*q_spec[:-4], q_spec[-2], q_spec[-4], None)
            )
        elif config.qkv_layout.is_kvpacked():
            # q_spec = (...batch, q_seqlen, head, hidden)
            # k_spec = (...batch, kv_seqlen, 2, num_gqa_groups, hidden)
            out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
            softmax_aux_sharding = NamedSharding(
                mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], None)
            )
        elif config.qkv_layout.is_separate():
            # q_spec = (...batch, q_seqlen, head, hidden)
            # k_spec = (...batch, kv_seqlen, num_gqa_groups, hidden)
            out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
            softmax_aux_sharding = NamedSharding(
                mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], None)
            )
        else:
            raise ValueError(f"Unsupported {config.qkv_layout=}")
632
633
634
635
        rng_state_sharding = NamedSharding(mesh, PartitionSpec(get_all_mesh_axes(), None))
        return (out_sharding, softmax_aux_sharding, rng_state_sharding)

    @staticmethod
636
    def partition(config, mesh, arg_infos, result_infos):
637
638
        out_sharding = result_infos[0].sharding
        softmax_aux_sharding = result_infos[1].sharding
639
640
641
        rng_state_sharding = seed_sharding = NamedSharding(
            mesh, PartitionSpec(get_all_mesh_axes(), None)
        )
642
643
        arg_shardings = [arg_i.sharding for arg_i in arg_infos]
        arg_shardings[4] = seed_sharding
644
645
        arg_shardings[-1] = arg_shardings[-3]
        arg_shardings[-2] = arg_shardings[-4]
646
        arg_shardings = tuple(arg_shardings)
647
        out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding)
648
        impl = partial(FusedAttnFwdPrimitive.impl, config=config)
649
650
651
652
653
654
655
656
657
658
        return mesh, impl, out_shardings, arg_shardings


register_primitive(FusedAttnFwdPrimitive)


class FusedAttnBwdPrimitive(BasePrimitive):
    """
    Fused Attention Backward Primitive
    """
659

660
661
    name = "te_fused_attn_backward"
    multiple_results = True
662
    impl_static_args = (16,)
663
664
665
666
    inner_primitive = None
    outer_primitive = None

    @staticmethod
667
668
669
670
671
672
673
674
675
    def abstract(
        q_aval,
        k_aval,
        v_aval,
        bias_aval,
        softmax_aux_aval,
        rng_state_aval,
        output_aval,
        doutput_aval,
676
677
678
679
        q_seqlen_or_cu_seqlen_aval,
        kv_seqlen_or_cu_seqlen_aval,
        _q_seq_offsets,
        _k_seq_offsets,
680
681
682
683
        _q_segment_ids,
        _kv_segment_ids,
        _q_segment_pos,
        _kv_segment_pos,
684
        *,
685
        config,
686
    ):
687
688
689
690
691
692
693
694
695
696
697
        """
        Fused attention bwd abstract
        """
        del softmax_aux_aval, rng_state_aval, output_aval

        q_dtype = dtypes.canonicalize_dtype(q_aval.dtype)
        k_dtype = dtypes.canonicalize_dtype(k_aval.dtype)
        v_dtype = dtypes.canonicalize_dtype(v_aval.dtype)
        bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype)
        doutput_dtype = dtypes.canonicalize_dtype(doutput_aval.dtype)
        assert q_dtype == k_dtype == v_dtype == bias_dtype == doutput_dtype
698
        assert q_seqlen_or_cu_seqlen_aval.dtype == kv_seqlen_or_cu_seqlen_aval.dtype
699

700
        batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = (
701
            FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout)
702
        )
703

Reese Wang's avatar
Reese Wang committed
704
        if config.attn_bias_type == AttnBiasType.NO_BIAS:
705
706
707
708
709
            bias_batch = bias_heads = 0
        else:
            *bias_batch_shape, bias_heads, _, _ = bias_aval.shape
            bias_batch = reduce(operator.mul, bias_batch_shape)

710
711
        deterministic = not FusedAttnHelper.is_non_deterministic_allowed()

712
        input_batch = reduce(operator.mul, batch_shape)
713
714
715
716
717
718
719
720
721
        wkspace_shape, wkspace_dtype = transformer_engine_jax.get_fused_attn_bwd_workspace_sizes(
            input_batch,
            bias_batch,
            q_max_seqlen,
            kv_max_seqlen,
            attn_heads,
            num_gqa_groups,
            bias_heads,
            head_dim,
722
723
            config.scaling_factor,
            config.dropout_probability,
Reese Wang's avatar
Reese Wang committed
724
725
726
            config.attn_bias_type.value,
            config.attn_mask_type.value,
            config.qkv_layout.value,
727
            jax_dtype_to_te_dtype(q_aval.dtype),
728
            config.is_training,
729
            deterministic,
730
            config.max_segments_per_seq,
731
732
            config.window_size[0],
            config.window_size[1],
733
        )
734
735
736
737
738

        dq_aval = q_aval.update(shape=q_aval.shape, dtype=q_dtype)
        dk_aval = k_aval.update(shape=k_aval.shape, dtype=k_dtype)
        dv_aval = v_aval.update(shape=v_aval.shape, dtype=v_dtype)
        dbias_aval = bias_aval.update(shape=bias_aval.shape, dtype=bias_dtype)
739
740
741
        wkspace_aval = q_aval.update(
            shape=wkspace_shape, dtype=te_dtype_to_jax_dtype(wkspace_dtype)
        )
742
743
744
745
746
747
748
749

        return dq_aval, dk_aval, dv_aval, dbias_aval, wkspace_aval

    @staticmethod
    def outer_abstract(*args, **kwargs):
        """
        Fused attention fwd outer primitive abstract
        """
750
        dq_aval, dk_aval, dv_aval, dbias_aval, _ = FusedAttnBwdPrimitive.abstract(*args, **kwargs)
751
752
753
        return dq_aval, dk_aval, dv_aval, dbias_aval

    @staticmethod
754
755
756
757
758
759
760
761
762
763
764
765
    def lowering(
        ctx,
        q,
        k,
        v,
        bias,
        softmax_aux,
        rng_state,
        output,
        doutput,
        q_cu_seqlen,
        kv_cu_seqlen,
766
767
        q_seq_offsets,
        k_seq_offsets,
768
769
770
771
        q_segment_ids,
        kv_segment_ids,
        q_segment_pos,
        kv_segment_pos,
772
        *,
773
        config,
774
    ):
775
776
777
778
779
        """
        Fused attention bwd lowering rules
        """
        q_aval, k_aval, v_aval, bias_aval, *_ = ctx.avals_in

780
        batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = (
781
            FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout)
782
        )
783
784
785

        input_batch = reduce(operator.mul, batch_shape)

Reese Wang's avatar
Reese Wang committed
786
        if config.attn_bias_type == AttnBiasType.NO_BIAS:
787
788
789
790
791
            bias_batch = bias_heads = 0
        else:
            *bias_batch_shape, bias_heads, _, _ = bias_aval.shape
            bias_batch = reduce(operator.mul, bias_batch_shape)

792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
        if is_ffi_enabled():
            name = "te_fused_attn_backward_ffi"
            out = ffi.ffi_lowering(name)(
                ctx,
                q,
                k,
                v,
                bias,
                softmax_aux,
                rng_state,
                output,
                doutput,
                q_cu_seqlen,
                kv_cu_seqlen,
                q_seq_offsets,
                k_seq_offsets,
808
809
810
811
                q_segment_ids,
                kv_segment_ids,
                q_segment_pos,
                kv_segment_pos,  # ffi_lowering needs number of parameters meets primitive.lowering
812
813
814
815
816
817
818
819
820
821
822
                input_batch=input_batch,
                bias_batch=bias_batch,
                q_max_seqlen=q_max_seqlen,
                kv_max_seqlen=kv_max_seqlen,
                attn_heads=attn_heads,
                num_gqa_groups=num_gqa_groups,
                bias_heads=bias_heads,
                head_dim=head_dim,
                max_segments_per_seq=config.max_segments_per_seq,
                scaling_factor=float(config.scaling_factor),
                dropout_probability=float(config.dropout_probability),
Reese Wang's avatar
Reese Wang committed
823
824
825
                bias_type=int(config.attn_bias_type.value),
                mask_type=int(config.attn_mask_type.value),
                qkv_layout=int(config.qkv_layout.value),
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
                is_training=config.is_training,
                deterministic=not FusedAttnHelper.is_non_deterministic_allowed(),
                window_size_left=config.window_size[0],
                window_size_right=config.window_size[1],
            )
        else:
            operands = [
                q,
                k,
                v,
                bias,
                softmax_aux,
                rng_state,
                output,
                doutput,
                q_cu_seqlen,
                kv_cu_seqlen,
                q_seq_offsets,
                k_seq_offsets,
            ]
            operand_shapes = map(lambda x: x.type.shape, operands)
            out_types = [
                ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
                for output in ctx.avals_out
            ]
            args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

            wkspace_aval = ctx.avals_out[-1]
854

855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
            opaque = transformer_engine_jax.pack_fused_attn_descriptor(
                input_batch,
                bias_batch,
                q_max_seqlen,
                kv_max_seqlen,
                attn_heads,
                num_gqa_groups,
                bias_heads,
                head_dim,
                config.max_segments_per_seq,
                wkspace_aval.size,
                config.scaling_factor,
                config.dropout_probability,
                config.attn_bias_type,
                config.attn_mask_type,
                config.qkv_layout,
                jax_dtype_to_te_dtype(q_aval.dtype),
                jax_dtype_to_te_dtype(wkspace_aval.dtype),
                config.is_training,
                not FusedAttnHelper.is_non_deterministic_allowed(),
                config.window_size[0],
                config.window_size[1],
            )
878

879
            out = custom_caller(FusedAttnBwdPrimitive.name, args, opaque, has_side_effect=False)
880
881
882
883

        return out

    @staticmethod
884
885
886
887
888
889
890
891
892
893
894
    def impl(
        q,
        k,
        v,
        bias,
        softmax_aux,
        rng_state,
        output,
        doutput,
        q_seqlen,
        kv_seqlen,
895
896
        q_seq_offsets,
        k_seq_offsets,
897
898
899
900
        _q_segment_ids,
        _kv_segment_ids,
        _q_segment_pos,
        _kv_segment_pos,
901
        config,
902
    ):
903
904
        assert FusedAttnBwdPrimitive.inner_primitive is not None

905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
        sequence_descriptor = SequenceDescriptor(
            seqlens=(q_seqlen, kv_seqlen),
            seq_offsets=(q_seq_offsets, k_seq_offsets),
            segment_ids=(_q_segment_ids, _kv_segment_ids),
            segment_pos=(_q_segment_pos, _kv_segment_pos),
        )

        (q_seqlen, kv_seqlen), (q_seq_offsets, k_seq_offsets) = (
            sequence_descriptor.get_seqlens_and_offsets(
                config.attn_mask_type,
                config.qkv_layout,
                config.window_size,
                config.max_segments_per_seq,
            )
        )

Reese Wang's avatar
Reese Wang committed
921
        if config.qkv_layout.is_thd():
922

923
            def _fix_len_take(x, condition, fill_value=-1):
924
925
926
927
928
                x_shape = x.shape
                x = x.flatten()
                size = x.size
                indices = jnp.nonzero(condition.flatten(), size=size, fill_value=size)[0]
                # TODO(rewang): try indices_are_sorted
929
                y = jnp.take(x, indices, fill_value=fill_value)
930
931
932
933
934
935
936
937
938
939
                return jnp.reshape(y, x_shape)

            def convert_to_2d(offsets, batch, max_seqlen):
                offsets_2d = jnp.where(
                    offsets >= 0,
                    offsets + (jnp.arange(batch) * max_seqlen)[..., jnp.newaxis],
                    offsets,
                )
                return offsets_2d

Reese Wang's avatar
Reese Wang committed
940
941
942
943
944
            batch, q_max_seqlen, kv_max_seqlen, *_ = FusedAttnHelper.parse_qkv_aval(
                q, k, v, config.qkv_layout
            )
            assert len(batch) == 1
            kv_batch = q_batch = batch[0]
945
946

            # Gather valid q_seqlen, which is greater than 0
947
            # cuDNN version < 9.3.0:
948
            # [[3, 5, 7, -1, -1], [2, 4, 6, -1, -1]] -> [[3, 5, 7, 2, 4], [6, -1, -1, -1, -1]]
949
950
951
952
953
954
955
956
            # cuDNN version >= 9.3.0, which supports act_seqlen = 0
            # [[3, 5, 7, -1, -1], [2, 4, 6, -1, -1]] -> [[3, 5, 7, 2, 4], [6, 0, 0, 0, 0]]
            if get_cudnn_version() >= (9, 3, 0):
                fill_value = 0
            else:
                fill_value = -1
            q_seqlen = _fix_len_take(q_seqlen, q_seqlen > 0, fill_value=fill_value)
            kv_seqlen = _fix_len_take(kv_seqlen, kv_seqlen > 0, fill_value=fill_value)
957
958
959
960
961

            # Flatten the offset calculation
            # max_seqlen = 8, [[0, 3, 5, -1], [0, 2, 4, -1]] -> [[0, 3, 5, -1], [8, 11, 13, -1]]
            q_seq_offsets = convert_to_2d(q_seq_offsets, q_batch, q_max_seqlen)
            k_seq_offsets = convert_to_2d(k_seq_offsets, kv_batch, kv_max_seqlen)
962

963
964
            # Gather valid q_seq_offsets, which is greater and equal to 0
            # [[0, 3, 5, -1], [8, 11, 13, -1]] -> [[0, 3, 5, 8], [11, 13, -1, -1]]
965
            # And set the unused position to max size (batch * max_seqlen)
966
            # [[0, 3, 5, 8], [11, 13, -1, -1]] -> [[0, 3, 5, 8], [11, 13, b*s, b*s]]
967
968
969
970
971
972
            q_seq_offsets = _fix_len_take(
                q_seq_offsets, q_seq_offsets >= 0, fill_value=q_batch * q_max_seqlen
            )
            k_seq_offsets = _fix_len_take(
                k_seq_offsets, k_seq_offsets >= 0, fill_value=kv_batch * kv_max_seqlen
            )
973
974
975

        q_cu_seqlen = generate_cu_seqlen(q_seqlen.flatten())
        kv_cu_seqlen = generate_cu_seqlen(kv_seqlen.flatten())
976
977
978
979
980
981
982
983
984
985
986
987

        dq, dk, dv, dbias, _ = FusedAttnBwdPrimitive.inner_primitive.bind(
            q,
            k,
            v,
            bias,
            softmax_aux,
            rng_state,
            output,
            doutput,
            q_cu_seqlen,
            kv_cu_seqlen,
988
989
            q_seq_offsets,
            k_seq_offsets,
990
991
992
993
            _q_segment_ids,
            _kv_segment_ids,
            _q_segment_pos,
            _kv_segment_pos,
994
            config=config,
995
        )
996
997
998
        return dq, dk, dv, dbias

    @staticmethod
999
    def batcher(batched_args, batch_dims, *, config):
1000
1001
1002
1003
1004
        check_valid_batch_dims(batch_dims)
        assert FusedAttnBwdPrimitive.outer_primitive is not None
        q_bdim, k_bdim, v_bdim, *_ = batch_dims

        out_bdims = q_bdim, k_bdim, v_bdim, q_bdim
1005
        return (
1006
            FusedAttnBwdPrimitive.outer_primitive.bind(*batched_args, config=config),
1007
1008
            out_bdims,
        )
1009
1010

    @staticmethod
1011
1012
    def infer_sharding_from_operands(config, mesh, arg_infos, result_infos):
        del config, result_infos
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
        q_spec = get_padded_spec(arg_infos[0])
        k_spec = get_padded_spec(arg_infos[1])
        v_spec = get_padded_spec(arg_infos[2])
        bias_spec = get_padded_spec(arg_infos[3])
        dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
        dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec))
        dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec))
        dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec))
        return (dq_sharding, dk_sharding, dv_sharding, dbias_sharding)

    @staticmethod
1024
    def partition(config, mesh, arg_infos, result_infos):
1025
1026
1027
1028
1029
1030
1031
1032
1033
        del result_infos
        q_spec = get_padded_spec(arg_infos[0])
        k_spec = get_padded_spec(arg_infos[1])
        v_spec = get_padded_spec(arg_infos[2])
        bias_spec = get_padded_spec(arg_infos[3])
        dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
        dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec))
        dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec))
        dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec))
1034
1035
1036
1037
        arg_shardings = [arg_i.sharding for arg_i in arg_infos]
        arg_shardings[-1] = arg_shardings[-3]
        arg_shardings[-2] = arg_shardings[-4]
        arg_shardings = tuple(arg_shardings)
1038
1039
        out_shardings = (dq_sharding, dk_sharding, dv_sharding, dbias_sharding)

1040
        def sharded_impl(
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
            q,
            k,
            v,
            bias,
            softmax_aux,
            rng_state,
            output,
            doutput,
            q_cu_seqlen,
            kv_cu_seqlen,
            q_seq_offsets,
            k_seq_offsets,
1053
1054
1055
1056
            _q_segment_ids,
            _kv_segment_ids,
            _q_segment_pos,
            _kv_segment_pos,
1057
        ):
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
            local_dq, local_dk, local_dv, local_dbias = FusedAttnBwdPrimitive.impl(
                q,
                k,
                v,
                bias,
                softmax_aux,
                rng_state,
                output,
                doutput,
                q_cu_seqlen,
                kv_cu_seqlen,
1069
1070
                q_seq_offsets,
                k_seq_offsets,
1071
1072
1073
1074
                _q_segment_ids,
                _kv_segment_ids,
                _q_segment_pos,
                _kv_segment_pos,
1075
                config=config,
1076
            )
1077
            global_dbias = local_dbias
Reese Wang's avatar
Reese Wang committed
1078
            if config.attn_bias_type is not AttnBiasType.NO_BIAS:
1079
                global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh)
1080
1081
1082
1083
1084
1085
1086
1087
            return local_dq, local_dk, local_dv, global_dbias

        return mesh, sharded_impl, out_shardings, arg_shardings


register_primitive(FusedAttnBwdPrimitive)


Reese Wang's avatar
Reese Wang committed
1088
def reorder_causal_dual_chunk_swap(tensor, cp_size: int, seq_dim: int, to_contiguous: bool):
1089
1090
1091
1092
1093
1094
1095
1096
1097
    """Reorders a tensor for load balancing the compute of causal attention."""
    if cp_size == 1:
        return tensor

    if cp_size % 2 != 0:
        raise ValueError(f"{cp_size=} must be a multiple of 2.")

    # Need to ensure we have 2 pairs to swap for balancing between cp ranks
    if tensor.shape[seq_dim] % (cp_size * 2) != 0:
Reese Wang's avatar
Reese Wang committed
1098
        raise ValueError(f"{tensor.shape[seq_dim]=} is not a multiple of {cp_size*2=}")
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139

    # [B, S, H, D] -> [B, 2*cp_size, S/2*cp_size, D]
    # [S, B, H, D] -> [2*cp_size, S/2*cp_size, B, H, D]
    ori_tensor_shape = tensor.shape
    tensor = tensor.reshape(
        (
            *ori_tensor_shape[:seq_dim],
            2 * cp_size,
            ori_tensor_shape[seq_dim] // (2 * cp_size),
            *ori_tensor_shape[seq_dim + 1 :],
        )
    )

    parts = []
    if not to_contiguous:
        for cp_rank in range(cp_size):
            # [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] -> [B, 2, S/2*cp_size, H, D]
            # [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] -> [2, S/2*cp_size, B, H, D]
            index = jnp.array([cp_rank, (2 * cp_size - cp_rank - 1)])
            parts.append(jnp.take(tensor, index, axis=seq_dim))
    else:
        for cp_rank in range(cp_size // 2):
            # [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] -> [B, 2, S/2*cp_size, H, D]
            # [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] -> [2, S/2*cp_size, B, H, D]
            base = 4 * cp_rank
            index = jnp.array([base, base + 2])
            parts.append(jnp.take(tensor, index, axis=seq_dim))
        for cp_rank in range(cp_size // 2):
            # [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] -> [B, 2, S/2*cp_size, H, D]
            # [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] -> [2, S/2*cp_size, B, H, D]
            base = 2 * cp_size - 1 - 4 * cp_rank
            index = jnp.array([base, base - 2])
            parts.append(jnp.take(tensor, index, axis=seq_dim))

    # [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D]
    # [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D]
    combined = jnp.stack(parts, axis=seq_dim)

    return combined.reshape(ori_tensor_shape)


Reese Wang's avatar
Reese Wang committed
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
def reorder_causal_striped(tensor, cp_size: int, seq_dim: int, is_inverse: bool):
    """Reorders a tensor for load balancing with striped pattern"""
    origin_shape = tensor.shape
    if origin_shape[seq_dim] % cp_size != 0:
        raise ValueError(
            "Expected origin_shape[seq_dim] is multiple of cp_size but got"
            f" {origin_shape[seq_dim]=} and {cp_size=}"
        )

    if not is_inverse:
        new_shape = [
            *origin_shape[:seq_dim],
            *[origin_shape[seq_dim] // cp_size, cp_size],
            *origin_shape[seq_dim + 1 :],
        ]
    else:
        new_shape = [
            *origin_shape[:seq_dim],
            *[cp_size, origin_shape[seq_dim] // cp_size],
            *origin_shape[seq_dim + 1 :],
        ]

    chunked_tensor = tensor.reshape(new_shape)
    reordered_chunked_tensor = jnp.swapaxes(chunked_tensor, seq_dim, seq_dim + 1)
    return reordered_chunked_tensor.reshape(origin_shape)


1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
@dataclass(frozen=True)
class _FusedAttnCPWithAllGatherHelper:
    """Helper class to assist with running the all-gather strategy for CP attention."""

    mesh: jax.sharding.Mesh
    config: _FusedAttnConfig

    def check_supported(self):
        """Checks if the context parallel implementation is supported by the given arguments."""
        header = "Context parallel fused attention"

Reese Wang's avatar
Reese Wang committed
1178
        allowed_layouts = [QKVLayout.BSHD_BS2HD, QKVLayout.BSHD_BSHD_BSHD]
1179
1180
1181
        if self.config.qkv_layout not in allowed_layouts:
            raise ValueError(
                f"{header} only supports layouts:"
1182
                f" {','.join(map(str, allowed_layouts))} got: {self.config.qkv_layout}"
1183
            )
1184

Reese Wang's avatar
Reese Wang committed
1185
        if self.config.attn_bias_type != AttnBiasType.NO_BIAS:
1186
            raise ValueError(f"{header} does not support bias got: {self.config.attn_bias_type}")
1187

Reese Wang's avatar
Reese Wang committed
1188
        allowed_masks = [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]
1189
1190
1191
        if self.config.attn_mask_type not in allowed_masks:
            raise ValueError(
                f"{header} only supports masking types: "
1192
                f" {','.join(map(str, allowed_masks))} got: {self.config.attn_mask_type}"
1193
            )
1194

1195
1196
1197
1198
1199
1200
1201
1202
        if self.config.max_segments_per_seq != 1:
            raise ValueError(
                f"{header} only supports max_segments_per_seq == 1 got:"
                f" {self.config.max_segments_per_seq}"
            )

        if self.config.dropout_probability != 0.0:
            raise ValueError(f"{header} does not support dropout")
1203
1204
1205

    def get_adjusted_mask(self):
        """Converts the mask for context parallelism."""
Reese Wang's avatar
Reese Wang committed
1206
1207
        if self.config.attn_mask_type == AttnMaskType.CAUSAL_MASK:
            return AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK
1208
1209
        return self.config.attn_mask_type

1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
    def get_step_config(self) -> _FusedAttnConfig:
        """Returns a _FusedAttnConfig for single CP step call to fused attention."""
        return _FusedAttnConfig(
            attn_bias_type=self.config.attn_bias_type,
            attn_mask_type=self.get_adjusted_mask(),
            qkv_layout=self.config.qkv_layout,
            scaling_factor=self.config.scaling_factor,
            dropout_probability=self.config.dropout_probability,
            is_training=self.config.is_training,
            max_segments_per_seq=self.config.max_segments_per_seq,
            window_size=self.config.window_size,
            context_parallel_load_balanced=self.config.context_parallel_load_balanced,
            cp_axis=self.config.cp_axis,
        )

1225
1226
1227
1228
    def all_gather_kv(self, k, v):
        """Performs a all-gather of k and v over context parallel ranks."""

        def ag(x):
1229
            x = lax_paral_op(
1230
1231
                x, lax.all_gather, self.config.cp_axis, mesh=self.mesh, axis=1, tiled=True
            )
1232
1233
            if self.config.context_parallel_load_balanced:
                cp_size = get_mesh_axis_size(self.config.cp_axis, self.mesh)
Reese Wang's avatar
Reese Wang committed
1234
                x = reorder_causal_dual_chunk_swap(x, cp_size, 1, to_contiguous=True)
1235
            return x
1236

Reese Wang's avatar
Reese Wang committed
1237
1238
1239
1240
        if self.config.qkv_layout.is_kvpacked():
            return ag(k), v
        if self.config.qkv_layout.is_separate():
            return ag(k), ag(v)
1241
1242
1243
1244
1245
1246
1247

        return k, v  # fall through

    def reduce_scatter_dkv(self, dk, dv):
        """Performs a reduce-scatter of dk and dv over context parallel ranks."""

        def rs(x):
1248
1249
            if self.config.context_parallel_load_balanced:
                cp_size = get_mesh_axis_size(self.config.cp_axis, self.mesh)
Reese Wang's avatar
Reese Wang committed
1250
                x = reorder_causal_dual_chunk_swap(x, cp_size, 1, to_contiguous=False)
1251

1252
1253
1254
1255
1256
1257
1258
1259
1260
            return lax_paral_op(
                x,
                lax.psum_scatter,
                self.config.cp_axis,
                mesh=self.mesh,
                scatter_dimension=1,
                tiled=True,
            )

Reese Wang's avatar
Reese Wang committed
1261
1262
1263
1264
        if self.config.qkv_layout.is_kvpacked():
            return rs(dk), dv
        if self.config.qkv_layout.is_separate():
            return rs(dk), rs(dv)
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300

        return dk, dv  # fall through

    def kv_seqlens_for_rank(self, cp_rank, kv_max_seqlen, kv_seqlen_per_subrank):
        """Returns sequence lengths of KV to use for each sub rank of the given cp_rank.

        Example: CP=4, MaxLen = 1024, Unbalanced
           cp_rank 0: [128, 256]
           cp_rank 1: [384, 512]
           cp_rank 2: [640, 768]
           cp_rank 3: [896, 1024]

        Example: CP=4, MaxLen = 1024, Balanced
           cp_rank 0: [128, 1024]
           cp_rank 1: [256, 896]
           cp_rank 2: [384, 768]
           cp_rank 3: [512, 640]
        """
        if self.config.context_parallel_load_balanced:
            kv_seq_this_rank = [
                (cp_rank + 1) * kv_seqlen_per_subrank,
                kv_max_seqlen - cp_rank * kv_seqlen_per_subrank,
            ]
        else:
            kv_seq_this_rank = [
                (cp_rank * 2 + 1) * kv_seqlen_per_subrank,
                (cp_rank * 2 + 2) * kv_seqlen_per_subrank,
            ]
        return kv_seq_this_rank

    def slice_kv(self, k, v, slice_seq_len):
        """Slices k and v tensors to a sequence length of slice_seq_len."""

        def sliced(x):
            return lax.dynamic_slice_in_dim(x, 0, slice_seq_len, axis=1)

Reese Wang's avatar
Reese Wang committed
1301
1302
1303
1304
        if self.config.qkv_layout.is_kvpacked():
            return sliced(k), v
        if self.config.qkv_layout.is_separate():
            return sliced(k), sliced(v)
1305
1306
1307
1308
1309
1310
1311
1312
1313

        return k, v  # fall through

    def pad_kv(self, dk, dv, pad_seq_len):
        """Pads dk and dv tensors to a sequence length of pad_seq_len."""

        def pad(x, npad):
            return jnp.pad(x, npad, "constant", constant_values=0.0)

Reese Wang's avatar
Reese Wang committed
1314
1315
1316
1317
1318
1319
        if self.config.qkv_layout.is_kvpacked():
            npad = [[0, 0], [0, pad_seq_len], [0, 0], [0, 0], [0, 0]]
            return pad(dk, npad), dv
        if self.config.qkv_layout.is_separate():
            npad = [[0, 0], [0, pad_seq_len], [0, 0], [0, 0]]
            return pad(dk, npad), pad(dv, npad)
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334

        return dk, dv  # fall through


class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive):
    """
    Fused Attention Forward with Context Parallelism Primitive

    This context parallel implementation uses all-gather to collect KV inputs from context parallel ranks.
    """

    @staticmethod
    def partition(config, mesh, arg_infos, result_infos):
        # Call base implementation for non-context parallel mesh to avoid unecessary work.
        is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1
1335
        assert (
1336
            not is_context_parallel or config.window_size[0] == -1
1337
        ), "Sliding window attention is not supported when context parallelism is enabled"
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
        if not is_context_parallel:
            return FusedAttnFwdPrimitive.partition(config, mesh, arg_infos, result_infos)

        helper = _FusedAttnCPWithAllGatherHelper(mesh, config)
        helper.check_supported()

        out_sharding = result_infos[0].sharding
        softmax_aux_sharding = result_infos[1].sharding
        rng_state_sharding = seed_sharding = NamedSharding(
            mesh, PartitionSpec(get_all_mesh_axes(), None)
        )
1349
1350
1351
        arg_shardings = [arg_i.sharding for arg_i in arg_infos]
        arg_shardings[4] = seed_sharding
        arg_shardings = tuple(arg_shardings)
1352
1353
        out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding)

1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
        def impl(
            q,
            k,
            v,
            bias,
            seed,
            q_seqlen,
            kv_seqlen,
            q_seq_offsets,
            k_seq_offsets,
            _q_segment_ids,
            _kv_segment_ids,
            _q_segment_pos,
            _kv_segment_pos,
        ):
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
            cp_size = get_mesh_axis_size(config.cp_axis, mesh)
            cp_rank = get_mesh_axis_rank(config.cp_axis, mesh)

            # cuDNN does not support right-aligned masking with dynamic sequence length padding.
            # Therefore we must explicitly instantiate each CP rank slicing and use a runtime switch
            # to select the appropriate computation. Each case generates a [..., SEQ/CP, ..] tensor
            # meeting the expectation of the SPMD model.
            # TODO(mgoldfarb-nvidia): When cuDNN supports we should be able to make use of a padding
            # mask/sequence length tensor to avoid this unrolled loop.
            def _cross_attn(idx, q, k, v, bias, q_seqlen, kv_seqlen, seed):
                kv_max_seqlen = k.shape[1]
                kv_seqlen_per_subrank = kv_max_seqlen // (cp_size * 2)
                assert kv_max_seqlen % cp_size == 0, "sequence length must evenly divide cp size"

                q_split = jnp.split(q, 2, axis=1)

                kv_seqlens_for_rank = helper.kv_seqlens_for_rank(
                    idx, kv_max_seqlen, kv_seqlen_per_subrank
                )

                results = []
                for sub_idx in range(2):
Reese Wang's avatar
Reese Wang committed
1391
                    if config.attn_mask_type == AttnMaskType.NO_MASK:
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
                        k_unmasked, v_unmasked = k, v  # full kv used for unmasked
                    else:
                        k_unmasked, v_unmasked = helper.slice_kv(k, v, kv_seqlens_for_rank[sub_idx])

                    q_seqlen_for_step = q_seqlen / (cp_size * 2)
                    num_kv_chunks = kv_max_seqlen // kv_seqlens_for_rank[sub_idx]
                    kv_seqlen_for_step = (kv_seqlen / (cp_size * 2)) * num_kv_chunks

                    output, softmax_aux, rng_state = FusedAttnFwdPrimitive.impl(
                        q_split[sub_idx],
                        k_unmasked,
                        v_unmasked,
                        bias,
1405
                        seed,
1406
1407
1408
1409
                        q_seqlen_for_step,
                        kv_seqlen_for_step,
                        q_seq_offsets,
                        k_seq_offsets,
1410
1411
1412
1413
                        _q_segment_ids,
                        _kv_segment_ids,
                        _q_segment_pos,
                        _kv_segment_pos,
1414
                        config=helper.get_step_config(),
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
                    )
                    results.append((output, softmax_aux, rng_state))

                output = jnp.concatenate((results[0][0], results[1][0]), axis=1)
                softmax_aux = jnp.concatenate((results[0][1], results[1][1]), axis=2)
                rng_state = results[1][2]  # Use the final RNG state

                return output, softmax_aux, rng_state

            k_ag, v_ag = helper.all_gather_kv(k, v)

            functions = [
                partial(_cross_attn, idx, q, k_ag, v_ag, bias, q_seqlen, kv_seqlen, seed)
                for idx in range(cp_size)
            ]

            return lax.switch(cp_rank, functions)

        return mesh, impl, out_shardings, arg_shardings


register_primitive(FusedAttnCPWithAllGatherFwdPrimitive)


class FusedAttnCPWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive):
    """
    Fused Attention Backward with Context Parallelism Primitive.

    This context parallel implementation uses all-gather to collect KV and dKV inputs from context parallel ranks.
    The gradients are subsequently reduce-scattered back to each context parallel rank.
    """

    @staticmethod
    def partition(config, mesh, arg_infos, result_infos):
        # Call base implementation for non-context parallel mesh to avoid unecessary work.
        is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1
1451
        assert (
1452
            not is_context_parallel or config.window_size[0] == -1
1453
        ), "Sliding window attention is not supported when context parallelism is enabled"
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
        if not is_context_parallel:
            return FusedAttnBwdPrimitive.partition(config, mesh, arg_infos, result_infos)

        # Ensure we can support this configuration with context parallelism.
        helper = _FusedAttnCPWithAllGatherHelper(mesh, config)
        helper.check_supported()

        del result_infos
        q_spec = get_padded_spec(arg_infos[0])
        k_spec = get_padded_spec(arg_infos[1])
        v_spec = get_padded_spec(arg_infos[2])
        bias_spec = get_padded_spec(arg_infos[3])
        dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
        dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec))
        dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec))
        dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec))
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
        out_shardings = (dq_sharding, dk_sharding, dv_sharding, dbias_sharding)

        def impl(
            q,
            k,
            v,
            bias,
            softmax_aux,
            rng_state,
            output,
            doutput,
            q_seqlen,
            kv_seqlen,
            q_seq_offsets,
            k_seq_offsets,
1486
1487
1488
1489
            _q_segment_ids,
            _kv_segment_ids,
            _q_segment_pos,
            _kv_segment_pos,
1490
1491
1492
1493
1494
1495
        ):
            cp_size = get_mesh_axis_size(config.cp_axis, mesh)
            cp_rank = get_mesh_axis_rank(config.cp_axis, mesh)

            # See comment in FusedAttnCPFwdPrimitive.partition for why we define this function.
            def _cross_attn_bwd(
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
                idx,
                q,
                k,
                v,
                bias,
                softmax_aux,
                rng_state,
                output,
                doutput,
                q_seqlen,
                kv_seqlen,
                _q_segment_ids,
                _kv_segment_ids,
                _q_segment_pos,
                _kv_segment_pos,
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
            ):
                kv_max_seqlen = k.shape[1]
                kv_seqlen_per_subrank = kv_max_seqlen // (cp_size * 2)
                assert kv_max_seqlen % cp_size == 0, "sequence length must evenly divide cp size"

                q_split = jnp.split(q, 2, axis=1)
                output_split = jnp.split(output, 2, axis=1)
                doutput_split = jnp.split(doutput, 2, axis=1)
                softmax_aux_split = jnp.split(softmax_aux, 2, axis=2)

                kv_seqlens_for_rank = helper.kv_seqlens_for_rank(
                    idx, kv_max_seqlen, kv_seqlen_per_subrank
                )

                results = []
                for sub_idx in range(2):
Reese Wang's avatar
Reese Wang committed
1527
                    if config.attn_mask_type == AttnMaskType.NO_MASK:
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
                        k_unmasked, v_unmasked = k, v  # full kv used for unmasked
                    else:
                        k_unmasked, v_unmasked = helper.slice_kv(k, v, kv_seqlens_for_rank[sub_idx])

                    q_seqlen_for_step = q_seqlen // (cp_size * 2)
                    num_kv_chunks = kv_max_seqlen // kv_seqlens_for_rank[sub_idx]
                    kv_seqlen_for_step = (kv_seqlen // (cp_size * 2)) * num_kv_chunks

                    dq_local, dk_local, dv_local, dbias_local = FusedAttnBwdPrimitive.impl(
                        q_split[sub_idx],
                        k_unmasked,
                        v_unmasked,
                        bias,
                        softmax_aux_split[sub_idx],
                        rng_state,
                        output_split[sub_idx],
                        doutput_split[sub_idx],
                        q_seqlen_for_step,
                        kv_seqlen_for_step,
                        q_seq_offsets,
                        k_seq_offsets,
1549
1550
1551
1552
                        _q_segment_ids,
                        _kv_segment_ids,
                        _q_segment_pos,
                        _kv_segment_pos,
1553
                        config=helper.get_step_config(),
1554
1555
1556
                    )

                    # pad dk/dv to be unsliced shape so we can reduce scatter over all ranks.
Reese Wang's avatar
Reese Wang committed
1557
                    if config.attn_mask_type != AttnMaskType.NO_MASK:
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
                        pad_length = kv_max_seqlen - kv_seqlens_for_rank[sub_idx]
                        dk_local, dv_local = helper.pad_kv(dk_local, dv_local, pad_length)

                    results.append((dq_local, dk_local, dv_local, dbias_local))

                dq_local = jnp.concatenate((results[0][0], results[1][0]), axis=1)
                dk_local_pad = results[0][1] + results[1][1]
                dv_local_pad = results[0][2] + results[1][2]
                return dq_local, dk_local_pad, dv_local_pad, results[1][3]

            k_ag, v_ag = helper.all_gather_kv(k, v)

            functions = [
                partial(
                    _cross_attn_bwd,
                    idx,
                    q,
                    k_ag,
                    v_ag,
                    bias,
                    softmax_aux,
                    rng_state,
                    output,
                    doutput,
                    q_seqlen,
                    kv_seqlen,
1584
1585
1586
1587
                    _q_segment_ids,
                    _kv_segment_ids,
                    _q_segment_pos,
                    _kv_segment_pos,
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
                )
                for idx in range(cp_size)
            ]

            dq, dk_local, dv_local, dbias = lax.switch(cp_rank, functions)
            dk, dv = helper.reduce_scatter_dkv(dk_local, dv_local)

            return dq, dk, dv, dbias

        return mesh, impl, out_shardings, arg_shardings


register_primitive(FusedAttnCPWithAllGatherBwdPrimitive)


1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
@dataclass(frozen=True)
class _FusedAttnCPWithP2PHelper:
    """Helper class to assist with running the P2P ring strategy for CP attention."""

    mesh: jax.sharding.Mesh
    config: _FusedAttnConfig

    @staticmethod
    def use_scanloop():
        """Returns true if the implementation will use a scan loop for iteration."""
        use_scan = bool(int(os.getenv("NVTE_FUSED_RING_ATTENTION_USE_SCAN", "1")))
1614
        return use_scan
1615
1616
1617
1618
1619

    def check_supported(self):
        """Checks if the context parallel implementation is supported by the given arguments."""
        header = "Context parallel fused ring attention"

Reese Wang's avatar
Reese Wang committed
1620
1621
1622
1623
1624
        if self.config.qkv_layout.is_thd():
            allowed_layouts = [QKVLayout.THD_T2HD, QKVLayout.THD_THD_THD]
        else:
            allowed_layouts = [QKVLayout.BSHD_BS2HD, QKVLayout.BSHD_BSHD_BSHD]

1625
1626
1627
1628
1629
1630
        if self.config.qkv_layout not in allowed_layouts:
            raise ValueError(
                f"{header} only supports layouts:"
                f" {','.join(map(str, allowed_layouts))} got: {self.config.qkv_layout}"
            )

Reese Wang's avatar
Reese Wang committed
1631
        if self.config.attn_bias_type != AttnBiasType.NO_BIAS:
1632
1633
            raise ValueError(f"{header} does not support bias got: {self.config.attn_bias_type}")

Reese Wang's avatar
Reese Wang committed
1634
1635
1636
1637
        if self.config.qkv_layout.is_thd():
            allowed_masks = [AttnMaskType.PADDING_CAUSAL_MASK]
        else:
            allowed_masks = [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]
1638
1639
1640
1641
1642
1643
        if self.config.attn_mask_type not in allowed_masks:
            raise ValueError(
                f"{header} only supports masking types: "
                f" {','.join(map(str, allowed_masks))} got: {self.config.attn_mask_type}"
            )

Reese Wang's avatar
Reese Wang committed
1644
        if not self.config.qkv_layout.is_thd() and self.config.max_segments_per_seq != 1:
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
            raise ValueError(
                f"{header} only supports max_segments_per_seq == 1 got:"
                f" {self.config.max_segments_per_seq}"
            )

        if self.config.dropout_probability != 0.0:
            raise ValueError(f"{header} does not support dropout")

        # We want to encourage use of scan loop to minimize unrolling and ensure more
        # predictable scheduling from XLA. The unrolled flavor will be supported but
        # not the prefered implementation.
        if not self.use_scanloop():
            warnings.warn(
                "Scan loop is disabled for fused ring attention. To enable set"
1659
                " NVTE_FUSED_RING_ATTENTION_USE_SCAN=1 in your environment"
1660
1661
1662
1663
1664
1665
1666
            )

    def get_step_config(self, attn_mask_type) -> _FusedAttnConfig:
        """Returns a _FusedAttnConfig for single CP step call to fused attention."""
        return _FusedAttnConfig(
            attn_bias_type=self.config.attn_bias_type,
            attn_mask_type=attn_mask_type,
Reese Wang's avatar
Reese Wang committed
1667
            qkv_layout=QKVLayout.BSHD_BS2HD,
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
            scaling_factor=self.config.scaling_factor,
            dropout_probability=self.config.dropout_probability,
            is_training=self.config.is_training,
            max_segments_per_seq=self.config.max_segments_per_seq,
            window_size=self.config.window_size,
            context_parallel_load_balanced=self.config.context_parallel_load_balanced,
            cp_axis=self.config.cp_axis,
        )

    def stack_kv(self, k, v):
        """Stacks k and v tensors if not stacked."""
        _not_used = jnp.zeros(0, dtype=k.dtype)
Reese Wang's avatar
Reese Wang committed
1680
1681
1682
1683
        if self.config.qkv_layout.is_kvpacked():
            return k
        if self.config.qkv_layout.is_separate():
            return jnp.stack([k, v], axis=2)
1684
1685
1686
1687
1688
        return _not_used

    def unstack_kv(self, kv):
        """Un-stacks k and v tensors if not stacked."""
        _not_used = jnp.zeros(0, dtype=kv.dtype)
Reese Wang's avatar
Reese Wang committed
1689
1690
1691
1692
        if self.config.qkv_layout.is_kvpacked():
            return kv, _not_used
        if self.config.qkv_layout.is_separate():
            return jnp.unstack(kv, axis=2)
1693
1694
1695
1696
1697
1698
        return _not_used, _not_used  # fall through

    def permute_kv(self, kv, cp_perm):
        """Permutes kv around the ring as described by cp_perm."""
        return lax_paral_op(kv, lax.ppermute, self.config.cp_axis, mesh=self.mesh, perm=cp_perm)

1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
    @staticmethod
    def correct_output_and_softmax_aux(output, softmax_aux, partial_output, partial_softmax_aux):
        """
        Corrects the output and softmax_aux tensor after each iteration of ring attention.

        See https://github.com/zhuzilin/ring-flash-attention/pull/34#issuecomment-2076126795 for
        derivation of this equation.
        """
        new_out = output - jax.nn.sigmoid(partial_softmax_aux - softmax_aux).transpose(
            0, 2, 1, 3
        ) * (output - partial_output)
        new_aux = softmax_aux - jax.nn.log_sigmoid(softmax_aux - partial_softmax_aux)
        return new_out, new_aux
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744

    def adjust_seqlen(self, seqlen, max_seqlen, idx):
        """Adjust the sequence length per step."""
        seqlen_of_curr_step = seqlen - max_seqlen * idx
        seqlen_of_curr_step = jnp.where(seqlen_of_curr_step < 0, 0, seqlen_of_curr_step)
        seqlen_per_step = jnp.where(
            seqlen_of_curr_step < max_seqlen, seqlen_of_curr_step, max_seqlen
        )
        return seqlen_per_step


class FusedRingAttnFwdPrimitive(FusedAttnFwdPrimitive):
    """
    Fused Ring Attention Forward Primitive
    """

    @staticmethod
    def partition(config, mesh, arg_infos, result_infos):
        is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1
        assert (
            not is_context_parallel or config.window_size[0] == -1
        ), "Sliding window attention is not supported when context parallelism is enabled"
        if not is_context_parallel:
            return FusedAttnFwdPrimitive.partition(config, mesh, arg_infos, result_infos)

        helper = _FusedAttnCPWithP2PHelper(mesh, config)
        helper.check_supported()

        out_sharding = result_infos[0].sharding
        softmax_aux_sharding = result_infos[1].sharding
        rng_state_sharding = seed_sharding = NamedSharding(
            mesh, PartitionSpec(get_all_mesh_axes(), None)
        )
1745
1746
1747
        arg_shardings = [arg_i.sharding for arg_i in arg_infos]
        arg_shardings[4] = seed_sharding
        arg_shardings = tuple(arg_shardings)
1748
1749
1750
1751
1752
1753
1754
        out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding)

        def ring_attn_fwd_impl(
            q,
            k,
            v,
            bias,
1755
            seed,
1756
1757
1758
1759
            q_seqlen,
            kv_seqlen,
            q_seq_offsets,
            k_seq_offsets,
1760
1761
1762
1763
            _q_segment_ids,
            _kv_segment_ids,
            _q_segment_pos,
            _kv_segment_pos,
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
        ):
            _not_used = jnp.zeros(0, dtype=v.dtype)

            # Combine KV tensors if separate for better permute scheduling and performance.
            # Eventually XLA should perform this automatically.
            kv = helper.stack_kv(k, v)

            batch, q_max_seqlen, head, _ = q.shape
            kv_max_seqlen = k.shape[1]

            cp_size = get_mesh_axis_size(config.cp_axis, mesh)
            cp_rank = get_mesh_axis_rank(config.cp_axis, mesh)
            cp_perm = [(i, (i + 1) % cp_size) for i in range(cp_size)]

1778
            output = jnp.zeros(q.shape).astype(jnp.float32)
1779
1780
1781
1782
1783
1784
1785
1786
            softmax_aux = jnp.full((batch, head, q_max_seqlen, 1), -jnp.inf, dtype=jnp.float32)

            # RNG shape should be the shared shape. This is unused for ring attention as we do not
            # support dropout currently.
            rng_state_shape = (result_infos[2].shape[0] // mesh.size, *result_infos[2].shape[1:])
            rng_state = jnp.zeros(rng_state_shape).astype(result_infos[2].dtype)

            def scan_kv_block(idx, carry):
1787
                kv, output, softmax_aux = carry
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799

                # Send KV block to next step so we can overlap compute.
                kv_next = helper.permute_kv(kv, cp_perm)

                def mask_compute(attn_mask_type):
                    q_seqlen_per_step = helper.adjust_seqlen(q_seqlen, q_max_seqlen, idx)
                    kv_seqlen_per_step = helper.adjust_seqlen(kv_seqlen, kv_max_seqlen, idx)
                    output_per_step, softmax_aux_per_step, _ = FusedAttnFwdPrimitive.impl(
                        q,
                        kv,
                        _not_used,
                        bias,
1800
                        seed,
1801
1802
1803
1804
                        q_seqlen_per_step,
                        kv_seqlen_per_step,
                        q_seq_offsets,
                        k_seq_offsets,
1805
1806
1807
1808
1809
                        _q_segment_ids,
                        _kv_segment_ids,
                        _q_segment_pos,
                        _kv_segment_pos,
                        config=helper.get_step_config(attn_mask_type),
1810
1811
1812
                    )
                    return output_per_step, softmax_aux_per_step

Reese Wang's avatar
Reese Wang committed
1813
1814
                causal_mask_compute = partial(mask_compute, AttnMaskType.CAUSAL_MASK)
                no_mask_compute = partial(mask_compute, AttnMaskType.NO_MASK)
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824

                def half_kv_no_mask_compute():
                    q_seqlen_per_step = helper.adjust_seqlen(q_seqlen, q_max_seqlen, idx)
                    kv_seqlen_per_step = helper.adjust_seqlen(kv_seqlen, kv_max_seqlen, idx) // 2
                    kv_part = lax.slice_in_dim(kv, 0, kv.shape[1] // 2, axis=1)
                    output_per_step, softmax_aux_per_step, _ = FusedAttnFwdPrimitive.impl(
                        q,
                        kv_part,
                        _not_used,
                        bias,
1825
                        seed,
1826
1827
1828
1829
                        q_seqlen_per_step,
                        kv_seqlen_per_step,
                        q_seq_offsets,
                        k_seq_offsets,
1830
1831
1832
1833
                        _q_segment_ids,
                        _kv_segment_ids,
                        _q_segment_pos,
                        _kv_segment_pos,
Reese Wang's avatar
Reese Wang committed
1834
                        config=helper.get_step_config(AttnMaskType.NO_MASK),
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
                    )
                    return output_per_step, softmax_aux_per_step

                def half_q_no_mask_compute():
                    q_seqlen_per_step = helper.adjust_seqlen(q_seqlen, q_max_seqlen, idx) // 2
                    kv_seqlen_per_step = helper.adjust_seqlen(kv_seqlen, kv_max_seqlen, idx)
                    q_part = lax.slice_in_dim(q, q_max_seqlen // 2, q_max_seqlen, axis=1)
                    output_per_step, softmax_aux_per_step, _ = FusedAttnFwdPrimitive.impl(
                        q_part,
                        kv,
                        _not_used,
                        bias,
1847
                        seed,
1848
1849
1850
1851
                        q_seqlen_per_step,
                        kv_seqlen_per_step,
                        q_seq_offsets,
                        k_seq_offsets,
1852
1853
1854
1855
                        _q_segment_ids,
                        _kv_segment_ids,
                        _q_segment_pos,
                        _kv_segment_pos,
Reese Wang's avatar
Reese Wang committed
1856
                        config=helper.get_step_config(AttnMaskType.NO_MASK),
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
                    )
                    output_per_step = jnp.concat([jnp.zeros_like(q_part), output_per_step], axis=1)
                    softmax_aux_per_step = jnp.concat(
                        [
                            jnp.full_like(softmax_aux_per_step, -jnp.inf),
                            softmax_aux_per_step,
                        ],
                        axis=2,
                    )
                    return output_per_step, softmax_aux_per_step

                def skip_compute():
                    output_per_step = jnp.zeros_like(q)
                    softmax_aux_per_step = jnp.full(
                        (batch, head, q.shape[1], 1), -jnp.inf, dtype=jnp.float32
                    )
                    return output_per_step, softmax_aux_per_step

Reese Wang's avatar
Reese Wang committed
1875
                if config.attn_mask_type == AttnMaskType.CAUSAL_MASK:
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
                    # This is for nested jax.lax.cond
                    def jax_cond_wrap():
                        if config.context_parallel_load_balanced:
                            return lax.cond(
                                (idx <= cp_rank), half_kv_no_mask_compute, half_q_no_mask_compute
                            )
                        return lax.cond((idx <= cp_rank), no_mask_compute, skip_compute)

                    output_per_step, softmax_aux_per_step = lax.cond(
                        idx == 0, causal_mask_compute, jax_cond_wrap
                    )
                else:
                    output_per_step, softmax_aux_per_step = no_mask_compute()

1890
1891
1892
1893
1894
                def skip_correction(output, softmax_aux, output_per_step, softmax_aux_per_step):
                    # No correction done here but we cast outputs to float32 and perform reduction
                    # in full precision.
                    # pylint: disable=unused-argument
                    return output_per_step.astype(jnp.float32), softmax_aux_per_step
1895

1896
1897
1898
1899
                def correction(output, softmax_aux, output_per_step, softmax_aux_per_step):
                    return helper.correct_output_and_softmax_aux(
                        output, softmax_aux, output_per_step, softmax_aux_per_step
                    )
1900

1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
                # first step there is no correction we get initial output and stats
                output, softmax_aux = lax.cond(
                    (idx == 0),
                    skip_correction,
                    correction,
                    output,
                    softmax_aux,
                    output_per_step,
                    softmax_aux_per_step,
                )

                return (kv_next, output, softmax_aux)

            carry = (kv, output, softmax_aux)
1915
1916
1917
1918
1919
            if helper.use_scanloop():
                carry = lax.fori_loop(0, cp_size, scan_kv_block, carry)
            else:
                for i in range(0, cp_size):
                    carry = scan_kv_block(i, carry)
1920
            (kv, output, softmax_aux) = carry
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972

            output = output.astype(q.dtype)
            return output, softmax_aux, rng_state

        return mesh, ring_attn_fwd_impl, out_shardings, arg_shardings


register_primitive(FusedRingAttnFwdPrimitive)


class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive):
    """
    Fused Ring Attention Backward Primitive
    """

    @staticmethod
    def partition(config, mesh, arg_infos, result_infos):
        is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1
        assert (
            not is_context_parallel or config.window_size[0] == -1
        ), "Sliding window attention is not supported when context parallelism is enabled"
        if not is_context_parallel:
            return FusedAttnBwdPrimitive.partition(config, mesh, arg_infos, result_infos)

        del result_infos
        q_spec = get_padded_spec(arg_infos[0])
        k_spec = get_padded_spec(arg_infos[1])
        v_spec = get_padded_spec(arg_infos[2])
        bias_spec = get_padded_spec(arg_infos[3])
        dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
        dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec))
        dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec))
        dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec))
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
        out_shardings = (dq_sharding, dk_sharding, dv_sharding, dbias_sharding)

        helper = _FusedAttnCPWithP2PHelper(mesh, config)
        helper.check_supported()

        def ring_attn_bwd_impl(
            q,
            k,
            v,
            bias,
            softmax_aux,
            rng_state,
            output,
            doutput,
            q_seqlen,
            kv_seqlen,
            q_seq_offsets,
            k_seq_offsets,
1973
1974
1975
1976
            _q_segment_ids,
            _kv_segment_ids,
            _q_segment_pos,
            _kv_segment_pos,
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
        ):
            _not_used = jnp.zeros(0, dtype=output.dtype)

            # Combine KV tensors if separate for better permute scheduling and performance.
            # Eventually XLA should perform this automatically.
            kv = helper.stack_kv(k, v)

            q_max_seqlen = q.shape[1]
            kv_max_seqlen = k.shape[1]

            cp_size = get_mesh_axis_size(config.cp_axis, mesh)
            cp_rank = get_mesh_axis_rank(config.cp_axis, mesh)
            cp_perm = [(i, (i + 1) % cp_size) for i in range(cp_size)]

            dq = jnp.zeros_like(q)
            dk_dv = helper.stack_kv(jnp.zeros_like(k), jnp.zeros_like(v))
            dbias = jnp.zeros_like(bias)

            def scan_kv_block(idx, carry):

                kv, dq, dk_dv, dbias = carry

                # Start communication that feeds the next iteraton.
                # We further combine the tensors to improve overlap.

                kv_dk_dv = jnp.stack([kv, dk_dv])
                kv_dk_dv = helper.permute_kv(kv_dk_dv, cp_perm)

                def mask_compute(attn_mask_type):
                    q_seqlen_per_step = helper.adjust_seqlen(q_seqlen, q_max_seqlen, idx)
                    kv_seqlen_per_step = helper.adjust_seqlen(kv_seqlen, kv_max_seqlen, idx)
                    dq_per_step, dk_dv_per_step, _, dbias_per_step = FusedAttnBwdPrimitive.impl(
                        q,
                        kv,
                        _not_used,
                        bias,
                        softmax_aux,
                        rng_state,
                        output,
                        doutput,
                        q_seqlen_per_step,
                        kv_seqlen_per_step,
                        q_seq_offsets,
                        k_seq_offsets,
2021
2022
2023
2024
                        _q_segment_ids,
                        _kv_segment_ids,
                        _q_segment_pos,
                        _kv_segment_pos,
2025
2026
2027
2028
                        config=helper.get_step_config(attn_mask_type),
                    )
                    return dq_per_step, dk_dv_per_step, dbias_per_step

Reese Wang's avatar
Reese Wang committed
2029
2030
                causal_mask_compute = partial(mask_compute, AttnMaskType.CAUSAL_MASK)
                no_mask_compute = partial(mask_compute, AttnMaskType.NO_MASK)
2031
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
2045
2046
2047
2048

                def half_kv_no_mask_compute():
                    q_seqlen_per_step = helper.adjust_seqlen(q_seqlen, q_max_seqlen, idx)
                    kv_seqlen_per_step = helper.adjust_seqlen(kv_seqlen, kv_max_seqlen, idx) // 2
                    kv_part = lax.slice_in_dim(kv, 0, kv_max_seqlen // 2, axis=1)
                    dq_per_step, dk_dv_per_step, _, dbias_per_step = FusedAttnBwdPrimitive.impl(
                        q,
                        kv_part,
                        _not_used,
                        bias,
                        softmax_aux,
                        rng_state,
                        output,
                        doutput,
                        q_seqlen_per_step,
                        kv_seqlen_per_step,
                        q_seq_offsets,
                        k_seq_offsets,
2049
2050
2051
2052
                        _q_segment_ids,
                        _kv_segment_ids,
                        _q_segment_pos,
                        _kv_segment_pos,
Reese Wang's avatar
Reese Wang committed
2053
                        config=helper.get_step_config(AttnMaskType.NO_MASK),
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
2064
2065
2066
2067
2068
2069
2070
2071
2072
2073
2074
2075
2076
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
                    )
                    dk_dv_per_step = jnp.concat(
                        [dk_dv_per_step, jnp.zeros_like(dk_dv_per_step)], axis=1
                    )
                    return dq_per_step, dk_dv_per_step, dbias_per_step

                def half_q_no_mask_compute():
                    q_seqlen_per_step = helper.adjust_seqlen(q_seqlen, q_max_seqlen, idx) // 2
                    kv_seqlen_per_step = helper.adjust_seqlen(kv_seqlen, kv_max_seqlen, idx)

                    q_part = lax.slice_in_dim(q, q_max_seqlen // 2, q_max_seqlen, axis=1)
                    doutput_part = lax.slice_in_dim(
                        doutput, q_max_seqlen // 2, q_max_seqlen, axis=1
                    )
                    output_part = lax.slice_in_dim(output, q_max_seqlen // 2, q_max_seqlen, axis=1)

                    softmax_aux_part = lax.slice_in_dim(
                        softmax_aux, q_max_seqlen // 2, q_max_seqlen, axis=2
                    )

                    dq_per_step, dk_dv_per_step, _, dbias_per_step = FusedAttnBwdPrimitive.impl(
                        q_part,
                        kv,
                        _not_used,
                        bias,
                        softmax_aux_part,
                        rng_state,
                        output_part,
                        doutput_part,
                        q_seqlen_per_step,
                        kv_seqlen_per_step,
                        q_seq_offsets,
                        k_seq_offsets,
2087
2088
2089
2090
                        _q_segment_ids,
                        _kv_segment_ids,
                        _q_segment_pos,
                        _kv_segment_pos,
Reese Wang's avatar
Reese Wang committed
2091
                        config=helper.get_step_config(AttnMaskType.NO_MASK),
2092
2093
2094
2095
2096
2097
2098
                    )
                    dq_per_step = jnp.concat([jnp.zeros_like(dq_per_step), dq_per_step], axis=1)
                    return dq_per_step, dk_dv_per_step, dbias_per_step

                def skip_compute():
                    return jnp.zeros_like(q), jnp.zeros_like(kv), jnp.zeros_like(bias)

Reese Wang's avatar
Reese Wang committed
2099
                if config.attn_mask_type == AttnMaskType.CAUSAL_MASK:
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115
2116
                    # This is for nested jax.lax.cond
                    def jax_cond_wrap():
                        if config.context_parallel_load_balanced:
                            return lax.cond(
                                (idx <= cp_rank), half_kv_no_mask_compute, half_q_no_mask_compute
                            )
                        return lax.cond((idx <= cp_rank), no_mask_compute, skip_compute)

                    dq_per_step, dk_dv_per_step, dbias_per_step = lax.cond(
                        idx == 0, causal_mask_compute, jax_cond_wrap
                    )
                else:
                    dq_per_step, dk_dv_per_step, dbias_per_step = no_mask_compute()

                kv_next, dk_dv = jnp.unstack(kv_dk_dv)
                dq = dq + dq_per_step
                dk_dv = dk_dv + dk_dv_per_step
Reese Wang's avatar
Reese Wang committed
2117
                if config.attn_bias_type is not AttnBiasType.NO_BIAS:
2118
2119
2120
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
                    dbias = dbias + dbias_per_step

                return (kv_next, dq, dk_dv, dbias)

            carry = (kv, dq, dk_dv, dbias)
            if helper.use_scanloop():
                carry = lax.fori_loop(0, cp_size, scan_kv_block, carry)
            else:
                for i in range(0, cp_size):
                    carry = scan_kv_block(i, carry)
            (kv, dq, dk_dv, dbias) = carry

            # Final permute to put gradients back to their final resting place.
            dk_dv = helper.permute_kv(dk_dv, cp_perm)

            global_dbias = dbias
Reese Wang's avatar
Reese Wang committed
2134
            if config.attn_bias_type is not AttnBiasType.NO_BIAS:
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
                global_dbias = all_reduce_sum_along_dp_fsdp(dbias, mesh)

            dk, dv = helper.unstack_kv(dk_dv)
            return dq, dk, dv, global_dbias

        return mesh, ring_attn_bwd_impl, out_shardings, arg_shardings


register_primitive(FusedRingAttnBwdPrimitive)


Reese Wang's avatar
Reese Wang committed
2146
2147
2148
2149
2150
2151
2152
2153
2154
2155
2156
2157
2158
2159
2160
2161
2162
2163
2164
2165
2166
2167
2168
2169
2170
2171
2172
2173
2174
2175
2176
2177
2178
2179
2180
2181
2182
2183
2184
2185
2186
2187
2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
2198
2199
2200
2201
2202
2203
2204
2205
2206
2207
2208
2209
2210
2211
2212
2213
2214
2215
2216
2217
2218
2219
2220
2221
2222
2223
2224
2225
2226
2227
2228
2229
2230
2231
2232
2233
2234
2235
2236
2237
2238
2239
2240
2241
2242
2243
2244
2245
2246
2247
2248
2249
2250
2251
2252
2253
2254
2255
2256
2257
2258
2259
2260
2261
2262
2263
2264
2265
2266
2267
2268
2269
2270
2271
2272
2273
2274
2275
2276
2277
2278
2279
2280
2281
2282
2283
2284
2285
2286
2287
2288
2289
2290
2291
2292
2293
2294
2295
2296
2297
2298
2299
2300
2301
2302
2303
2304
2305
2306
2307
2308
2309
2310
2311
2312
2313
2314
2315
2316
2317
2318
2319
2320
2321
2322
2323
2324
2325
2326
2327
2328
2329
2330
2331
2332
2333
2334
2335
2336
2337
2338
2339
2340
2341
2342
2343
2344
2345
2346
2347
2348
2349
2350
2351
2352
2353
2354
2355
2356
2357
2358
2359
2360
2361
2362
2363
2364
2365
2366
2367
2368
2369
2370
2371
2372
2373
2374
2375
2376
2377
2378
2379
2380
2381
2382
2383
2384
2385
2386
2387
2388
2389
2390
2391
2392
2393
2394
2395
2396
2397
2398
2399
2400
2401
2402
2403
2404
2405
2406
2407
2408
2409
2410
class FusedRingAttnStripedFwdPrimitive(FusedAttnFwdPrimitive):
    """
    Fused Striped Ring Attention Forward Primitive
    """

    @staticmethod
    def partition(config, mesh, arg_infos, result_infos):
        is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1
        assert (
            not is_context_parallel or config.window_size[0] == -1
        ), "Sliding window attention is not supported when context parallelism is enabled"
        if not is_context_parallel:
            return FusedAttnFwdPrimitive.partition(config, mesh, arg_infos, result_infos)

        helper = _FusedAttnCPWithP2PHelper(mesh, config)
        helper.check_supported()

        out_sharding = result_infos[0].sharding
        softmax_aux_sharding = result_infos[1].sharding
        rng_state_sharding = seed_sharding = NamedSharding(
            mesh, PartitionSpec(get_all_mesh_axes(), None)
        )
        arg_shardings = [arg_i.sharding for arg_i in arg_infos]
        arg_shardings[4] = seed_sharding
        arg_shardings = tuple(arg_shardings)
        out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding)

        def fwd_impl(
            q,
            k,
            v,
            bias,
            seed,
            q_seqlen,
            kv_seqlen,
            q_seq_offsets,
            k_seq_offsets,
            q_segment_ids,
            kv_segment_ids,
            q_segment_pos,
            kv_segment_pos,
        ):
            if q_segment_ids.size == 0 or kv_segment_ids.size == 0:
                raise ValueError("THD + ring attn only supports passing seqment_ids/pos")

            _not_used = jnp.zeros(0, dtype=v.dtype)

            # Combine KV tensors if separate for better permute scheduling and performance.
            # Eventually XLA should perform this automatically.
            kv = helper.stack_kv(k, v)
            if not config.qkv_layout.is_qkvpacked():
                subblock_config = replace(config, qkv_layout=config.qkv_layout.to_kvpacked())
            else:
                subblock_config = config

            cp_size = get_mesh_axis_size(config.cp_axis, mesh)
            cp_perm = [(i, (i + 1) % cp_size) for i in range(cp_size)]

            batch, q_max_seqlen, head, _ = q.shape
            output = jnp.zeros(q.shape).astype(jnp.float32)
            softmax_aux = jnp.zeros((batch, q_max_seqlen, head, 1), dtype=jnp.float32)

            # RNG shape should be the shared shape. This is unused for ring attention as we do not
            # support dropout currently.
            rng_state_shape = (result_infos[2].shape[0] // mesh.size, *result_infos[2].shape[1:])
            rng_state = jnp.zeros(rng_state_shape).astype(result_infos[2].dtype)

            def scan_kv_block(idx, carry):
                kv, kv_segment_ids, kv_segment_pos, output, softmax_aux = carry

                # TODO(rewang): To check whether we need special handle for the last idx
                # Send KV block to next step so we can overlap compute.
                kv_next = helper.permute_kv(kv, cp_perm)
                kv_segment_ids_next = helper.permute_kv(kv_segment_ids, cp_perm)
                kv_segment_pos_next = helper.permute_kv(kv_segment_pos, cp_perm)

                output_per_step, softmax_aux_per_step, _ = FusedAttnFwdPrimitive.impl(
                    q,
                    kv,
                    _not_used,
                    bias,
                    seed,
                    q_seqlen,
                    kv_seqlen,
                    q_seq_offsets,
                    k_seq_offsets,
                    q_segment_ids,
                    kv_segment_ids,
                    q_segment_pos,
                    kv_segment_pos,
                    subblock_config,
                )

                # TODO(rewang): THD softmax_aux layout is acutally [B, S, H]
                softmax_aux_per_step = softmax_aux_per_step.reshape((batch, q_max_seqlen, head, 1))

                def skip_correction(_output, _softmax_aux, output_per_step, softmax_aux_per_step):
                    # No correction done here but we cast outputs to float32 and perform reduction
                    # in full precision.
                    return output_per_step.astype(jnp.float32), softmax_aux_per_step

                def correction(output, softmax_aux, output_per_step, softmax_aux_per_step):
                    new_out = output - jax.nn.sigmoid(softmax_aux_per_step - softmax_aux) * (
                        output - output_per_step
                    )
                    new_aux = softmax_aux - jax.nn.log_sigmoid(softmax_aux - softmax_aux_per_step)
                    return new_out, new_aux

                # first step there is no correction we get initial output and stats
                output, softmax_aux = lax.cond(
                    idx == 0,
                    skip_correction,
                    correction,
                    output,
                    softmax_aux,
                    output_per_step,
                    softmax_aux_per_step,
                )

                return (kv_next, kv_segment_ids_next, kv_segment_pos_next, output, softmax_aux)

            carry = (kv, kv_segment_ids, kv_segment_pos, output, softmax_aux)
            if helper.use_scanloop():
                carry = lax.fori_loop(0, cp_size, scan_kv_block, carry)
            else:
                for i in range(0, cp_size):
                    carry = scan_kv_block(i, carry)
            (_, _, _, output, softmax_aux) = carry

            softmax_aux = softmax_aux.reshape((batch, head, q_max_seqlen, 1))

            return output.astype(q.dtype), softmax_aux, rng_state

        return mesh, fwd_impl, out_shardings, arg_shardings


register_primitive(FusedRingAttnStripedFwdPrimitive)


class FusedRingAttnStripedBwdPrimitive(FusedAttnBwdPrimitive):
    """
    Fused Striped Ring Attention Backward Primitive
    """

    @staticmethod
    def partition(config, mesh, arg_infos, result_infos):
        is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1
        assert (
            not is_context_parallel or config.window_size[0] == -1
        ), "Sliding window attention is not supported when context parallelism is enabled"
        if not is_context_parallel:
            return FusedAttnBwdPrimitive.partition(config, mesh, arg_infos, result_infos)

        arg_shardings = tuple(arg.sharding for arg in arg_infos)
        # dq, dk, dv, dbias sharding = q, k, v, bias sharding
        out_shardings = tuple(arg.sharding for arg in arg_infos[:4])

        helper = _FusedAttnCPWithP2PHelper(mesh, config)
        helper.check_supported()

        def bwd_impl(
            q,
            k,
            v,
            bias,
            softmax_aux,
            rng_state,
            output,
            doutput,
            q_seqlen,
            kv_seqlen,
            q_seq_offsets,
            k_seq_offsets,
            q_segment_ids,
            kv_segment_ids,
            q_segment_pos,
            kv_segment_pos,
        ):

            if q_segment_ids.size == 0 or kv_segment_ids.size == 0:
                raise ValueError("THD + ring attn only supports passing seqment_ids/pos")

            _not_used = jnp.zeros(0, dtype=output.dtype)

            # Combine KV tensors if separate for better permute scheduling and performance.
            # Eventually XLA should perform this automatically.
            kv = helper.stack_kv(k, v)
            if not config.qkv_layout.is_qkvpacked():
                subblock_config = replace(config, qkv_layout=config.qkv_layout.to_kvpacked())
            else:
                subblock_config = config

            cp_size = get_mesh_axis_size(config.cp_axis, mesh)
            cp_perm = [(i, (i + 1) % cp_size) for i in range(cp_size)]

            dq = jnp.zeros_like(q)
            dkv = jnp.zeros_like(kv)
            dbias = jnp.zeros_like(bias)

            def scan_kv_block(_idx, carry):
                kv, kv_segment_ids, kv_segment_pos, dq, dkv, dbias = carry

                # Start communication that feeds the next iteration.
                # We further combine the tensors to improve overlap.
                kv_dkv = jnp.stack([kv, dkv])
                kv_dkv = helper.permute_kv(kv_dkv, cp_perm)
                kv_segment_ids_next = helper.permute_kv(kv_segment_ids, cp_perm)
                kv_segment_pos_next = helper.permute_kv(kv_segment_pos, cp_perm)

                def compute():
                    dq_per_step, dkv_per_step, _, dbias_per_step = FusedAttnBwdPrimitive.impl(
                        q,
                        kv,
                        _not_used,
                        bias,
                        softmax_aux,
                        rng_state,
                        output,
                        doutput,
                        q_seqlen,
                        kv_seqlen,
                        q_seq_offsets,
                        k_seq_offsets,
                        q_segment_ids,
                        kv_segment_ids,
                        q_segment_pos,
                        kv_segment_pos,
                        config=subblock_config,
                    )
                    return dq_per_step, dkv_per_step, dbias_per_step

                dq_per_step, dkv_per_step, dbias_per_step = compute()

                kv_next, dkv = jnp.unstack(kv_dkv)
                dq += dq_per_step
                dkv += dkv_per_step
                if config.attn_bias_type is not AttnBiasType.NO_BIAS:
                    dbias = dbias + dbias_per_step

                return (kv_next, kv_segment_ids_next, kv_segment_pos_next, dq, dkv, dbias)

            carry = (kv, kv_segment_ids, kv_segment_pos, dq, dkv, dbias)
            if helper.use_scanloop():
                carry = lax.fori_loop(0, cp_size, scan_kv_block, carry)
            else:
                for idx in range(cp_size):
                    carry = scan_kv_block(idx, carry)
            (_, _, _, dq, dkv, dbias) = carry

            # Final permute to put gradients back to their final resting place.
            dkv = helper.permute_kv(dkv, cp_perm)

            global_dbias = dbias
            if config.attn_bias_type is not AttnBiasType.NO_BIAS:
                global_dbias = all_reduce_sum_along_dp_fsdp(dbias, mesh)

            dk, dv = helper.unstack_kv(dkv)
            return dq, dk, dv, global_dbias

        return mesh, bwd_impl, out_shardings, arg_shardings


register_primitive(FusedRingAttnStripedBwdPrimitive)


2411
2412
2413
2414
2415
2416
2417
2418
2419
2420
def _maybe_context_parallel_axis(cp_axis: str):
    if not cp_axis:
        gmr = global_mesh_resource()
        if gmr is not None:
            cp_axis = gmr.cp_resource
        else:
            cp_axis = ""
    return cp_axis


2421
2422
2423
def fused_attn_fwd(
    qkv: Tuple[jnp.ndarray, ...],
    bias: Optional[jnp.ndarray],
2424
    sequence_descriptor: SequenceDescriptor,
2425
    seed: Optional[jnp.ndarray],
Reese Wang's avatar
Reese Wang committed
2426
2427
2428
    attn_bias_type: AttnBiasType,
    attn_mask_type: AttnMaskType,
    qkv_layout: QKVLayout,
2429
2430
2431
    scaling_factor: float,
    dropout_probability: float,
    is_training: bool,
2432
    max_segments_per_seq: int,
2433
    window_size: Optional[Tuple[int, int]] = None,
2434
    context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT,
2435
2436
    context_parallel_causal_load_balanced: bool = False,
    context_parallel_axis: str = "",
2437
) -> jnp.ndarray:
2438
    """
2439
2440
2441
2442
2443
2444
2445
2446
2447
2448
2449
2450
2451
2452
2453
2454
2455
2456
2457
2458
    Perform the forward pass of with cuDNN fused attention implementations.

    This function implements the following formula:
        BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2
    Args:
        qkv (Tuple[jnp.ndarray, ...]): A tuple containing query, key, and value tensors.
        It supports three formats:
            - `(qkv_packed,)`: For interleaved QKV packed format, typically used when query, key,
              and value have the same shape (e.g., self-attention).
            - `(query, kv_packed)`: For separate query and KV packed format, typically used when
              query has a different shape (e.g., cross-attention).
            - `(query, key, value)`: For separate query, key, and value tensors.
        bias (Optional[jnp.ndarray]): An optional bias tensor to be added to the attention scores.
        q_seqlen (jnp.ndarray): Sequence lengths for the query, with shape [batch,].
        kv_seqlen (jnp.ndarray): Sequence lengths for the key and value, with shape [batch,].
        q_seq_offsets (jnp.ndarray):
            The offsets in the sequence dim for the query, with shape [batch + 1,].
        kv_seq_offsets (jnp.ndarray):
            The offsets in the sequence dim for the query, with shape [batch + 1,].
        seed (Optional[jnp.ndarray]): Optional random seed for dropout.
Reese Wang's avatar
Reese Wang committed
2459
2460
2461
        attn_bias_type (AttnBiasType): Type of attention bias.
        attn_mask_type (AttnMaskType): Type of attention mask.
        qkv_layout (QKVLayout): Layout of the QKV tensors.
2462
2463
2464
        scaling_factor (float): Scaling factor for the attention scores.
        dropout_probability (float): Dropout probability to apply during attention.
        is_training (bool): Flag indicating whether the model is in training mode.
2465
2466
2467
2468
2469
        max_segments_per_seq (int):
            Indicating the maximum number of segments inside a sequence. This parameter is to
            constrain the limit usage and need to be static during the e2e training. The XLA compile
            time and memory consumption is proportional to `max_segments_per_seq`.
        window_size (Optional[Tuple[int, int]]): Sliding window size.
2470
2471
2472
        context_parallel_causal_load_balanced (bool):
            Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
        context_parallel_axis (str): The name of the context parallel axis.
2473
2474
    Returns:
        (jnp.ndarray): The output tensor from the fused attention.
2475
    """
2476
2477
2478
    seed = _FusedAttnRNGStateChecker().check_seed(seed, dropout_probability, is_training)
    # For optional tensors, which custom calls doesn't support None
    _not_used = jnp.zeros(0, dtype=qkv[0].dtype)
2479

Reese Wang's avatar
Reese Wang committed
2480
2481
2482
2483
2484
2485
2486
2487
2488
2489
2490
2491
2492
2493
2494
2495
2496
    if qkv_layout.is_qkvpacked():
        assert len(qkv) == 1, f"qkv=(packed_qkv,) is expected with {qkv_layout=} but got {qkv=}"
        qkv_for_primitive = [*qkv, _not_used, _not_used]
    elif qkv_layout.is_kvpacked():
        assert (
            len(qkv) == 2
        ), f"qkv=(query, packed_kv) is expected with {qkv_layout=} but got {qkv=}"
        qkv_for_primitive = [*qkv, _not_used]
    elif qkv_layout.is_separate():
        assert (
            len(qkv) == 3
        ), f"qkv=(query, key, value) is expected with {qkv_layout=} but got {qkv=}"
        qkv_for_primitive = qkv
    else:
        raise ValueError(f"Unknown {qkv_layout=}")

    if attn_bias_type == AttnBiasType.NO_BIAS:
2497
        assert bias is None
2498
        bias = jnp.zeros(0, dtype=qkv[0].dtype)
2499

2500
    fused_config = _FusedAttnConfig(
2501
2502
        attn_bias_type=attn_bias_type,
        attn_mask_type=attn_mask_type,
2503
        qkv_layout=qkv_layout,
2504
2505
2506
        scaling_factor=scaling_factor,
        dropout_probability=dropout_probability,
        is_training=is_training,
2507
        max_segments_per_seq=max_segments_per_seq,
2508
        window_size=(-1, -1) if window_size is None else window_size,
2509
2510
2511
2512
        context_parallel_load_balanced=context_parallel_causal_load_balanced,
        cp_axis=_maybe_context_parallel_axis(context_parallel_axis),
    )

2513
    primitive = None
2514
2515
    match context_parallel_strategy:
        case CPStrategy.DEFAULT | CPStrategy.ALL_GATHER:
2516
            primitive = FusedAttnCPWithAllGatherFwdPrimitive.outer_primitive
2517
        case CPStrategy.RING:
Reese Wang's avatar
Reese Wang committed
2518
2519
2520
2521
2522
            # We must use stripe attention for THD-RING
            if qkv_layout.is_thd():
                primitive = FusedRingAttnStripedFwdPrimitive.outer_primitive
            else:
                primitive = FusedRingAttnFwdPrimitive.outer_primitive
2523

2524
2525
    seq_desc_flatten, _ = jax.tree.flatten(sequence_descriptor)
    return primitive.bind(
2526
2527
2528
        *qkv_for_primitive,
        bias,
        seed,
2529
        *seq_desc_flatten,
2530
        config=fused_config,
2531
2532
2533
    )


2534
2535
2536
def fused_attn_bwd(
    qkv: Tuple[jnp.ndarray, ...],
    bias: Optional[jnp.ndarray],
2537
2538
2539
2540
    softmax_aux: jnp.ndarray,
    rng_state: jnp.ndarray,
    output: jnp.ndarray,
    doutput: jnp.ndarray,
2541
    sequence_descriptor: SequenceDescriptor,
Reese Wang's avatar
Reese Wang committed
2542
2543
2544
    attn_bias_type: AttnBiasType,
    attn_mask_type: AttnMaskType,
    qkv_layout: QKVLayout,
2545
2546
2547
    scaling_factor: float,
    dropout_probability: float,
    is_training: bool,
2548
    max_segments_per_seq: int,
2549
    window_size: Optional[Tuple[int, int]] = None,
2550
    context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT,
2551
2552
    context_parallel_causal_load_balanced: bool = False,
    context_parallel_axis: str = "",
2553
):
2554
    """
2555
2556
2557
2558
2559
2560
2561
2562
2563
2564
2565
2566
2567
2568
2569
2570
2571
2572
2573
2574
2575
    Perform the backward pass of the cuDNN fused attention implementations.

    Args:
        qkv (Tuple[jnp.ndarray, ...]): A tuple containing the original query, key, and value tensors
        used in the forward pass. It supports three formats:
            - `(qkv_packed,)`: For interleaved QKV packed format, typically used when query, key,
              and value have the same shape (e.g., self-attention).
            - `(query, kv_packed)`: For separate query and KV packed format, typically used when
              query has a different shape (e.g., cross-attention).
            - `(query, key, value)`: For separate query, key, and value tensors.
        bias (Optional[jnp.ndarray]): An optional bias tensor to be added to the attention scores.
        softmax_aux (jnp.ndarray): Auxiliary tensors from the softmax step used in the forward pass.
        rng_state (jnp.ndarray): Auxiliary tensors to save the random state in the forward pass.
        output (jnp.ndarray): The output tensor from the forward pass.
        doutput (jnp.ndarray): The gradient with respect to the output.
        q_seqlen (jnp.ndarray): Sequence lengths for the query, with shape [batch,].
        kv_seqlen (jnp.ndarray): Sequence lengths for the key and value, with shape [batch,].
        q_seq_offsets (jnp.ndarray):
            The offsets in the sequence dim for the query, with shape [batch + 1,].
        kv_seq_offsets (jnp.ndarray):
            The offsets in the sequence dim for the query, with shape [batch + 1,].
Reese Wang's avatar
Reese Wang committed
2576
2577
2578
        attn_bias_type (AttnBiasType): Type of attention bias.
        attn_mask_type (AttnMaskType): Type of attention mask.
        qkv_layout (QKVLayout): Layout of the QKV tensors.
2579
2580
2581
        scaling_factor (float): Scaling factor for the attention scores.
        dropout_probability (float): Dropout probability to apply during attention.
        is_training (bool): Flag indicating whether the model is in training mode.
2582
2583
2584
2585
2586
        max_segments_per_seq (int):
            Indicating the maximum number of segments inside a sequence. This parameter is to
            constrain the limit usage and need to be static during the e2e training. The XLA compile
            time and memory consumption is proportional to `max_segments_per_seq`.
        window_size (Optional[Tuple[int, int]]): Sliding window size .
2587
2588
2589
        context_parallel_causal_load_balanced (bool):
            Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
        context_parallel_axis (str): The name of the context parallel axis.
2590
2591
2592
2593
2594
    Returns:
        Tuple[jnp.ndarray, ...], jnp.ndarray:
        - The first tuple contains the gradients with respect to the input `qkv` tensors in the
          same format as the input `qkv`.
        - The second value is the gradient with respect to `bias`, or `None` if `bias` is `None`.
2595
    """
2596
2597
2598
    # For optional tensors, which custom calls doesn't support None
    _not_used = jnp.zeros(0, dtype=qkv[0].dtype)

Reese Wang's avatar
Reese Wang committed
2599
2600
2601
2602
2603
2604
2605
2606
2607
2608
2609
2610
2611
2612
2613
2614
2615
    if qkv_layout.is_qkvpacked():
        assert len(qkv) == 1, f"qkv=(packed_qkv,) is expected with {qkv_layout=} but got {qkv=}"
        qkv_for_primitive = [*qkv, _not_used, _not_used]
    elif qkv_layout.is_kvpacked():
        assert (
            len(qkv) == 2
        ), f"qkv=(query, packed_kv) is expected with {qkv_layout=} but got {qkv=}"
        qkv_for_primitive = [*qkv, _not_used]
    elif qkv_layout.is_separate():
        assert (
            len(qkv) == 3
        ), f"qkv=(query, key, value) is expected with {qkv_layout=} but got {qkv=}"
        qkv_for_primitive = qkv
    else:
        raise ValueError(f"Unknown {qkv_layout=}")

    if attn_bias_type == AttnBiasType.NO_BIAS:
2616
        assert bias is None
2617
        bias = jnp.zeros(0, dtype=qkv[0].dtype)
2618

2619
2620
2621
2622
2623
2624
2625
2626
    fused_config = _FusedAttnConfig(
        attn_bias_type=attn_bias_type,
        attn_mask_type=attn_mask_type,
        qkv_layout=qkv_layout,
        scaling_factor=scaling_factor,
        dropout_probability=dropout_probability,
        is_training=is_training,
        max_segments_per_seq=max_segments_per_seq,
2627
        window_size=(-1, -1) if window_size is None else window_size,
2628
2629
2630
2631
        context_parallel_load_balanced=context_parallel_causal_load_balanced,
        cp_axis=_maybe_context_parallel_axis(context_parallel_axis),
    )

2632
    primitive = None
2633
2634
    match context_parallel_strategy:
        case CPStrategy.DEFAULT | CPStrategy.ALL_GATHER:
2635
            primitive = FusedAttnCPWithAllGatherBwdPrimitive.outer_primitive
2636
        case CPStrategy.RING:
Reese Wang's avatar
Reese Wang committed
2637
2638
2639
2640
            if qkv_layout.is_thd():
                primitive = FusedRingAttnStripedBwdPrimitive.outer_primitive
            else:
                primitive = FusedRingAttnBwdPrimitive.outer_primitive
2641
2642
2643

    seq_desc_flatten, _ = jax.tree.flatten(sequence_descriptor)
    *qkv_grads, bias_grad = primitive.bind(
2644
        *qkv_for_primitive,
2645
2646
2647
2648
2649
        bias,
        softmax_aux,
        rng_state,
        output,
        doutput,
2650
        *seq_desc_flatten,
2651
        config=fused_config,
2652
    )
2653
    return tuple(qkv_grads[: len(qkv)]), bias_grad