attention.py 102 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.
"""JAX/TE custom ops for attention"""
Reese Wang's avatar
Reese Wang committed
5
from dataclasses import dataclass, replace
6
from functools import partial, reduce
7
import operator
8
import os
9
from typing import Optional, Tuple
10
11
import warnings

12
import jax
13
import jax.numpy as jnp
14
from jax import dtypes, lax
15
16
17
from jax.interpreters import mlir
from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding
18
from jax import ffi
19

Reese Wang's avatar
Reese Wang committed
20
21
22
23
24
25
26
27
from transformer_engine.jax.attention import (
    AttnBiasType,
    AttnMaskType,
    QKVLayout,
    QKVFormat,
    CPStrategy,
    SequenceDescriptor,
)
28

29
from transformer_engine import transformer_engine_jax
Reese Wang's avatar
Reese Wang committed
30
31
from transformer_engine.transformer_engine_jax import NVTE_Fused_Attn_Backend

32
33
34
35
36
37
from .base import BasePrimitive, register_primitive
from .custom_call import custom_caller, CustomCallArgsWrapper
from .misc import (
    check_valid_batch_dims,
    jax_dtype_to_te_dtype,
    te_dtype_to_jax_dtype,
38
    get_padded_spec,
39
    get_cudnn_version,
40
    is_ffi_enabled,
41
    get_xla_flag,
42
43
)
from ..sharding import (
44
45
    global_mesh_resource,
    lax_paral_op,
46
    all_reduce_sum_along_dp_fsdp,
47
48
    get_mesh_axis_size,
    get_mesh_axis_rank,
49
50
51
52
53
    get_all_mesh_axes,
    num_of_devices,
)


54
55
56
57
58
__all__ = [
    "FusedAttnHelper",
    "fused_attn_fwd",
    "fused_attn_bwd",
]
59
60


61
62
63
64
65
66
67
68
69
70
71
@partial(
    jax.tree_util.register_dataclass,
    data_fields=[],
    meta_fields=[
        "attn_bias_type",
        "attn_mask_type",
        "qkv_layout",
        "scaling_factor",
        "dropout_probability",
        "is_training",
        "max_segments_per_seq",
72
        "window_size",
73
74
75
76
77
78
79
80
81
82
        "context_parallel_load_balanced",
        "cp_axis",
    ],
)
@dataclass(frozen=True)
class _FusedAttnConfig:
    """
    Passes static configuration properties of fused attention.
    """

Reese Wang's avatar
Reese Wang committed
83
84
85
    attn_bias_type: AttnBiasType
    attn_mask_type: AttnMaskType
    qkv_layout: QKVLayout
86
87
88
89
    scaling_factor: float
    dropout_probability: float
    is_training: bool
    max_segments_per_seq: int
90
    window_size: Tuple[int, int]
91
92
93
94
    context_parallel_load_balanced: bool
    cp_axis: str


95
96
97
98
99
100
101
102
@dataclass(frozen=True)
class FusedAttnHelper:
    """
    Helper for the fused attention backend
    """

    q_dtype: jnp.dtype
    kv_dtype: jnp.dtype
Reese Wang's avatar
Reese Wang committed
103
104
105
    qkv_layout: QKVLayout
    attn_bias_type: AttnBiasType
    attn_mask_type: AttnMaskType
106
107
108
109
110
111
    dropout_probability: float
    q_num_heads: int
    kv_num_heads: int
    q_max_seqlen: int
    kv_max_seqlen: int
    head_dim: int
112
    window_size: Tuple[int, int]
113
114
115
116
117
118
119
120

    def is_fused_attn_kernel_available(self):
        """Check if there is available fused attention kernel"""
        return self.get_fused_attn_backend() != NVTE_Fused_Attn_Backend.NVTE_No_Backend

    def get_fused_attn_backend(self):
        """Get the fused attention kernel backend"""
        return transformer_engine_jax.get_fused_attn_backend(
121
122
            jax_dtype_to_te_dtype(self.q_dtype),
            jax_dtype_to_te_dtype(self.kv_dtype),
Reese Wang's avatar
Reese Wang committed
123
124
125
            self.qkv_layout.value,
            self.attn_bias_type.value,
            self.attn_mask_type.value,
126
127
128
129
130
131
            self.dropout_probability,
            self.q_num_heads,
            self.kv_num_heads,
            self.q_max_seqlen,
            self.kv_max_seqlen,
            self.head_dim,
132
133
            self.window_size[0],
            self.window_size[1],
134
        )
135

136
137
138
139
140
    @staticmethod
    def is_non_deterministic_allowed():
        """Check if non-deterministic kernels are allowed"""
        return bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1")))

141
142
143
    @staticmethod
    def parse_qkv_aval(q_aval, k_aval, v_aval, qkv_layout):
        """Parse qkv aval"""
Reese Wang's avatar
Reese Wang committed
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
        if qkv_layout.get_qkv_format() == QKVFormat.SBHD:
            raise NotImplementedError
        if qkv_layout.is_qkvpacked():
            *q_batch_shape, q_max_seqlen, nqkv, attn_heads, q_head_dim = q_aval.shape
            kv_batch_shape = q_batch_shape
            kv_max_seqlen = q_max_seqlen
            num_gqa_groups = attn_heads
            kv_head_dim = q_head_dim
            assert nqkv == 3
        elif qkv_layout.is_kvpacked():
            *q_batch_shape, q_max_seqlen, attn_heads, q_head_dim = q_aval.shape
            *kv_batch_shape, kv_max_seqlen, nkv, num_gqa_groups, kv_head_dim = k_aval.shape
            assert nkv == 2
        elif qkv_layout.is_separate():
            *q_batch_shape, q_max_seqlen, attn_heads, q_head_dim = q_aval.shape
            *kv_batch_shape, kv_max_seqlen, num_gqa_groups, kv_head_dim = k_aval.shape
            assert k_aval.shape == v_aval.shape, f"{k_aval.shape=} {v_aval.shape=}"
        else:
            raise ValueError(f"Unexpected {qkv_layout=}")
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
        assert q_batch_shape == kv_batch_shape
        assert q_head_dim == kv_head_dim
        assert q_aval.dtype == k_aval.dtype == v_aval.dtype

        return (q_batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, q_head_dim)


@dataclass(frozen=True)
class _FusedAttnRNGStateChecker:
    """
    Checker for guarding the fused attention rng state.
    The fused attention backend requires a 64 bits seed and a 64 bits offset.
    However, JAX doesn't enable 64 bits by default,
    so we have to emulate seed as two 32 bits array.
    The offset calculation is maintained in the backend.
    """
179

180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
    rng_state_dtype: jnp.dtype = jnp.uint32
    # (seed,) with internal dtype int64
    seed_size: int = 2
    # (seed, offset) with internal dtype int64
    rng_state_size: int = 2 * 2

    def check_seed(self, seed, dropout_probability, is_training):
        """
        Check the seed and convert the data type of seed if possible.
        """
        # Jax can't bind None, create a dummy tensor for None
        if seed is None:
            dropout_enabled = dropout_probability > 0 and is_training
            assert not dropout_enabled, "seed is not allowed to be None when dropout is enabled."
            seed = jnp.zeros(2, dtype=self.rng_state_dtype)
            seed = jnp.repeat(seed, num_of_devices())

        if seed.dtype != self.rng_state_dtype:
            warnings.warn(
                f"Requested {seed.dtype=} is not available, and will be "
                f"casted to dtype {self.rng_state_dtype}. "
201
202
                "Please use threefry/rbg/unsafe_rbg PRNG implementations to remove this warning."
            )
203
204
205
206
207
208
209
210
211
212
213
214
215
            seed = seed.astype(self.rng_state_dtype)

        assert seed.dtype == self.rng_state_dtype
        # Backend takes an int64_t seed, so only the first two u32 elements are taken
        assert seed.size >= self.seed_size

        return seed


def generate_cu_seqlen(actual_seqlen):
    """
    Generating cumsum seqlen for a batch
    """
216
217
    actual_seqlen = jnp.where(actual_seqlen < 0, 0, actual_seqlen)
    cu_seqlen = jnp.cumulative_sum(actual_seqlen, include_initial=True)
218
219
220
221
222
223
224
    return cu_seqlen


class FusedAttnFwdPrimitive(BasePrimitive):
    """
    Fused Attention Forward Primitive
    """
225

226
227
    name = "te_fused_attn_forward"
    multiple_results = True
228
    impl_static_args = (13,)
229
230
231
232
    inner_primitive = None
    outer_primitive = None

    @staticmethod
233
234
235
236
237
    def abstract(
        q_aval,
        k_aval,
        v_aval,
        bias_aval,
238
        seed_aval,
239
240
        q_seqlen_or_cu_seqlen_aval,
        kv_seqlen_or_cu_seqlen_aval,
241
242
        _q_seq_offsets,
        _k_seq_offsets,
243
244
245
246
        _q_segment_ids,
        _kv_segment_ids,
        _q_segment_pos,
        _kv_segment_pos,
247
        *,
248
        config: _FusedAttnConfig,
249
    ):
250
251
252
253
254
255
256
        """
        Fused attention fwd abstract
        """
        q_dtype = dtypes.canonicalize_dtype(q_aval.dtype)
        k_dtype = dtypes.canonicalize_dtype(k_aval.dtype)
        v_dtype = dtypes.canonicalize_dtype(v_aval.dtype)
        bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype)
257
258
259
260
261
262
263
        assert (
            q_dtype == k_dtype == v_dtype == bias_dtype
        ), f"q_dtype={q_dtype}, k_dtype={k_dtype}, v_dtype={v_dtype}, bias_dtype={bias_dtype}"
        assert q_seqlen_or_cu_seqlen_aval.dtype == kv_seqlen_or_cu_seqlen_aval.dtype, (
            f"q_seqlen_or_cu_seqlen_aval={q_seqlen_or_cu_seqlen_aval},"
            f" kv_seqlen_or_cu_seqlen_aval={kv_seqlen_or_cu_seqlen_aval}"
        )
264

265
        batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = (
266
            FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout)
267
        )
268
269
270
271
272

        output_shape = (*batch_shape, q_max_seqlen, attn_heads, head_dim)
        out_aval = q_aval.update(shape=output_shape, dtype=q_dtype)

        # backend determines the softmax buffer shape/dtype
273
274
275
        backend = FusedAttnHelper(
            q_dtype,
            k_dtype,
276
277
278
279
            config.qkv_layout,
            config.attn_bias_type,
            config.attn_mask_type,
            config.dropout_probability,
280
281
282
283
284
            attn_heads,
            num_gqa_groups,
            q_max_seqlen,
            kv_max_seqlen,
            head_dim,
285
            config.window_size,
286
        ).get_fused_attn_backend()
287
288
289
290
291

        if backend == NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen:
            softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, kv_max_seqlen)
            softmax_dtype = q_dtype
        elif backend == NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen:
292
293
294
295
296
297
298
299
300
301
            # cuDNN 9.6 reduces the required softmax shape
            if get_cudnn_version() >= (9, 6, 0):
                softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, 1)
            else:
                softmax_shape = (
                    *batch_shape,
                    attn_heads,
                    q_max_seqlen,
                    config.max_segments_per_seq,
                )
302
303
            softmax_dtype = dtypes.canonicalize_dtype(jnp.float32)
        else:
304
            raise ValueError(f"Unsupported {backend=}")
305
306
307
308
309
310
311
312
313
314
        softmax_aux_aval = q_aval.update(shape=softmax_shape, dtype=softmax_dtype)

        # JAX does not enable 64-bit int by default so we get XLA to allocate x8 memory with
        # 32-bit unsigned int to get the buffer size we need in the C++ kernel
        checker = _FusedAttnRNGStateChecker()
        seed_dtype = dtypes.canonicalize_dtype(seed_aval.dtype)
        assert seed_dtype == checker.rng_state_dtype
        rng_state_shape = (seed_aval.shape[0], checker.rng_state_size)
        rng_state_aval = seed_aval.update(shape=rng_state_shape, dtype=checker.rng_state_dtype)

Reese Wang's avatar
Reese Wang committed
315
        if config.attn_bias_type == AttnBiasType.NO_BIAS:
316
317
318
319
320
321
322
323
324
            bias_batch = bias_heads = 0
        else:
            *bias_batch_shape, bias_heads, _, _ = bias_aval.shape
            bias_batch = reduce(operator.mul, bias_batch_shape)

        # do a dummy kernel call here to get workspace buffer shapes/dtypes that XLA needs to
        # prepare for the active fused-attn backend
        input_batch = reduce(operator.mul, batch_shape)
        wkspace_info = transformer_engine_jax.get_fused_attn_fwd_workspace_sizes(
325
326
327
328
329
330
331
332
            input_batch,
            bias_batch,
            q_max_seqlen,
            kv_max_seqlen,
            attn_heads,
            num_gqa_groups,
            bias_heads,
            head_dim,
333
334
            config.scaling_factor,
            config.dropout_probability,
Reese Wang's avatar
Reese Wang committed
335
336
337
            config.attn_bias_type.value,
            config.attn_mask_type.value,
            config.qkv_layout.value,
338
            jax_dtype_to_te_dtype(q_aval.dtype),
339
340
            config.is_training,
            config.max_segments_per_seq,
341
342
            config.window_size[0],
            config.window_size[1],
343
344
345
346
        )
        wkspace_aval = q_aval.update(
            shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1])
        )
347
348
349
350
351
352
353
354

        return out_aval, softmax_aux_aval, rng_state_aval, wkspace_aval

    @staticmethod
    def outer_abstract(*args, **kwargs):
        """
        Fused attention fwd outer primitive abstract
        """
355
356
357
        out_aval, softmax_aux_aval, rng_state_aval, _ = FusedAttnFwdPrimitive.abstract(
            *args, **kwargs
        )
358
359
360
        return out_aval, softmax_aux_aval, rng_state_aval

    @staticmethod
361
362
363
364
365
366
    def lowering(
        ctx,
        q,
        k,
        v,
        bias,
367
        seed,
368
369
        q_cu_seqlen,
        kv_cu_seqlen,
370
371
        q_seq_offsets,
        k_seq_offsets,
372
373
374
375
        _q_segment_ids,
        _kv_segment_ids,
        _q_segment_pos,
        _kv_segment_pos,
376
        *,
377
        config: _FusedAttnConfig,
378
    ):
379
380
381
382
383
        """
        Fused attention fwd lowering rules
        """
        q_aval, k_aval, v_aval, bias_aval, *_ = ctx.avals_in

384
        batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = (
385
            FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout)
386
        )
387
388
389

        input_batch = reduce(operator.mul, batch_shape)

Reese Wang's avatar
Reese Wang committed
390
        if config.attn_bias_type == AttnBiasType.NO_BIAS:
391
392
393
394
395
            bias_batch = bias_heads = 0
        else:
            *bias_batch_shape, bias_heads, _, _ = bias_aval.shape
            bias_batch = reduce(operator.mul, bias_batch_shape)

396
        if is_ffi_enabled():
397
398
399
400
401
402
403
            name = "te_fused_attn_forward_ffi"
            out = ffi.ffi_lowering(name)(
                ctx,
                q,
                k,
                v,
                bias,
404
                seed,
405
406
407
408
                q_cu_seqlen,
                kv_cu_seqlen,
                q_seq_offsets,
                k_seq_offsets,
409
410
411
412
                _q_segment_ids,
                _kv_segment_ids,
                _q_segment_pos,
                _kv_segment_pos,  # ffi_lowering needs number of parameters meets primitive.lowering
413
414
415
416
417
418
419
420
421
422
423
                input_batch=input_batch,
                bias_batch=bias_batch,
                q_max_seqlen=q_max_seqlen,
                kv_max_seqlen=kv_max_seqlen,
                attn_heads=attn_heads,
                num_gqa_groups=num_gqa_groups,
                bias_heads=bias_heads,
                head_dim=head_dim,
                max_segments_per_seq=config.max_segments_per_seq,
                scaling_factor=float(config.scaling_factor),
                dropout_probability=float(config.dropout_probability),
Reese Wang's avatar
Reese Wang committed
424
425
426
                bias_type=int(config.attn_bias_type.value),
                mask_type=int(config.attn_mask_type.value),
                qkv_layout=int(config.qkv_layout.value),
427
428
429
430
431
432
433
434
435
436
437
                is_training=config.is_training,
                deterministic=not FusedAttnHelper.is_non_deterministic_allowed(),
                window_size_left=config.window_size[0],
                window_size_right=config.window_size[1],
            )
        else:
            operands = [
                q,
                k,
                v,
                bias,
438
                seed,
439
440
441
442
443
444
445
446
447
448
449
450
                q_cu_seqlen,
                kv_cu_seqlen,
                q_seq_offsets,
                k_seq_offsets,
            ]
            operand_shapes = map(lambda x: x.type.shape, operands)
            out_types = [
                ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
                for output in ctx.avals_out
            ]
            args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

451
452
            wkspace_aval = ctx.avals_out[-1]

453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
            opaque = transformer_engine_jax.pack_fused_attn_descriptor(
                input_batch,
                bias_batch,
                q_max_seqlen,
                kv_max_seqlen,
                attn_heads,
                num_gqa_groups,
                bias_heads,
                head_dim,
                config.max_segments_per_seq,
                wkspace_aval.size,
                config.scaling_factor,
                config.dropout_probability,
                config.attn_bias_type,
                config.attn_mask_type,
                config.qkv_layout,
                jax_dtype_to_te_dtype(q_aval.dtype),
                jax_dtype_to_te_dtype(wkspace_aval.dtype),
                config.is_training,
                not FusedAttnHelper.is_non_deterministic_allowed(),
                config.window_size[0],
                config.window_size[1],
            )
476

477
            out = custom_caller(FusedAttnFwdPrimitive.name, args, opaque, has_side_effect=False)
478
479
480
481

        return out

    @staticmethod
482
483
484
485
486
    def impl(
        q,
        k,
        v,
        bias,
487
        seed,
488
489
        q_seqlen,
        kv_seqlen,
490
491
        q_seq_offsets,
        k_seq_offsets,
492
493
494
495
        _q_segment_ids,
        _kv_segment_ids,
        _q_segment_pos,
        _kv_segment_pos,
496
        config: _FusedAttnConfig,
497
    ):
498
499
        assert FusedAttnFwdPrimitive.inner_primitive is not None

500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
        sequence_descriptor = SequenceDescriptor(
            seqlens=(q_seqlen, kv_seqlen),
            seq_offsets=(q_seq_offsets, k_seq_offsets),
            segment_ids=(_q_segment_ids, _kv_segment_ids),
            segment_pos=(_q_segment_pos, _kv_segment_pos),
        )

        (q_seqlen, kv_seqlen), (q_seq_offsets, k_seq_offsets) = (
            sequence_descriptor.get_seqlens_and_offsets(
                config.attn_mask_type,
                config.qkv_layout,
                config.window_size,
                config.max_segments_per_seq,
            )
        )

Reese Wang's avatar
Reese Wang committed
516
        if config.qkv_layout.is_thd():
517

518
            def _fix_len_take(x, condition, fill_value=-1):
519
520
521
522
                x_shape = x.shape
                x = x.flatten()
                size = x.size
                indices = jnp.nonzero(condition.flatten(), size=size, fill_value=size)[0]
523
                y = jnp.take(x, indices, fill_value=fill_value)
524
525
526
527
528
529
530
531
532
533
                return jnp.reshape(y, x_shape)

            def convert_to_2d(offsets, batch, max_seqlen):
                offsets_2d = jnp.where(
                    offsets >= 0,
                    offsets + (jnp.arange(batch) * max_seqlen)[..., jnp.newaxis],
                    offsets,
                )
                return offsets_2d

Reese Wang's avatar
Reese Wang committed
534
535
536
537
538
            batch, q_max_seqlen, kv_max_seqlen, *_ = FusedAttnHelper.parse_qkv_aval(
                q, k, v, config.qkv_layout
            )
            assert len(batch) == 1, f"Expected len(batch) == 1, but got {len(batch)=}"
            kv_batch = q_batch = batch[0]
539
540

            # Gather valid q_seqlen, which is greater than 0
541
            # cuDNN version < 9.3.0:
542
            # [[3, 5, 7, -1, -1], [2, 4, 6, -1, -1]] -> [[3, 5, 7, 2, 4], [6, -1, -1, -1, -1]]
543
544
545
546
547
548
            # cuDNN version >= 9.3.0, which supports act_seqlen = 0
            # [[3, 5, 7, -1, -1], [2, 4, 6, -1, -1]] -> [[3, 5, 7, 2, 4], [6, 0, 0, 0, 0]]
            if get_cudnn_version() >= (9, 3, 0):
                fill_value = 0
            else:
                fill_value = -1
549

550
551
            q_seqlen = _fix_len_take(q_seqlen, q_seqlen > 0, fill_value=fill_value)
            kv_seqlen = _fix_len_take(kv_seqlen, kv_seqlen > 0, fill_value=fill_value)
552
553
554
555
556

            # Flatten the offset calculation
            # max_seqlen = 8, [[0, 3, 5, -1], [0, 2, 4, -1]] -> [[0, 3, 5, -1], [8, 11, 13, -1]]
            q_seq_offsets = convert_to_2d(q_seq_offsets, q_batch, q_max_seqlen)
            k_seq_offsets = convert_to_2d(k_seq_offsets, kv_batch, kv_max_seqlen)
557

558
559
            # Gather valid q_seq_offsets, which is greater and equal to 0
            # [[0, 3, 5, -1], [8, 11, 13, -1]] -> [[0, 3, 5, 8], [11, 13, -1, -1]]
560
            # And set the unused position to max size (batch * max_seqlen)
561
            # [[0, 3, 5, 8], [11, 13, -1, -1]] -> [[0, 3, 5, 8], [11, 13, b*s, b*s]]
562
563
564
565
566
567
            q_seq_offsets = _fix_len_take(
                q_seq_offsets, q_seq_offsets >= 0, fill_value=q_batch * q_max_seqlen
            )
            k_seq_offsets = _fix_len_take(
                k_seq_offsets, k_seq_offsets >= 0, fill_value=kv_batch * kv_max_seqlen
            )
568
569
570

        q_cu_seqlen = generate_cu_seqlen(q_seqlen.flatten())
        kv_cu_seqlen = generate_cu_seqlen(kv_seqlen.flatten())
571
572
573
574
575
576

        output, softmax_aux, rng_state, _ = FusedAttnFwdPrimitive.inner_primitive.bind(
            q,
            k,
            v,
            bias,
577
            seed,
578
579
            q_cu_seqlen,
            kv_cu_seqlen,
580
581
            q_seq_offsets,
            k_seq_offsets,
582
583
584
585
            _q_segment_ids,
            _kv_segment_ids,
            _q_segment_pos,
            _kv_segment_pos,
586
            config=config,
587
        )
588
589
590
        return output, softmax_aux, rng_state

    @staticmethod
591
    def batcher(batched_args, batch_dims, *, config):
592
593
        check_valid_batch_dims(batch_dims)
        assert FusedAttnFwdPrimitive.outer_primitive is not None
594
        q_bdim, _, _, _, seed_bdim, *_ = batch_dims
595
596

        out_bdims = q_bdim, q_bdim, seed_bdim
597
        return (
598
            FusedAttnFwdPrimitive.outer_primitive.bind(*batched_args, config=config),
599
600
            out_bdims,
        )
601
602

    @staticmethod
603
604
    def infer_sharding_from_operands(config, mesh, arg_infos, result_infos):
        del result_infos
605
        q_spec = get_padded_spec(arg_infos[0])
Reese Wang's avatar
Reese Wang committed
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
        if config.qkv_layout.is_qkvpacked():
            # q_spec = (...batch, q_seqlen, 3, head, hidden)
            out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec[:-3], *q_spec[-2:]))
            softmax_aux_sharding = NamedSharding(
                mesh, PartitionSpec(*q_spec[:-4], q_spec[-2], q_spec[-4], None)
            )
        elif config.qkv_layout.is_kvpacked():
            # q_spec = (...batch, q_seqlen, head, hidden)
            # k_spec = (...batch, kv_seqlen, 2, num_gqa_groups, hidden)
            out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
            softmax_aux_sharding = NamedSharding(
                mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], None)
            )
        elif config.qkv_layout.is_separate():
            # q_spec = (...batch, q_seqlen, head, hidden)
            # k_spec = (...batch, kv_seqlen, num_gqa_groups, hidden)
            out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
            softmax_aux_sharding = NamedSharding(
                mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], None)
            )
        else:
            raise ValueError(f"Unsupported {config.qkv_layout=}")
628
629
630
631
        rng_state_sharding = NamedSharding(mesh, PartitionSpec(get_all_mesh_axes(), None))
        return (out_sharding, softmax_aux_sharding, rng_state_sharding)

    @staticmethod
632
    def partition(config, mesh, arg_infos, result_infos):
633
634
        out_sharding = result_infos[0].sharding
        softmax_aux_sharding = result_infos[1].sharding
635
636
637
        rng_state_sharding = seed_sharding = NamedSharding(
            mesh, PartitionSpec(get_all_mesh_axes(), None)
        )
638
639
        arg_shardings = [arg_i.sharding for arg_i in arg_infos]
        arg_shardings[4] = seed_sharding
640
641
        arg_shardings[-1] = arg_shardings[-3]
        arg_shardings[-2] = arg_shardings[-4]
642
        arg_shardings = tuple(arg_shardings)
643
        out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding)
644
        impl = partial(FusedAttnFwdPrimitive.impl, config=config)
645
646
647
648
649
650
651
652
653
654
        return mesh, impl, out_shardings, arg_shardings


register_primitive(FusedAttnFwdPrimitive)


class FusedAttnBwdPrimitive(BasePrimitive):
    """
    Fused Attention Backward Primitive
    """
655

656
657
    name = "te_fused_attn_backward"
    multiple_results = True
658
    impl_static_args = (16,)
659
660
661
662
    inner_primitive = None
    outer_primitive = None

    @staticmethod
663
664
665
666
667
668
669
670
671
    def abstract(
        q_aval,
        k_aval,
        v_aval,
        bias_aval,
        softmax_aux_aval,
        rng_state_aval,
        output_aval,
        doutput_aval,
672
673
674
675
        q_seqlen_or_cu_seqlen_aval,
        kv_seqlen_or_cu_seqlen_aval,
        _q_seq_offsets,
        _k_seq_offsets,
676
677
678
679
        _q_segment_ids,
        _kv_segment_ids,
        _q_segment_pos,
        _kv_segment_pos,
680
        *,
681
        config,
682
    ):
683
684
685
686
687
688
689
690
691
692
693
        """
        Fused attention bwd abstract
        """
        del softmax_aux_aval, rng_state_aval, output_aval

        q_dtype = dtypes.canonicalize_dtype(q_aval.dtype)
        k_dtype = dtypes.canonicalize_dtype(k_aval.dtype)
        v_dtype = dtypes.canonicalize_dtype(v_aval.dtype)
        bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype)
        doutput_dtype = dtypes.canonicalize_dtype(doutput_aval.dtype)
        assert q_dtype == k_dtype == v_dtype == bias_dtype == doutput_dtype
694
        assert q_seqlen_or_cu_seqlen_aval.dtype == kv_seqlen_or_cu_seqlen_aval.dtype
695

696
        batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = (
697
            FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout)
698
        )
699

Reese Wang's avatar
Reese Wang committed
700
        if config.attn_bias_type == AttnBiasType.NO_BIAS:
701
702
703
704
705
            bias_batch = bias_heads = 0
        else:
            *bias_batch_shape, bias_heads, _, _ = bias_aval.shape
            bias_batch = reduce(operator.mul, bias_batch_shape)

706
707
        deterministic = not FusedAttnHelper.is_non_deterministic_allowed()

708
        input_batch = reduce(operator.mul, batch_shape)
709
710
711
712
713
714
715
716
717
        wkspace_shape, wkspace_dtype = transformer_engine_jax.get_fused_attn_bwd_workspace_sizes(
            input_batch,
            bias_batch,
            q_max_seqlen,
            kv_max_seqlen,
            attn_heads,
            num_gqa_groups,
            bias_heads,
            head_dim,
718
719
            config.scaling_factor,
            config.dropout_probability,
Reese Wang's avatar
Reese Wang committed
720
721
722
            config.attn_bias_type.value,
            config.attn_mask_type.value,
            config.qkv_layout.value,
723
            jax_dtype_to_te_dtype(q_aval.dtype),
724
            config.is_training,
725
            deterministic,
726
            config.max_segments_per_seq,
727
728
            config.window_size[0],
            config.window_size[1],
729
        )
730
731
732
733
734

        dq_aval = q_aval.update(shape=q_aval.shape, dtype=q_dtype)
        dk_aval = k_aval.update(shape=k_aval.shape, dtype=k_dtype)
        dv_aval = v_aval.update(shape=v_aval.shape, dtype=v_dtype)
        dbias_aval = bias_aval.update(shape=bias_aval.shape, dtype=bias_dtype)
735
736
737
        wkspace_aval = q_aval.update(
            shape=wkspace_shape, dtype=te_dtype_to_jax_dtype(wkspace_dtype)
        )
738
739
740
741
742
743
744
745

        return dq_aval, dk_aval, dv_aval, dbias_aval, wkspace_aval

    @staticmethod
    def outer_abstract(*args, **kwargs):
        """
        Fused attention fwd outer primitive abstract
        """
746
        dq_aval, dk_aval, dv_aval, dbias_aval, _ = FusedAttnBwdPrimitive.abstract(*args, **kwargs)
747
748
749
        return dq_aval, dk_aval, dv_aval, dbias_aval

    @staticmethod
750
751
752
753
754
755
756
757
758
759
760
761
    def lowering(
        ctx,
        q,
        k,
        v,
        bias,
        softmax_aux,
        rng_state,
        output,
        doutput,
        q_cu_seqlen,
        kv_cu_seqlen,
762
763
        q_seq_offsets,
        k_seq_offsets,
764
765
766
767
        q_segment_ids,
        kv_segment_ids,
        q_segment_pos,
        kv_segment_pos,
768
        *,
769
        config,
770
    ):
771
772
773
774
775
        """
        Fused attention bwd lowering rules
        """
        q_aval, k_aval, v_aval, bias_aval, *_ = ctx.avals_in

776
        batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = (
777
            FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout)
778
        )
779
780
781

        input_batch = reduce(operator.mul, batch_shape)

Reese Wang's avatar
Reese Wang committed
782
        if config.attn_bias_type == AttnBiasType.NO_BIAS:
783
784
785
786
787
            bias_batch = bias_heads = 0
        else:
            *bias_batch_shape, bias_heads, _, _ = bias_aval.shape
            bias_batch = reduce(operator.mul, bias_batch_shape)

788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
        if is_ffi_enabled():
            name = "te_fused_attn_backward_ffi"
            out = ffi.ffi_lowering(name)(
                ctx,
                q,
                k,
                v,
                bias,
                softmax_aux,
                rng_state,
                output,
                doutput,
                q_cu_seqlen,
                kv_cu_seqlen,
                q_seq_offsets,
                k_seq_offsets,
804
805
806
807
                q_segment_ids,
                kv_segment_ids,
                q_segment_pos,
                kv_segment_pos,  # ffi_lowering needs number of parameters meets primitive.lowering
808
809
810
811
812
813
814
815
816
817
818
                input_batch=input_batch,
                bias_batch=bias_batch,
                q_max_seqlen=q_max_seqlen,
                kv_max_seqlen=kv_max_seqlen,
                attn_heads=attn_heads,
                num_gqa_groups=num_gqa_groups,
                bias_heads=bias_heads,
                head_dim=head_dim,
                max_segments_per_seq=config.max_segments_per_seq,
                scaling_factor=float(config.scaling_factor),
                dropout_probability=float(config.dropout_probability),
Reese Wang's avatar
Reese Wang committed
819
820
821
                bias_type=int(config.attn_bias_type.value),
                mask_type=int(config.attn_mask_type.value),
                qkv_layout=int(config.qkv_layout.value),
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
                is_training=config.is_training,
                deterministic=not FusedAttnHelper.is_non_deterministic_allowed(),
                window_size_left=config.window_size[0],
                window_size_right=config.window_size[1],
            )
        else:
            operands = [
                q,
                k,
                v,
                bias,
                softmax_aux,
                rng_state,
                output,
                doutput,
                q_cu_seqlen,
                kv_cu_seqlen,
                q_seq_offsets,
                k_seq_offsets,
            ]
            operand_shapes = map(lambda x: x.type.shape, operands)
            out_types = [
                ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
                for output in ctx.avals_out
            ]
            args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

            wkspace_aval = ctx.avals_out[-1]
850

851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
            opaque = transformer_engine_jax.pack_fused_attn_descriptor(
                input_batch,
                bias_batch,
                q_max_seqlen,
                kv_max_seqlen,
                attn_heads,
                num_gqa_groups,
                bias_heads,
                head_dim,
                config.max_segments_per_seq,
                wkspace_aval.size,
                config.scaling_factor,
                config.dropout_probability,
                config.attn_bias_type,
                config.attn_mask_type,
                config.qkv_layout,
                jax_dtype_to_te_dtype(q_aval.dtype),
                jax_dtype_to_te_dtype(wkspace_aval.dtype),
                config.is_training,
                not FusedAttnHelper.is_non_deterministic_allowed(),
                config.window_size[0],
                config.window_size[1],
            )
874

875
            out = custom_caller(FusedAttnBwdPrimitive.name, args, opaque, has_side_effect=False)
876
877
878
879

        return out

    @staticmethod
880
881
882
883
884
885
886
887
888
889
890
    def impl(
        q,
        k,
        v,
        bias,
        softmax_aux,
        rng_state,
        output,
        doutput,
        q_seqlen,
        kv_seqlen,
891
892
        q_seq_offsets,
        k_seq_offsets,
893
894
895
896
        _q_segment_ids,
        _kv_segment_ids,
        _q_segment_pos,
        _kv_segment_pos,
897
        config,
898
    ):
899
900
        assert FusedAttnBwdPrimitive.inner_primitive is not None

901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
        sequence_descriptor = SequenceDescriptor(
            seqlens=(q_seqlen, kv_seqlen),
            seq_offsets=(q_seq_offsets, k_seq_offsets),
            segment_ids=(_q_segment_ids, _kv_segment_ids),
            segment_pos=(_q_segment_pos, _kv_segment_pos),
        )

        (q_seqlen, kv_seqlen), (q_seq_offsets, k_seq_offsets) = (
            sequence_descriptor.get_seqlens_and_offsets(
                config.attn_mask_type,
                config.qkv_layout,
                config.window_size,
                config.max_segments_per_seq,
            )
        )

Reese Wang's avatar
Reese Wang committed
917
        if config.qkv_layout.is_thd():
918

919
            def _fix_len_take(x, condition, fill_value=-1):
920
921
922
923
924
                x_shape = x.shape
                x = x.flatten()
                size = x.size
                indices = jnp.nonzero(condition.flatten(), size=size, fill_value=size)[0]
                # TODO(rewang): try indices_are_sorted
925
                y = jnp.take(x, indices, fill_value=fill_value)
926
927
928
929
930
931
932
933
934
935
                return jnp.reshape(y, x_shape)

            def convert_to_2d(offsets, batch, max_seqlen):
                offsets_2d = jnp.where(
                    offsets >= 0,
                    offsets + (jnp.arange(batch) * max_seqlen)[..., jnp.newaxis],
                    offsets,
                )
                return offsets_2d

Reese Wang's avatar
Reese Wang committed
936
937
938
939
940
            batch, q_max_seqlen, kv_max_seqlen, *_ = FusedAttnHelper.parse_qkv_aval(
                q, k, v, config.qkv_layout
            )
            assert len(batch) == 1
            kv_batch = q_batch = batch[0]
941
942

            # Gather valid q_seqlen, which is greater than 0
943
            # cuDNN version < 9.3.0:
944
            # [[3, 5, 7, -1, -1], [2, 4, 6, -1, -1]] -> [[3, 5, 7, 2, 4], [6, -1, -1, -1, -1]]
945
946
947
948
949
950
951
952
            # cuDNN version >= 9.3.0, which supports act_seqlen = 0
            # [[3, 5, 7, -1, -1], [2, 4, 6, -1, -1]] -> [[3, 5, 7, 2, 4], [6, 0, 0, 0, 0]]
            if get_cudnn_version() >= (9, 3, 0):
                fill_value = 0
            else:
                fill_value = -1
            q_seqlen = _fix_len_take(q_seqlen, q_seqlen > 0, fill_value=fill_value)
            kv_seqlen = _fix_len_take(kv_seqlen, kv_seqlen > 0, fill_value=fill_value)
953
954
955
956
957

            # Flatten the offset calculation
            # max_seqlen = 8, [[0, 3, 5, -1], [0, 2, 4, -1]] -> [[0, 3, 5, -1], [8, 11, 13, -1]]
            q_seq_offsets = convert_to_2d(q_seq_offsets, q_batch, q_max_seqlen)
            k_seq_offsets = convert_to_2d(k_seq_offsets, kv_batch, kv_max_seqlen)
958

959
960
            # Gather valid q_seq_offsets, which is greater and equal to 0
            # [[0, 3, 5, -1], [8, 11, 13, -1]] -> [[0, 3, 5, 8], [11, 13, -1, -1]]
961
            # And set the unused position to max size (batch * max_seqlen)
962
            # [[0, 3, 5, 8], [11, 13, -1, -1]] -> [[0, 3, 5, 8], [11, 13, b*s, b*s]]
963
964
965
966
967
968
            q_seq_offsets = _fix_len_take(
                q_seq_offsets, q_seq_offsets >= 0, fill_value=q_batch * q_max_seqlen
            )
            k_seq_offsets = _fix_len_take(
                k_seq_offsets, k_seq_offsets >= 0, fill_value=kv_batch * kv_max_seqlen
            )
969
970
971

        q_cu_seqlen = generate_cu_seqlen(q_seqlen.flatten())
        kv_cu_seqlen = generate_cu_seqlen(kv_seqlen.flatten())
972
973
974
975
976
977
978
979
980
981
982
983

        dq, dk, dv, dbias, _ = FusedAttnBwdPrimitive.inner_primitive.bind(
            q,
            k,
            v,
            bias,
            softmax_aux,
            rng_state,
            output,
            doutput,
            q_cu_seqlen,
            kv_cu_seqlen,
984
985
            q_seq_offsets,
            k_seq_offsets,
986
987
988
989
            _q_segment_ids,
            _kv_segment_ids,
            _q_segment_pos,
            _kv_segment_pos,
990
            config=config,
991
        )
992
993
994
        return dq, dk, dv, dbias

    @staticmethod
995
    def batcher(batched_args, batch_dims, *, config):
996
997
998
999
1000
        check_valid_batch_dims(batch_dims)
        assert FusedAttnBwdPrimitive.outer_primitive is not None
        q_bdim, k_bdim, v_bdim, *_ = batch_dims

        out_bdims = q_bdim, k_bdim, v_bdim, q_bdim
1001
        return (
1002
            FusedAttnBwdPrimitive.outer_primitive.bind(*batched_args, config=config),
1003
1004
            out_bdims,
        )
1005
1006

    @staticmethod
1007
1008
    def infer_sharding_from_operands(config, mesh, arg_infos, result_infos):
        del config, result_infos
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
        q_spec = get_padded_spec(arg_infos[0])
        k_spec = get_padded_spec(arg_infos[1])
        v_spec = get_padded_spec(arg_infos[2])
        bias_spec = get_padded_spec(arg_infos[3])
        dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
        dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec))
        dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec))
        dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec))
        return (dq_sharding, dk_sharding, dv_sharding, dbias_sharding)

    @staticmethod
1020
    def partition(config, mesh, arg_infos, result_infos):
1021
1022
1023
1024
1025
1026
1027
1028
1029
        del result_infos
        q_spec = get_padded_spec(arg_infos[0])
        k_spec = get_padded_spec(arg_infos[1])
        v_spec = get_padded_spec(arg_infos[2])
        bias_spec = get_padded_spec(arg_infos[3])
        dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
        dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec))
        dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec))
        dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec))
1030
1031
1032
1033
        arg_shardings = [arg_i.sharding for arg_i in arg_infos]
        arg_shardings[-1] = arg_shardings[-3]
        arg_shardings[-2] = arg_shardings[-4]
        arg_shardings = tuple(arg_shardings)
1034
1035
        out_shardings = (dq_sharding, dk_sharding, dv_sharding, dbias_sharding)

1036
        def sharded_impl(
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
            q,
            k,
            v,
            bias,
            softmax_aux,
            rng_state,
            output,
            doutput,
            q_cu_seqlen,
            kv_cu_seqlen,
            q_seq_offsets,
            k_seq_offsets,
1049
1050
1051
1052
            _q_segment_ids,
            _kv_segment_ids,
            _q_segment_pos,
            _kv_segment_pos,
1053
        ):
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
            local_dq, local_dk, local_dv, local_dbias = FusedAttnBwdPrimitive.impl(
                q,
                k,
                v,
                bias,
                softmax_aux,
                rng_state,
                output,
                doutput,
                q_cu_seqlen,
                kv_cu_seqlen,
1065
1066
                q_seq_offsets,
                k_seq_offsets,
1067
1068
1069
1070
                _q_segment_ids,
                _kv_segment_ids,
                _q_segment_pos,
                _kv_segment_pos,
1071
                config=config,
1072
            )
1073
            global_dbias = local_dbias
Reese Wang's avatar
Reese Wang committed
1074
            if config.attn_bias_type is not AttnBiasType.NO_BIAS:
1075
                global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh)
1076
1077
1078
1079
1080
1081
1082
1083
            return local_dq, local_dk, local_dv, global_dbias

        return mesh, sharded_impl, out_shardings, arg_shardings


register_primitive(FusedAttnBwdPrimitive)


Reese Wang's avatar
Reese Wang committed
1084
def reorder_causal_dual_chunk_swap(tensor, cp_size: int, seq_dim: int, to_contiguous: bool):
1085
1086
1087
1088
1089
1090
1091
1092
1093
    """Reorders a tensor for load balancing the compute of causal attention."""
    if cp_size == 1:
        return tensor

    if cp_size % 2 != 0:
        raise ValueError(f"{cp_size=} must be a multiple of 2.")

    # Need to ensure we have 2 pairs to swap for balancing between cp ranks
    if tensor.shape[seq_dim] % (cp_size * 2) != 0:
Reese Wang's avatar
Reese Wang committed
1094
        raise ValueError(f"{tensor.shape[seq_dim]=} is not a multiple of {cp_size*2=}")
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135

    # [B, S, H, D] -> [B, 2*cp_size, S/2*cp_size, D]
    # [S, B, H, D] -> [2*cp_size, S/2*cp_size, B, H, D]
    ori_tensor_shape = tensor.shape
    tensor = tensor.reshape(
        (
            *ori_tensor_shape[:seq_dim],
            2 * cp_size,
            ori_tensor_shape[seq_dim] // (2 * cp_size),
            *ori_tensor_shape[seq_dim + 1 :],
        )
    )

    parts = []
    if not to_contiguous:
        for cp_rank in range(cp_size):
            # [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] -> [B, 2, S/2*cp_size, H, D]
            # [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] -> [2, S/2*cp_size, B, H, D]
            index = jnp.array([cp_rank, (2 * cp_size - cp_rank - 1)])
            parts.append(jnp.take(tensor, index, axis=seq_dim))
    else:
        for cp_rank in range(cp_size // 2):
            # [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] -> [B, 2, S/2*cp_size, H, D]
            # [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] -> [2, S/2*cp_size, B, H, D]
            base = 4 * cp_rank
            index = jnp.array([base, base + 2])
            parts.append(jnp.take(tensor, index, axis=seq_dim))
        for cp_rank in range(cp_size // 2):
            # [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] -> [B, 2, S/2*cp_size, H, D]
            # [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] -> [2, S/2*cp_size, B, H, D]
            base = 2 * cp_size - 1 - 4 * cp_rank
            index = jnp.array([base, base - 2])
            parts.append(jnp.take(tensor, index, axis=seq_dim))

    # [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D]
    # [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D]
    combined = jnp.stack(parts, axis=seq_dim)

    return combined.reshape(ori_tensor_shape)


Reese Wang's avatar
Reese Wang committed
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
def reorder_causal_striped(tensor, cp_size: int, seq_dim: int, is_inverse: bool):
    """Reorders a tensor for load balancing with striped pattern"""
    origin_shape = tensor.shape
    if origin_shape[seq_dim] % cp_size != 0:
        raise ValueError(
            "Expected origin_shape[seq_dim] is multiple of cp_size but got"
            f" {origin_shape[seq_dim]=} and {cp_size=}"
        )

    if not is_inverse:
        new_shape = [
            *origin_shape[:seq_dim],
            *[origin_shape[seq_dim] // cp_size, cp_size],
            *origin_shape[seq_dim + 1 :],
        ]
    else:
        new_shape = [
            *origin_shape[:seq_dim],
            *[cp_size, origin_shape[seq_dim] // cp_size],
            *origin_shape[seq_dim + 1 :],
        ]

    chunked_tensor = tensor.reshape(new_shape)
    reordered_chunked_tensor = jnp.swapaxes(chunked_tensor, seq_dim, seq_dim + 1)
    return reordered_chunked_tensor.reshape(origin_shape)


1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
@dataclass(frozen=True)
class _FusedAttnCPWithAllGatherHelper:
    """Helper class to assist with running the all-gather strategy for CP attention."""

    mesh: jax.sharding.Mesh
    config: _FusedAttnConfig

    def check_supported(self):
        """Checks if the context parallel implementation is supported by the given arguments."""
        header = "Context parallel fused attention"

Reese Wang's avatar
Reese Wang committed
1174
        allowed_layouts = [QKVLayout.BSHD_BS2HD, QKVLayout.BSHD_BSHD_BSHD]
1175
1176
1177
        if self.config.qkv_layout not in allowed_layouts:
            raise ValueError(
                f"{header} only supports layouts:"
1178
                f" {','.join(map(str, allowed_layouts))} got: {self.config.qkv_layout}"
1179
            )
1180

Reese Wang's avatar
Reese Wang committed
1181
        if self.config.attn_bias_type != AttnBiasType.NO_BIAS:
1182
            raise ValueError(f"{header} does not support bias got: {self.config.attn_bias_type}")
1183

Reese Wang's avatar
Reese Wang committed
1184
        allowed_masks = [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]
1185
1186
1187
        if self.config.attn_mask_type not in allowed_masks:
            raise ValueError(
                f"{header} only supports masking types: "
1188
                f" {','.join(map(str, allowed_masks))} got: {self.config.attn_mask_type}"
1189
            )
1190

1191
1192
1193
1194
1195
1196
1197
1198
        if self.config.max_segments_per_seq != 1:
            raise ValueError(
                f"{header} only supports max_segments_per_seq == 1 got:"
                f" {self.config.max_segments_per_seq}"
            )

        if self.config.dropout_probability != 0.0:
            raise ValueError(f"{header} does not support dropout")
1199
1200
1201

    def get_adjusted_mask(self):
        """Converts the mask for context parallelism."""
Reese Wang's avatar
Reese Wang committed
1202
1203
        if self.config.attn_mask_type == AttnMaskType.CAUSAL_MASK:
            return AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK
1204
1205
        return self.config.attn_mask_type

1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
    def get_step_config(self) -> _FusedAttnConfig:
        """Returns a _FusedAttnConfig for single CP step call to fused attention."""
        return _FusedAttnConfig(
            attn_bias_type=self.config.attn_bias_type,
            attn_mask_type=self.get_adjusted_mask(),
            qkv_layout=self.config.qkv_layout,
            scaling_factor=self.config.scaling_factor,
            dropout_probability=self.config.dropout_probability,
            is_training=self.config.is_training,
            max_segments_per_seq=self.config.max_segments_per_seq,
            window_size=self.config.window_size,
            context_parallel_load_balanced=self.config.context_parallel_load_balanced,
            cp_axis=self.config.cp_axis,
        )

1221
1222
1223
1224
    def all_gather_kv(self, k, v):
        """Performs a all-gather of k and v over context parallel ranks."""

        def ag(x):
1225
            x = lax_paral_op(
1226
1227
                x, lax.all_gather, self.config.cp_axis, mesh=self.mesh, axis=1, tiled=True
            )
1228
1229
            if self.config.context_parallel_load_balanced:
                cp_size = get_mesh_axis_size(self.config.cp_axis, self.mesh)
Reese Wang's avatar
Reese Wang committed
1230
                x = reorder_causal_dual_chunk_swap(x, cp_size, 1, to_contiguous=True)
1231
            return x
1232

Reese Wang's avatar
Reese Wang committed
1233
1234
1235
1236
        if self.config.qkv_layout.is_kvpacked():
            return ag(k), v
        if self.config.qkv_layout.is_separate():
            return ag(k), ag(v)
1237
1238
1239
1240
1241
1242
1243

        return k, v  # fall through

    def reduce_scatter_dkv(self, dk, dv):
        """Performs a reduce-scatter of dk and dv over context parallel ranks."""

        def rs(x):
1244
1245
            if self.config.context_parallel_load_balanced:
                cp_size = get_mesh_axis_size(self.config.cp_axis, self.mesh)
Reese Wang's avatar
Reese Wang committed
1246
                x = reorder_causal_dual_chunk_swap(x, cp_size, 1, to_contiguous=False)
1247

1248
1249
1250
1251
1252
1253
1254
1255
1256
            return lax_paral_op(
                x,
                lax.psum_scatter,
                self.config.cp_axis,
                mesh=self.mesh,
                scatter_dimension=1,
                tiled=True,
            )

Reese Wang's avatar
Reese Wang committed
1257
1258
1259
1260
        if self.config.qkv_layout.is_kvpacked():
            return rs(dk), dv
        if self.config.qkv_layout.is_separate():
            return rs(dk), rs(dv)
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296

        return dk, dv  # fall through

    def kv_seqlens_for_rank(self, cp_rank, kv_max_seqlen, kv_seqlen_per_subrank):
        """Returns sequence lengths of KV to use for each sub rank of the given cp_rank.

        Example: CP=4, MaxLen = 1024, Unbalanced
           cp_rank 0: [128, 256]
           cp_rank 1: [384, 512]
           cp_rank 2: [640, 768]
           cp_rank 3: [896, 1024]

        Example: CP=4, MaxLen = 1024, Balanced
           cp_rank 0: [128, 1024]
           cp_rank 1: [256, 896]
           cp_rank 2: [384, 768]
           cp_rank 3: [512, 640]
        """
        if self.config.context_parallel_load_balanced:
            kv_seq_this_rank = [
                (cp_rank + 1) * kv_seqlen_per_subrank,
                kv_max_seqlen - cp_rank * kv_seqlen_per_subrank,
            ]
        else:
            kv_seq_this_rank = [
                (cp_rank * 2 + 1) * kv_seqlen_per_subrank,
                (cp_rank * 2 + 2) * kv_seqlen_per_subrank,
            ]
        return kv_seq_this_rank

    def slice_kv(self, k, v, slice_seq_len):
        """Slices k and v tensors to a sequence length of slice_seq_len."""

        def sliced(x):
            return lax.dynamic_slice_in_dim(x, 0, slice_seq_len, axis=1)

Reese Wang's avatar
Reese Wang committed
1297
1298
1299
1300
        if self.config.qkv_layout.is_kvpacked():
            return sliced(k), v
        if self.config.qkv_layout.is_separate():
            return sliced(k), sliced(v)
1301
1302
1303
1304
1305
1306
1307
1308
1309

        return k, v  # fall through

    def pad_kv(self, dk, dv, pad_seq_len):
        """Pads dk and dv tensors to a sequence length of pad_seq_len."""

        def pad(x, npad):
            return jnp.pad(x, npad, "constant", constant_values=0.0)

Reese Wang's avatar
Reese Wang committed
1310
1311
1312
1313
1314
1315
        if self.config.qkv_layout.is_kvpacked():
            npad = [[0, 0], [0, pad_seq_len], [0, 0], [0, 0], [0, 0]]
            return pad(dk, npad), dv
        if self.config.qkv_layout.is_separate():
            npad = [[0, 0], [0, pad_seq_len], [0, 0], [0, 0]]
            return pad(dk, npad), pad(dv, npad)
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330

        return dk, dv  # fall through


class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive):
    """
    Fused Attention Forward with Context Parallelism Primitive

    This context parallel implementation uses all-gather to collect KV inputs from context parallel ranks.
    """

    @staticmethod
    def partition(config, mesh, arg_infos, result_infos):
        # Call base implementation for non-context parallel mesh to avoid unecessary work.
        is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1
1331
        assert (
1332
            not is_context_parallel or config.window_size[0] == -1
1333
        ), "Sliding window attention is not supported when context parallelism is enabled"
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
        if not is_context_parallel:
            return FusedAttnFwdPrimitive.partition(config, mesh, arg_infos, result_infos)

        helper = _FusedAttnCPWithAllGatherHelper(mesh, config)
        helper.check_supported()

        out_sharding = result_infos[0].sharding
        softmax_aux_sharding = result_infos[1].sharding
        rng_state_sharding = seed_sharding = NamedSharding(
            mesh, PartitionSpec(get_all_mesh_axes(), None)
        )
1345
1346
1347
        arg_shardings = [arg_i.sharding for arg_i in arg_infos]
        arg_shardings[4] = seed_sharding
        arg_shardings = tuple(arg_shardings)
1348
1349
        out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding)

1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
        def impl(
            q,
            k,
            v,
            bias,
            seed,
            q_seqlen,
            kv_seqlen,
            q_seq_offsets,
            k_seq_offsets,
            _q_segment_ids,
            _kv_segment_ids,
            _q_segment_pos,
            _kv_segment_pos,
        ):
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
            cp_size = get_mesh_axis_size(config.cp_axis, mesh)
            cp_rank = get_mesh_axis_rank(config.cp_axis, mesh)

            # cuDNN does not support right-aligned masking with dynamic sequence length padding.
            # Therefore we must explicitly instantiate each CP rank slicing and use a runtime switch
            # to select the appropriate computation. Each case generates a [..., SEQ/CP, ..] tensor
            # meeting the expectation of the SPMD model.
            # TODO(mgoldfarb-nvidia): When cuDNN supports we should be able to make use of a padding
            # mask/sequence length tensor to avoid this unrolled loop.
            def _cross_attn(idx, q, k, v, bias, q_seqlen, kv_seqlen, seed):
                kv_max_seqlen = k.shape[1]
                kv_seqlen_per_subrank = kv_max_seqlen // (cp_size * 2)
                assert kv_max_seqlen % cp_size == 0, "sequence length must evenly divide cp size"

                q_split = jnp.split(q, 2, axis=1)

                kv_seqlens_for_rank = helper.kv_seqlens_for_rank(
                    idx, kv_max_seqlen, kv_seqlen_per_subrank
                )

                results = []
                for sub_idx in range(2):
Reese Wang's avatar
Reese Wang committed
1387
                    if config.attn_mask_type == AttnMaskType.NO_MASK:
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
                        k_unmasked, v_unmasked = k, v  # full kv used for unmasked
                    else:
                        k_unmasked, v_unmasked = helper.slice_kv(k, v, kv_seqlens_for_rank[sub_idx])

                    q_seqlen_for_step = q_seqlen / (cp_size * 2)
                    num_kv_chunks = kv_max_seqlen // kv_seqlens_for_rank[sub_idx]
                    kv_seqlen_for_step = (kv_seqlen / (cp_size * 2)) * num_kv_chunks

                    output, softmax_aux, rng_state = FusedAttnFwdPrimitive.impl(
                        q_split[sub_idx],
                        k_unmasked,
                        v_unmasked,
                        bias,
1401
                        seed,
1402
1403
1404
1405
                        q_seqlen_for_step,
                        kv_seqlen_for_step,
                        q_seq_offsets,
                        k_seq_offsets,
1406
1407
1408
1409
                        _q_segment_ids,
                        _kv_segment_ids,
                        _q_segment_pos,
                        _kv_segment_pos,
1410
                        config=helper.get_step_config(),
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
                    )
                    results.append((output, softmax_aux, rng_state))

                output = jnp.concatenate((results[0][0], results[1][0]), axis=1)
                softmax_aux = jnp.concatenate((results[0][1], results[1][1]), axis=2)
                rng_state = results[1][2]  # Use the final RNG state

                return output, softmax_aux, rng_state

            k_ag, v_ag = helper.all_gather_kv(k, v)

            functions = [
                partial(_cross_attn, idx, q, k_ag, v_ag, bias, q_seqlen, kv_seqlen, seed)
                for idx in range(cp_size)
            ]

            return lax.switch(cp_rank, functions)

        return mesh, impl, out_shardings, arg_shardings


register_primitive(FusedAttnCPWithAllGatherFwdPrimitive)


class FusedAttnCPWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive):
    """
    Fused Attention Backward with Context Parallelism Primitive.

    This context parallel implementation uses all-gather to collect KV and dKV inputs from context parallel ranks.
    The gradients are subsequently reduce-scattered back to each context parallel rank.
    """

    @staticmethod
    def partition(config, mesh, arg_infos, result_infos):
        # Call base implementation for non-context parallel mesh to avoid unecessary work.
        is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1
1447
        assert (
1448
            not is_context_parallel or config.window_size[0] == -1
1449
        ), "Sliding window attention is not supported when context parallelism is enabled"
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
        if not is_context_parallel:
            return FusedAttnBwdPrimitive.partition(config, mesh, arg_infos, result_infos)

        # Ensure we can support this configuration with context parallelism.
        helper = _FusedAttnCPWithAllGatherHelper(mesh, config)
        helper.check_supported()

        del result_infos
        q_spec = get_padded_spec(arg_infos[0])
        k_spec = get_padded_spec(arg_infos[1])
        v_spec = get_padded_spec(arg_infos[2])
        bias_spec = get_padded_spec(arg_infos[3])
        dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
        dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec))
        dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec))
        dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec))
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
        out_shardings = (dq_sharding, dk_sharding, dv_sharding, dbias_sharding)

        def impl(
            q,
            k,
            v,
            bias,
            softmax_aux,
            rng_state,
            output,
            doutput,
            q_seqlen,
            kv_seqlen,
            q_seq_offsets,
            k_seq_offsets,
1482
1483
1484
1485
            _q_segment_ids,
            _kv_segment_ids,
            _q_segment_pos,
            _kv_segment_pos,
1486
1487
1488
1489
1490
1491
        ):
            cp_size = get_mesh_axis_size(config.cp_axis, mesh)
            cp_rank = get_mesh_axis_rank(config.cp_axis, mesh)

            # See comment in FusedAttnCPFwdPrimitive.partition for why we define this function.
            def _cross_attn_bwd(
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
                idx,
                q,
                k,
                v,
                bias,
                softmax_aux,
                rng_state,
                output,
                doutput,
                q_seqlen,
                kv_seqlen,
                _q_segment_ids,
                _kv_segment_ids,
                _q_segment_pos,
                _kv_segment_pos,
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
            ):
                kv_max_seqlen = k.shape[1]
                kv_seqlen_per_subrank = kv_max_seqlen // (cp_size * 2)
                assert kv_max_seqlen % cp_size == 0, "sequence length must evenly divide cp size"

                q_split = jnp.split(q, 2, axis=1)
                output_split = jnp.split(output, 2, axis=1)
                doutput_split = jnp.split(doutput, 2, axis=1)
                softmax_aux_split = jnp.split(softmax_aux, 2, axis=2)

                kv_seqlens_for_rank = helper.kv_seqlens_for_rank(
                    idx, kv_max_seqlen, kv_seqlen_per_subrank
                )

                results = []
                for sub_idx in range(2):
Reese Wang's avatar
Reese Wang committed
1523
                    if config.attn_mask_type == AttnMaskType.NO_MASK:
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
                        k_unmasked, v_unmasked = k, v  # full kv used for unmasked
                    else:
                        k_unmasked, v_unmasked = helper.slice_kv(k, v, kv_seqlens_for_rank[sub_idx])

                    q_seqlen_for_step = q_seqlen // (cp_size * 2)
                    num_kv_chunks = kv_max_seqlen // kv_seqlens_for_rank[sub_idx]
                    kv_seqlen_for_step = (kv_seqlen // (cp_size * 2)) * num_kv_chunks

                    dq_local, dk_local, dv_local, dbias_local = FusedAttnBwdPrimitive.impl(
                        q_split[sub_idx],
                        k_unmasked,
                        v_unmasked,
                        bias,
                        softmax_aux_split[sub_idx],
                        rng_state,
                        output_split[sub_idx],
                        doutput_split[sub_idx],
                        q_seqlen_for_step,
                        kv_seqlen_for_step,
                        q_seq_offsets,
                        k_seq_offsets,
1545
1546
1547
1548
                        _q_segment_ids,
                        _kv_segment_ids,
                        _q_segment_pos,
                        _kv_segment_pos,
1549
                        config=helper.get_step_config(),
1550
1551
1552
                    )

                    # pad dk/dv to be unsliced shape so we can reduce scatter over all ranks.
Reese Wang's avatar
Reese Wang committed
1553
                    if config.attn_mask_type != AttnMaskType.NO_MASK:
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
                        pad_length = kv_max_seqlen - kv_seqlens_for_rank[sub_idx]
                        dk_local, dv_local = helper.pad_kv(dk_local, dv_local, pad_length)

                    results.append((dq_local, dk_local, dv_local, dbias_local))

                dq_local = jnp.concatenate((results[0][0], results[1][0]), axis=1)
                dk_local_pad = results[0][1] + results[1][1]
                dv_local_pad = results[0][2] + results[1][2]
                return dq_local, dk_local_pad, dv_local_pad, results[1][3]

            k_ag, v_ag = helper.all_gather_kv(k, v)

            functions = [
                partial(
                    _cross_attn_bwd,
                    idx,
                    q,
                    k_ag,
                    v_ag,
                    bias,
                    softmax_aux,
                    rng_state,
                    output,
                    doutput,
                    q_seqlen,
                    kv_seqlen,
1580
1581
1582
1583
                    _q_segment_ids,
                    _kv_segment_ids,
                    _q_segment_pos,
                    _kv_segment_pos,
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
                )
                for idx in range(cp_size)
            ]

            dq, dk_local, dv_local, dbias = lax.switch(cp_rank, functions)
            dk, dv = helper.reduce_scatter_dkv(dk_local, dv_local)

            return dq, dk, dv, dbias

        return mesh, impl, out_shardings, arg_shardings


register_primitive(FusedAttnCPWithAllGatherBwdPrimitive)


1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
@dataclass(frozen=True)
class _FusedAttnCPWithP2PHelper:
    """Helper class to assist with running the P2P ring strategy for CP attention."""

    mesh: jax.sharding.Mesh
    config: _FusedAttnConfig

    @staticmethod
    def use_scanloop():
        """Returns true if the implementation will use a scan loop for iteration."""
        use_scan = bool(int(os.getenv("NVTE_FUSED_RING_ATTENTION_USE_SCAN", "1")))

        # nvbug(4675071): Disable the HLO verifier for channel ID checks.
        # A WAR was added to XLA: https://github.com/openxla/xla/pull/16779
        def truthy(val):
            return val.lower() in ["1", "true"]

1616
        x = use_scan and get_xla_flag("--xla_ignore_channel_id", default=True, cast=truthy)
1617
1618
1619
1620
1621
1622
        return x

    def check_supported(self):
        """Checks if the context parallel implementation is supported by the given arguments."""
        header = "Context parallel fused ring attention"

Reese Wang's avatar
Reese Wang committed
1623
1624
1625
1626
1627
        if self.config.qkv_layout.is_thd():
            allowed_layouts = [QKVLayout.THD_T2HD, QKVLayout.THD_THD_THD]
        else:
            allowed_layouts = [QKVLayout.BSHD_BS2HD, QKVLayout.BSHD_BSHD_BSHD]

1628
1629
1630
1631
1632
1633
        if self.config.qkv_layout not in allowed_layouts:
            raise ValueError(
                f"{header} only supports layouts:"
                f" {','.join(map(str, allowed_layouts))} got: {self.config.qkv_layout}"
            )

Reese Wang's avatar
Reese Wang committed
1634
        if self.config.attn_bias_type != AttnBiasType.NO_BIAS:
1635
1636
            raise ValueError(f"{header} does not support bias got: {self.config.attn_bias_type}")

Reese Wang's avatar
Reese Wang committed
1637
1638
1639
1640
        if self.config.qkv_layout.is_thd():
            allowed_masks = [AttnMaskType.PADDING_CAUSAL_MASK]
        else:
            allowed_masks = [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]
1641
1642
1643
1644
1645
1646
        if self.config.attn_mask_type not in allowed_masks:
            raise ValueError(
                f"{header} only supports masking types: "
                f" {','.join(map(str, allowed_masks))} got: {self.config.attn_mask_type}"
            )

Reese Wang's avatar
Reese Wang committed
1647
        if not self.config.qkv_layout.is_thd() and self.config.max_segments_per_seq != 1:
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
            raise ValueError(
                f"{header} only supports max_segments_per_seq == 1 got:"
                f" {self.config.max_segments_per_seq}"
            )

        if self.config.dropout_probability != 0.0:
            raise ValueError(f"{header} does not support dropout")

        # We want to encourage use of scan loop to minimize unrolling and ensure more
        # predictable scheduling from XLA. The unrolled flavor will be supported but
        # not the prefered implementation.
        if not self.use_scanloop():
            warnings.warn(
                "Scan loop is disabled for fused ring attention. To enable set"
                " NVTE_FUSED_RING_ATTENTION_USE_SCAN=1 in your environment and"
                " add --xla_experimental_ignore_channel_id=true to XLA_FLAGS."
            )

    def get_step_config(self, attn_mask_type) -> _FusedAttnConfig:
        """Returns a _FusedAttnConfig for single CP step call to fused attention."""
        return _FusedAttnConfig(
            attn_bias_type=self.config.attn_bias_type,
            attn_mask_type=attn_mask_type,
Reese Wang's avatar
Reese Wang committed
1671
            qkv_layout=QKVLayout.BSHD_BS2HD,
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
            scaling_factor=self.config.scaling_factor,
            dropout_probability=self.config.dropout_probability,
            is_training=self.config.is_training,
            max_segments_per_seq=self.config.max_segments_per_seq,
            window_size=self.config.window_size,
            context_parallel_load_balanced=self.config.context_parallel_load_balanced,
            cp_axis=self.config.cp_axis,
        )

    def stack_kv(self, k, v):
        """Stacks k and v tensors if not stacked."""
        _not_used = jnp.zeros(0, dtype=k.dtype)
Reese Wang's avatar
Reese Wang committed
1684
1685
1686
1687
        if self.config.qkv_layout.is_kvpacked():
            return k
        if self.config.qkv_layout.is_separate():
            return jnp.stack([k, v], axis=2)
1688
1689
1690
1691
1692
        return _not_used

    def unstack_kv(self, kv):
        """Un-stacks k and v tensors if not stacked."""
        _not_used = jnp.zeros(0, dtype=kv.dtype)
Reese Wang's avatar
Reese Wang committed
1693
1694
1695
1696
        if self.config.qkv_layout.is_kvpacked():
            return kv, _not_used
        if self.config.qkv_layout.is_separate():
            return jnp.unstack(kv, axis=2)
1697
1698
1699
1700
1701
1702
        return _not_used, _not_used  # fall through

    def permute_kv(self, kv, cp_perm):
        """Permutes kv around the ring as described by cp_perm."""
        return lax_paral_op(kv, lax.ppermute, self.config.cp_axis, mesh=self.mesh, perm=cp_perm)

1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
    @staticmethod
    def correct_output_and_softmax_aux(output, softmax_aux, partial_output, partial_softmax_aux):
        """
        Corrects the output and softmax_aux tensor after each iteration of ring attention.

        See https://github.com/zhuzilin/ring-flash-attention/pull/34#issuecomment-2076126795 for
        derivation of this equation.
        """
        new_out = output - jax.nn.sigmoid(partial_softmax_aux - softmax_aux).transpose(
            0, 2, 1, 3
        ) * (output - partial_output)
        new_aux = softmax_aux - jax.nn.log_sigmoid(softmax_aux - partial_softmax_aux)
        return new_out, new_aux
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748

    def adjust_seqlen(self, seqlen, max_seqlen, idx):
        """Adjust the sequence length per step."""
        seqlen_of_curr_step = seqlen - max_seqlen * idx
        seqlen_of_curr_step = jnp.where(seqlen_of_curr_step < 0, 0, seqlen_of_curr_step)
        seqlen_per_step = jnp.where(
            seqlen_of_curr_step < max_seqlen, seqlen_of_curr_step, max_seqlen
        )
        return seqlen_per_step


class FusedRingAttnFwdPrimitive(FusedAttnFwdPrimitive):
    """
    Fused Ring Attention Forward Primitive
    """

    @staticmethod
    def partition(config, mesh, arg_infos, result_infos):
        is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1
        assert (
            not is_context_parallel or config.window_size[0] == -1
        ), "Sliding window attention is not supported when context parallelism is enabled"
        if not is_context_parallel:
            return FusedAttnFwdPrimitive.partition(config, mesh, arg_infos, result_infos)

        helper = _FusedAttnCPWithP2PHelper(mesh, config)
        helper.check_supported()

        out_sharding = result_infos[0].sharding
        softmax_aux_sharding = result_infos[1].sharding
        rng_state_sharding = seed_sharding = NamedSharding(
            mesh, PartitionSpec(get_all_mesh_axes(), None)
        )
1749
1750
1751
        arg_shardings = [arg_i.sharding for arg_i in arg_infos]
        arg_shardings[4] = seed_sharding
        arg_shardings = tuple(arg_shardings)
1752
1753
1754
1755
1756
1757
1758
        out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding)

        def ring_attn_fwd_impl(
            q,
            k,
            v,
            bias,
1759
            seed,
1760
1761
1762
1763
            q_seqlen,
            kv_seqlen,
            q_seq_offsets,
            k_seq_offsets,
1764
1765
1766
1767
            _q_segment_ids,
            _kv_segment_ids,
            _q_segment_pos,
            _kv_segment_pos,
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
        ):
            _not_used = jnp.zeros(0, dtype=v.dtype)

            # Combine KV tensors if separate for better permute scheduling and performance.
            # Eventually XLA should perform this automatically.
            kv = helper.stack_kv(k, v)

            batch, q_max_seqlen, head, _ = q.shape
            kv_max_seqlen = k.shape[1]

            cp_size = get_mesh_axis_size(config.cp_axis, mesh)
            cp_rank = get_mesh_axis_rank(config.cp_axis, mesh)
            cp_perm = [(i, (i + 1) % cp_size) for i in range(cp_size)]

1782
            output = jnp.zeros(q.shape).astype(jnp.float32)
1783
1784
1785
1786
1787
1788
1789
1790
            softmax_aux = jnp.full((batch, head, q_max_seqlen, 1), -jnp.inf, dtype=jnp.float32)

            # RNG shape should be the shared shape. This is unused for ring attention as we do not
            # support dropout currently.
            rng_state_shape = (result_infos[2].shape[0] // mesh.size, *result_infos[2].shape[1:])
            rng_state = jnp.zeros(rng_state_shape).astype(result_infos[2].dtype)

            def scan_kv_block(idx, carry):
1791
                kv, output, softmax_aux = carry
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803

                # Send KV block to next step so we can overlap compute.
                kv_next = helper.permute_kv(kv, cp_perm)

                def mask_compute(attn_mask_type):
                    q_seqlen_per_step = helper.adjust_seqlen(q_seqlen, q_max_seqlen, idx)
                    kv_seqlen_per_step = helper.adjust_seqlen(kv_seqlen, kv_max_seqlen, idx)
                    output_per_step, softmax_aux_per_step, _ = FusedAttnFwdPrimitive.impl(
                        q,
                        kv,
                        _not_used,
                        bias,
1804
                        seed,
1805
1806
1807
1808
                        q_seqlen_per_step,
                        kv_seqlen_per_step,
                        q_seq_offsets,
                        k_seq_offsets,
1809
1810
1811
1812
1813
                        _q_segment_ids,
                        _kv_segment_ids,
                        _q_segment_pos,
                        _kv_segment_pos,
                        config=helper.get_step_config(attn_mask_type),
1814
1815
1816
                    )
                    return output_per_step, softmax_aux_per_step

Reese Wang's avatar
Reese Wang committed
1817
1818
                causal_mask_compute = partial(mask_compute, AttnMaskType.CAUSAL_MASK)
                no_mask_compute = partial(mask_compute, AttnMaskType.NO_MASK)
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828

                def half_kv_no_mask_compute():
                    q_seqlen_per_step = helper.adjust_seqlen(q_seqlen, q_max_seqlen, idx)
                    kv_seqlen_per_step = helper.adjust_seqlen(kv_seqlen, kv_max_seqlen, idx) // 2
                    kv_part = lax.slice_in_dim(kv, 0, kv.shape[1] // 2, axis=1)
                    output_per_step, softmax_aux_per_step, _ = FusedAttnFwdPrimitive.impl(
                        q,
                        kv_part,
                        _not_used,
                        bias,
1829
                        seed,
1830
1831
1832
1833
                        q_seqlen_per_step,
                        kv_seqlen_per_step,
                        q_seq_offsets,
                        k_seq_offsets,
1834
1835
1836
1837
                        _q_segment_ids,
                        _kv_segment_ids,
                        _q_segment_pos,
                        _kv_segment_pos,
Reese Wang's avatar
Reese Wang committed
1838
                        config=helper.get_step_config(AttnMaskType.NO_MASK),
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
                    )
                    return output_per_step, softmax_aux_per_step

                def half_q_no_mask_compute():
                    q_seqlen_per_step = helper.adjust_seqlen(q_seqlen, q_max_seqlen, idx) // 2
                    kv_seqlen_per_step = helper.adjust_seqlen(kv_seqlen, kv_max_seqlen, idx)
                    q_part = lax.slice_in_dim(q, q_max_seqlen // 2, q_max_seqlen, axis=1)
                    output_per_step, softmax_aux_per_step, _ = FusedAttnFwdPrimitive.impl(
                        q_part,
                        kv,
                        _not_used,
                        bias,
1851
                        seed,
1852
1853
1854
1855
                        q_seqlen_per_step,
                        kv_seqlen_per_step,
                        q_seq_offsets,
                        k_seq_offsets,
1856
1857
1858
1859
                        _q_segment_ids,
                        _kv_segment_ids,
                        _q_segment_pos,
                        _kv_segment_pos,
Reese Wang's avatar
Reese Wang committed
1860
                        config=helper.get_step_config(AttnMaskType.NO_MASK),
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
                    )
                    output_per_step = jnp.concat([jnp.zeros_like(q_part), output_per_step], axis=1)
                    softmax_aux_per_step = jnp.concat(
                        [
                            jnp.full_like(softmax_aux_per_step, -jnp.inf),
                            softmax_aux_per_step,
                        ],
                        axis=2,
                    )
                    return output_per_step, softmax_aux_per_step

                def skip_compute():
                    output_per_step = jnp.zeros_like(q)
                    softmax_aux_per_step = jnp.full(
                        (batch, head, q.shape[1], 1), -jnp.inf, dtype=jnp.float32
                    )
                    return output_per_step, softmax_aux_per_step

Reese Wang's avatar
Reese Wang committed
1879
                if config.attn_mask_type == AttnMaskType.CAUSAL_MASK:
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
                    # This is for nested jax.lax.cond
                    def jax_cond_wrap():
                        if config.context_parallel_load_balanced:
                            return lax.cond(
                                (idx <= cp_rank), half_kv_no_mask_compute, half_q_no_mask_compute
                            )
                        return lax.cond((idx <= cp_rank), no_mask_compute, skip_compute)

                    output_per_step, softmax_aux_per_step = lax.cond(
                        idx == 0, causal_mask_compute, jax_cond_wrap
                    )
                else:
                    output_per_step, softmax_aux_per_step = no_mask_compute()

1894
1895
1896
1897
1898
                def skip_correction(output, softmax_aux, output_per_step, softmax_aux_per_step):
                    # No correction done here but we cast outputs to float32 and perform reduction
                    # in full precision.
                    # pylint: disable=unused-argument
                    return output_per_step.astype(jnp.float32), softmax_aux_per_step
1899

1900
1901
1902
1903
                def correction(output, softmax_aux, output_per_step, softmax_aux_per_step):
                    return helper.correct_output_and_softmax_aux(
                        output, softmax_aux, output_per_step, softmax_aux_per_step
                    )
1904

1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
                # first step there is no correction we get initial output and stats
                output, softmax_aux = lax.cond(
                    (idx == 0),
                    skip_correction,
                    correction,
                    output,
                    softmax_aux,
                    output_per_step,
                    softmax_aux_per_step,
                )

                return (kv_next, output, softmax_aux)

            carry = (kv, output, softmax_aux)
1919
1920
1921
1922
1923
            if helper.use_scanloop():
                carry = lax.fori_loop(0, cp_size, scan_kv_block, carry)
            else:
                for i in range(0, cp_size):
                    carry = scan_kv_block(i, carry)
1924
            (kv, output, softmax_aux) = carry
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976

            output = output.astype(q.dtype)
            return output, softmax_aux, rng_state

        return mesh, ring_attn_fwd_impl, out_shardings, arg_shardings


register_primitive(FusedRingAttnFwdPrimitive)


class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive):
    """
    Fused Ring Attention Backward Primitive
    """

    @staticmethod
    def partition(config, mesh, arg_infos, result_infos):
        is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1
        assert (
            not is_context_parallel or config.window_size[0] == -1
        ), "Sliding window attention is not supported when context parallelism is enabled"
        if not is_context_parallel:
            return FusedAttnBwdPrimitive.partition(config, mesh, arg_infos, result_infos)

        del result_infos
        q_spec = get_padded_spec(arg_infos[0])
        k_spec = get_padded_spec(arg_infos[1])
        v_spec = get_padded_spec(arg_infos[2])
        bias_spec = get_padded_spec(arg_infos[3])
        dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
        dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec))
        dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec))
        dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec))
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
        out_shardings = (dq_sharding, dk_sharding, dv_sharding, dbias_sharding)

        helper = _FusedAttnCPWithP2PHelper(mesh, config)
        helper.check_supported()

        def ring_attn_bwd_impl(
            q,
            k,
            v,
            bias,
            softmax_aux,
            rng_state,
            output,
            doutput,
            q_seqlen,
            kv_seqlen,
            q_seq_offsets,
            k_seq_offsets,
1977
1978
1979
1980
            _q_segment_ids,
            _kv_segment_ids,
            _q_segment_pos,
            _kv_segment_pos,
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
2024
        ):
            _not_used = jnp.zeros(0, dtype=output.dtype)

            # Combine KV tensors if separate for better permute scheduling and performance.
            # Eventually XLA should perform this automatically.
            kv = helper.stack_kv(k, v)

            q_max_seqlen = q.shape[1]
            kv_max_seqlen = k.shape[1]

            cp_size = get_mesh_axis_size(config.cp_axis, mesh)
            cp_rank = get_mesh_axis_rank(config.cp_axis, mesh)
            cp_perm = [(i, (i + 1) % cp_size) for i in range(cp_size)]

            dq = jnp.zeros_like(q)
            dk_dv = helper.stack_kv(jnp.zeros_like(k), jnp.zeros_like(v))
            dbias = jnp.zeros_like(bias)

            def scan_kv_block(idx, carry):

                kv, dq, dk_dv, dbias = carry

                # Start communication that feeds the next iteraton.
                # We further combine the tensors to improve overlap.

                kv_dk_dv = jnp.stack([kv, dk_dv])
                kv_dk_dv = helper.permute_kv(kv_dk_dv, cp_perm)

                def mask_compute(attn_mask_type):
                    q_seqlen_per_step = helper.adjust_seqlen(q_seqlen, q_max_seqlen, idx)
                    kv_seqlen_per_step = helper.adjust_seqlen(kv_seqlen, kv_max_seqlen, idx)
                    dq_per_step, dk_dv_per_step, _, dbias_per_step = FusedAttnBwdPrimitive.impl(
                        q,
                        kv,
                        _not_used,
                        bias,
                        softmax_aux,
                        rng_state,
                        output,
                        doutput,
                        q_seqlen_per_step,
                        kv_seqlen_per_step,
                        q_seq_offsets,
                        k_seq_offsets,
2025
2026
2027
2028
                        _q_segment_ids,
                        _kv_segment_ids,
                        _q_segment_pos,
                        _kv_segment_pos,
2029
2030
2031
2032
                        config=helper.get_step_config(attn_mask_type),
                    )
                    return dq_per_step, dk_dv_per_step, dbias_per_step

Reese Wang's avatar
Reese Wang committed
2033
2034
                causal_mask_compute = partial(mask_compute, AttnMaskType.CAUSAL_MASK)
                no_mask_compute = partial(mask_compute, AttnMaskType.NO_MASK)
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052

                def half_kv_no_mask_compute():
                    q_seqlen_per_step = helper.adjust_seqlen(q_seqlen, q_max_seqlen, idx)
                    kv_seqlen_per_step = helper.adjust_seqlen(kv_seqlen, kv_max_seqlen, idx) // 2
                    kv_part = lax.slice_in_dim(kv, 0, kv_max_seqlen // 2, axis=1)
                    dq_per_step, dk_dv_per_step, _, dbias_per_step = FusedAttnBwdPrimitive.impl(
                        q,
                        kv_part,
                        _not_used,
                        bias,
                        softmax_aux,
                        rng_state,
                        output,
                        doutput,
                        q_seqlen_per_step,
                        kv_seqlen_per_step,
                        q_seq_offsets,
                        k_seq_offsets,
2053
2054
2055
2056
                        _q_segment_ids,
                        _kv_segment_ids,
                        _q_segment_pos,
                        _kv_segment_pos,
Reese Wang's avatar
Reese Wang committed
2057
                        config=helper.get_step_config(AttnMaskType.NO_MASK),
2058
2059
2060
2061
2062
2063
2064
2065
2066
2067
2068
2069
2070
2071
2072
2073
2074
2075
2076
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
                    )
                    dk_dv_per_step = jnp.concat(
                        [dk_dv_per_step, jnp.zeros_like(dk_dv_per_step)], axis=1
                    )
                    return dq_per_step, dk_dv_per_step, dbias_per_step

                def half_q_no_mask_compute():
                    q_seqlen_per_step = helper.adjust_seqlen(q_seqlen, q_max_seqlen, idx) // 2
                    kv_seqlen_per_step = helper.adjust_seqlen(kv_seqlen, kv_max_seqlen, idx)

                    q_part = lax.slice_in_dim(q, q_max_seqlen // 2, q_max_seqlen, axis=1)
                    doutput_part = lax.slice_in_dim(
                        doutput, q_max_seqlen // 2, q_max_seqlen, axis=1
                    )
                    output_part = lax.slice_in_dim(output, q_max_seqlen // 2, q_max_seqlen, axis=1)

                    softmax_aux_part = lax.slice_in_dim(
                        softmax_aux, q_max_seqlen // 2, q_max_seqlen, axis=2
                    )

                    dq_per_step, dk_dv_per_step, _, dbias_per_step = FusedAttnBwdPrimitive.impl(
                        q_part,
                        kv,
                        _not_used,
                        bias,
                        softmax_aux_part,
                        rng_state,
                        output_part,
                        doutput_part,
                        q_seqlen_per_step,
                        kv_seqlen_per_step,
                        q_seq_offsets,
                        k_seq_offsets,
2091
2092
2093
2094
                        _q_segment_ids,
                        _kv_segment_ids,
                        _q_segment_pos,
                        _kv_segment_pos,
Reese Wang's avatar
Reese Wang committed
2095
                        config=helper.get_step_config(AttnMaskType.NO_MASK),
2096
2097
2098
2099
2100
2101
2102
                    )
                    dq_per_step = jnp.concat([jnp.zeros_like(dq_per_step), dq_per_step], axis=1)
                    return dq_per_step, dk_dv_per_step, dbias_per_step

                def skip_compute():
                    return jnp.zeros_like(q), jnp.zeros_like(kv), jnp.zeros_like(bias)

Reese Wang's avatar
Reese Wang committed
2103
                if config.attn_mask_type == AttnMaskType.CAUSAL_MASK:
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115
2116
2117
2118
2119
2120
                    # This is for nested jax.lax.cond
                    def jax_cond_wrap():
                        if config.context_parallel_load_balanced:
                            return lax.cond(
                                (idx <= cp_rank), half_kv_no_mask_compute, half_q_no_mask_compute
                            )
                        return lax.cond((idx <= cp_rank), no_mask_compute, skip_compute)

                    dq_per_step, dk_dv_per_step, dbias_per_step = lax.cond(
                        idx == 0, causal_mask_compute, jax_cond_wrap
                    )
                else:
                    dq_per_step, dk_dv_per_step, dbias_per_step = no_mask_compute()

                kv_next, dk_dv = jnp.unstack(kv_dk_dv)
                dq = dq + dq_per_step
                dk_dv = dk_dv + dk_dv_per_step
Reese Wang's avatar
Reese Wang committed
2121
                if config.attn_bias_type is not AttnBiasType.NO_BIAS:
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
                    dbias = dbias + dbias_per_step

                return (kv_next, dq, dk_dv, dbias)

            carry = (kv, dq, dk_dv, dbias)
            if helper.use_scanloop():
                carry = lax.fori_loop(0, cp_size, scan_kv_block, carry)
            else:
                for i in range(0, cp_size):
                    carry = scan_kv_block(i, carry)
            (kv, dq, dk_dv, dbias) = carry

            # Final permute to put gradients back to their final resting place.
            dk_dv = helper.permute_kv(dk_dv, cp_perm)

            global_dbias = dbias
Reese Wang's avatar
Reese Wang committed
2138
            if config.attn_bias_type is not AttnBiasType.NO_BIAS:
2139
2140
2141
2142
2143
2144
2145
2146
2147
2148
2149
                global_dbias = all_reduce_sum_along_dp_fsdp(dbias, mesh)

            dk, dv = helper.unstack_kv(dk_dv)
            return dq, dk, dv, global_dbias

        return mesh, ring_attn_bwd_impl, out_shardings, arg_shardings


register_primitive(FusedRingAttnBwdPrimitive)


Reese Wang's avatar
Reese Wang committed
2150
2151
2152
2153
2154
2155
2156
2157
2158
2159
2160
2161
2162
2163
2164
2165
2166
2167
2168
2169
2170
2171
2172
2173
2174
2175
2176
2177
2178
2179
2180
2181
2182
2183
2184
2185
2186
2187
2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
2198
2199
2200
2201
2202
2203
2204
2205
2206
2207
2208
2209
2210
2211
2212
2213
2214
2215
2216
2217
2218
2219
2220
2221
2222
2223
2224
2225
2226
2227
2228
2229
2230
2231
2232
2233
2234
2235
2236
2237
2238
2239
2240
2241
2242
2243
2244
2245
2246
2247
2248
2249
2250
2251
2252
2253
2254
2255
2256
2257
2258
2259
2260
2261
2262
2263
2264
2265
2266
2267
2268
2269
2270
2271
2272
2273
2274
2275
2276
2277
2278
2279
2280
2281
2282
2283
2284
2285
2286
2287
2288
2289
2290
2291
2292
2293
2294
2295
2296
2297
2298
2299
2300
2301
2302
2303
2304
2305
2306
2307
2308
2309
2310
2311
2312
2313
2314
2315
2316
2317
2318
2319
2320
2321
2322
2323
2324
2325
2326
2327
2328
2329
2330
2331
2332
2333
2334
2335
2336
2337
2338
2339
2340
2341
2342
2343
2344
2345
2346
2347
2348
2349
2350
2351
2352
2353
2354
2355
2356
2357
2358
2359
2360
2361
2362
2363
2364
2365
2366
2367
2368
2369
2370
2371
2372
2373
2374
2375
2376
2377
2378
2379
2380
2381
2382
2383
2384
2385
2386
2387
2388
2389
2390
2391
2392
2393
2394
2395
2396
2397
2398
2399
2400
2401
2402
2403
2404
2405
2406
2407
2408
2409
2410
2411
2412
2413
2414
class FusedRingAttnStripedFwdPrimitive(FusedAttnFwdPrimitive):
    """
    Fused Striped Ring Attention Forward Primitive
    """

    @staticmethod
    def partition(config, mesh, arg_infos, result_infos):
        is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1
        assert (
            not is_context_parallel or config.window_size[0] == -1
        ), "Sliding window attention is not supported when context parallelism is enabled"
        if not is_context_parallel:
            return FusedAttnFwdPrimitive.partition(config, mesh, arg_infos, result_infos)

        helper = _FusedAttnCPWithP2PHelper(mesh, config)
        helper.check_supported()

        out_sharding = result_infos[0].sharding
        softmax_aux_sharding = result_infos[1].sharding
        rng_state_sharding = seed_sharding = NamedSharding(
            mesh, PartitionSpec(get_all_mesh_axes(), None)
        )
        arg_shardings = [arg_i.sharding for arg_i in arg_infos]
        arg_shardings[4] = seed_sharding
        arg_shardings = tuple(arg_shardings)
        out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding)

        def fwd_impl(
            q,
            k,
            v,
            bias,
            seed,
            q_seqlen,
            kv_seqlen,
            q_seq_offsets,
            k_seq_offsets,
            q_segment_ids,
            kv_segment_ids,
            q_segment_pos,
            kv_segment_pos,
        ):
            if q_segment_ids.size == 0 or kv_segment_ids.size == 0:
                raise ValueError("THD + ring attn only supports passing seqment_ids/pos")

            _not_used = jnp.zeros(0, dtype=v.dtype)

            # Combine KV tensors if separate for better permute scheduling and performance.
            # Eventually XLA should perform this automatically.
            kv = helper.stack_kv(k, v)
            if not config.qkv_layout.is_qkvpacked():
                subblock_config = replace(config, qkv_layout=config.qkv_layout.to_kvpacked())
            else:
                subblock_config = config

            cp_size = get_mesh_axis_size(config.cp_axis, mesh)
            cp_perm = [(i, (i + 1) % cp_size) for i in range(cp_size)]

            batch, q_max_seqlen, head, _ = q.shape
            output = jnp.zeros(q.shape).astype(jnp.float32)
            softmax_aux = jnp.zeros((batch, q_max_seqlen, head, 1), dtype=jnp.float32)

            # RNG shape should be the shared shape. This is unused for ring attention as we do not
            # support dropout currently.
            rng_state_shape = (result_infos[2].shape[0] // mesh.size, *result_infos[2].shape[1:])
            rng_state = jnp.zeros(rng_state_shape).astype(result_infos[2].dtype)

            def scan_kv_block(idx, carry):
                kv, kv_segment_ids, kv_segment_pos, output, softmax_aux = carry

                # TODO(rewang): To check whether we need special handle for the last idx
                # Send KV block to next step so we can overlap compute.
                kv_next = helper.permute_kv(kv, cp_perm)
                kv_segment_ids_next = helper.permute_kv(kv_segment_ids, cp_perm)
                kv_segment_pos_next = helper.permute_kv(kv_segment_pos, cp_perm)

                output_per_step, softmax_aux_per_step, _ = FusedAttnFwdPrimitive.impl(
                    q,
                    kv,
                    _not_used,
                    bias,
                    seed,
                    q_seqlen,
                    kv_seqlen,
                    q_seq_offsets,
                    k_seq_offsets,
                    q_segment_ids,
                    kv_segment_ids,
                    q_segment_pos,
                    kv_segment_pos,
                    subblock_config,
                )

                # TODO(rewang): THD softmax_aux layout is acutally [B, S, H]
                softmax_aux_per_step = softmax_aux_per_step.reshape((batch, q_max_seqlen, head, 1))

                def skip_correction(_output, _softmax_aux, output_per_step, softmax_aux_per_step):
                    # No correction done here but we cast outputs to float32 and perform reduction
                    # in full precision.
                    return output_per_step.astype(jnp.float32), softmax_aux_per_step

                def correction(output, softmax_aux, output_per_step, softmax_aux_per_step):
                    new_out = output - jax.nn.sigmoid(softmax_aux_per_step - softmax_aux) * (
                        output - output_per_step
                    )
                    new_aux = softmax_aux - jax.nn.log_sigmoid(softmax_aux - softmax_aux_per_step)
                    return new_out, new_aux

                # first step there is no correction we get initial output and stats
                output, softmax_aux = lax.cond(
                    idx == 0,
                    skip_correction,
                    correction,
                    output,
                    softmax_aux,
                    output_per_step,
                    softmax_aux_per_step,
                )

                return (kv_next, kv_segment_ids_next, kv_segment_pos_next, output, softmax_aux)

            carry = (kv, kv_segment_ids, kv_segment_pos, output, softmax_aux)
            if helper.use_scanloop():
                carry = lax.fori_loop(0, cp_size, scan_kv_block, carry)
            else:
                for i in range(0, cp_size):
                    carry = scan_kv_block(i, carry)
            (_, _, _, output, softmax_aux) = carry

            softmax_aux = softmax_aux.reshape((batch, head, q_max_seqlen, 1))

            return output.astype(q.dtype), softmax_aux, rng_state

        return mesh, fwd_impl, out_shardings, arg_shardings


register_primitive(FusedRingAttnStripedFwdPrimitive)


class FusedRingAttnStripedBwdPrimitive(FusedAttnBwdPrimitive):
    """
    Fused Striped Ring Attention Backward Primitive
    """

    @staticmethod
    def partition(config, mesh, arg_infos, result_infos):
        is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1
        assert (
            not is_context_parallel or config.window_size[0] == -1
        ), "Sliding window attention is not supported when context parallelism is enabled"
        if not is_context_parallel:
            return FusedAttnBwdPrimitive.partition(config, mesh, arg_infos, result_infos)

        arg_shardings = tuple(arg.sharding for arg in arg_infos)
        # dq, dk, dv, dbias sharding = q, k, v, bias sharding
        out_shardings = tuple(arg.sharding for arg in arg_infos[:4])

        helper = _FusedAttnCPWithP2PHelper(mesh, config)
        helper.check_supported()

        def bwd_impl(
            q,
            k,
            v,
            bias,
            softmax_aux,
            rng_state,
            output,
            doutput,
            q_seqlen,
            kv_seqlen,
            q_seq_offsets,
            k_seq_offsets,
            q_segment_ids,
            kv_segment_ids,
            q_segment_pos,
            kv_segment_pos,
        ):

            if q_segment_ids.size == 0 or kv_segment_ids.size == 0:
                raise ValueError("THD + ring attn only supports passing seqment_ids/pos")

            _not_used = jnp.zeros(0, dtype=output.dtype)

            # Combine KV tensors if separate for better permute scheduling and performance.
            # Eventually XLA should perform this automatically.
            kv = helper.stack_kv(k, v)
            if not config.qkv_layout.is_qkvpacked():
                subblock_config = replace(config, qkv_layout=config.qkv_layout.to_kvpacked())
            else:
                subblock_config = config

            cp_size = get_mesh_axis_size(config.cp_axis, mesh)
            cp_perm = [(i, (i + 1) % cp_size) for i in range(cp_size)]

            dq = jnp.zeros_like(q)
            dkv = jnp.zeros_like(kv)
            dbias = jnp.zeros_like(bias)

            def scan_kv_block(_idx, carry):
                kv, kv_segment_ids, kv_segment_pos, dq, dkv, dbias = carry

                # Start communication that feeds the next iteration.
                # We further combine the tensors to improve overlap.
                kv_dkv = jnp.stack([kv, dkv])
                kv_dkv = helper.permute_kv(kv_dkv, cp_perm)
                kv_segment_ids_next = helper.permute_kv(kv_segment_ids, cp_perm)
                kv_segment_pos_next = helper.permute_kv(kv_segment_pos, cp_perm)

                def compute():
                    dq_per_step, dkv_per_step, _, dbias_per_step = FusedAttnBwdPrimitive.impl(
                        q,
                        kv,
                        _not_used,
                        bias,
                        softmax_aux,
                        rng_state,
                        output,
                        doutput,
                        q_seqlen,
                        kv_seqlen,
                        q_seq_offsets,
                        k_seq_offsets,
                        q_segment_ids,
                        kv_segment_ids,
                        q_segment_pos,
                        kv_segment_pos,
                        config=subblock_config,
                    )
                    return dq_per_step, dkv_per_step, dbias_per_step

                dq_per_step, dkv_per_step, dbias_per_step = compute()

                kv_next, dkv = jnp.unstack(kv_dkv)
                dq += dq_per_step
                dkv += dkv_per_step
                if config.attn_bias_type is not AttnBiasType.NO_BIAS:
                    dbias = dbias + dbias_per_step

                return (kv_next, kv_segment_ids_next, kv_segment_pos_next, dq, dkv, dbias)

            carry = (kv, kv_segment_ids, kv_segment_pos, dq, dkv, dbias)
            if helper.use_scanloop():
                carry = lax.fori_loop(0, cp_size, scan_kv_block, carry)
            else:
                for idx in range(cp_size):
                    carry = scan_kv_block(idx, carry)
            (_, _, _, dq, dkv, dbias) = carry

            # Final permute to put gradients back to their final resting place.
            dkv = helper.permute_kv(dkv, cp_perm)

            global_dbias = dbias
            if config.attn_bias_type is not AttnBiasType.NO_BIAS:
                global_dbias = all_reduce_sum_along_dp_fsdp(dbias, mesh)

            dk, dv = helper.unstack_kv(dkv)
            return dq, dk, dv, global_dbias

        return mesh, bwd_impl, out_shardings, arg_shardings


register_primitive(FusedRingAttnStripedBwdPrimitive)


2415
2416
2417
2418
2419
2420
2421
2422
2423
2424
def _maybe_context_parallel_axis(cp_axis: str):
    if not cp_axis:
        gmr = global_mesh_resource()
        if gmr is not None:
            cp_axis = gmr.cp_resource
        else:
            cp_axis = ""
    return cp_axis


2425
2426
2427
def fused_attn_fwd(
    qkv: Tuple[jnp.ndarray, ...],
    bias: Optional[jnp.ndarray],
2428
    sequence_descriptor: SequenceDescriptor,
2429
    seed: Optional[jnp.ndarray],
Reese Wang's avatar
Reese Wang committed
2430
2431
2432
    attn_bias_type: AttnBiasType,
    attn_mask_type: AttnMaskType,
    qkv_layout: QKVLayout,
2433
2434
2435
    scaling_factor: float,
    dropout_probability: float,
    is_training: bool,
2436
    max_segments_per_seq: int,
2437
    window_size: Optional[Tuple[int, int]] = None,
2438
    context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT,
2439
2440
    context_parallel_causal_load_balanced: bool = False,
    context_parallel_axis: str = "",
2441
) -> jnp.ndarray:
2442
    """
2443
2444
2445
2446
2447
2448
2449
2450
2451
2452
2453
2454
2455
2456
2457
2458
2459
2460
2461
2462
    Perform the forward pass of with cuDNN fused attention implementations.

    This function implements the following formula:
        BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2
    Args:
        qkv (Tuple[jnp.ndarray, ...]): A tuple containing query, key, and value tensors.
        It supports three formats:
            - `(qkv_packed,)`: For interleaved QKV packed format, typically used when query, key,
              and value have the same shape (e.g., self-attention).
            - `(query, kv_packed)`: For separate query and KV packed format, typically used when
              query has a different shape (e.g., cross-attention).
            - `(query, key, value)`: For separate query, key, and value tensors.
        bias (Optional[jnp.ndarray]): An optional bias tensor to be added to the attention scores.
        q_seqlen (jnp.ndarray): Sequence lengths for the query, with shape [batch,].
        kv_seqlen (jnp.ndarray): Sequence lengths for the key and value, with shape [batch,].
        q_seq_offsets (jnp.ndarray):
            The offsets in the sequence dim for the query, with shape [batch + 1,].
        kv_seq_offsets (jnp.ndarray):
            The offsets in the sequence dim for the query, with shape [batch + 1,].
        seed (Optional[jnp.ndarray]): Optional random seed for dropout.
Reese Wang's avatar
Reese Wang committed
2463
2464
2465
        attn_bias_type (AttnBiasType): Type of attention bias.
        attn_mask_type (AttnMaskType): Type of attention mask.
        qkv_layout (QKVLayout): Layout of the QKV tensors.
2466
2467
2468
        scaling_factor (float): Scaling factor for the attention scores.
        dropout_probability (float): Dropout probability to apply during attention.
        is_training (bool): Flag indicating whether the model is in training mode.
2469
2470
2471
2472
2473
        max_segments_per_seq (int):
            Indicating the maximum number of segments inside a sequence. This parameter is to
            constrain the limit usage and need to be static during the e2e training. The XLA compile
            time and memory consumption is proportional to `max_segments_per_seq`.
        window_size (Optional[Tuple[int, int]]): Sliding window size.
2474
2475
2476
        context_parallel_causal_load_balanced (bool):
            Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
        context_parallel_axis (str): The name of the context parallel axis.
2477
2478
    Returns:
        (jnp.ndarray): The output tensor from the fused attention.
2479
    """
2480
2481
2482
    seed = _FusedAttnRNGStateChecker().check_seed(seed, dropout_probability, is_training)
    # For optional tensors, which custom calls doesn't support None
    _not_used = jnp.zeros(0, dtype=qkv[0].dtype)
2483

Reese Wang's avatar
Reese Wang committed
2484
2485
2486
2487
2488
2489
2490
2491
2492
2493
2494
2495
2496
2497
2498
2499
2500
    if qkv_layout.is_qkvpacked():
        assert len(qkv) == 1, f"qkv=(packed_qkv,) is expected with {qkv_layout=} but got {qkv=}"
        qkv_for_primitive = [*qkv, _not_used, _not_used]
    elif qkv_layout.is_kvpacked():
        assert (
            len(qkv) == 2
        ), f"qkv=(query, packed_kv) is expected with {qkv_layout=} but got {qkv=}"
        qkv_for_primitive = [*qkv, _not_used]
    elif qkv_layout.is_separate():
        assert (
            len(qkv) == 3
        ), f"qkv=(query, key, value) is expected with {qkv_layout=} but got {qkv=}"
        qkv_for_primitive = qkv
    else:
        raise ValueError(f"Unknown {qkv_layout=}")

    if attn_bias_type == AttnBiasType.NO_BIAS:
2501
        assert bias is None
2502
        bias = jnp.zeros(0, dtype=qkv[0].dtype)
2503

2504
    fused_config = _FusedAttnConfig(
2505
2506
        attn_bias_type=attn_bias_type,
        attn_mask_type=attn_mask_type,
2507
        qkv_layout=qkv_layout,
2508
2509
2510
        scaling_factor=scaling_factor,
        dropout_probability=dropout_probability,
        is_training=is_training,
2511
        max_segments_per_seq=max_segments_per_seq,
2512
        window_size=(-1, -1) if window_size is None else window_size,
2513
2514
2515
2516
        context_parallel_load_balanced=context_parallel_causal_load_balanced,
        cp_axis=_maybe_context_parallel_axis(context_parallel_axis),
    )

2517
    primitive = None
2518
2519
    match context_parallel_strategy:
        case CPStrategy.DEFAULT | CPStrategy.ALL_GATHER:
2520
            primitive = FusedAttnCPWithAllGatherFwdPrimitive.outer_primitive
2521
        case CPStrategy.RING:
Reese Wang's avatar
Reese Wang committed
2522
2523
2524
2525
2526
            # We must use stripe attention for THD-RING
            if qkv_layout.is_thd():
                primitive = FusedRingAttnStripedFwdPrimitive.outer_primitive
            else:
                primitive = FusedRingAttnFwdPrimitive.outer_primitive
2527

2528
2529
    seq_desc_flatten, _ = jax.tree.flatten(sequence_descriptor)
    return primitive.bind(
2530
2531
2532
        *qkv_for_primitive,
        bias,
        seed,
2533
        *seq_desc_flatten,
2534
        config=fused_config,
2535
2536
2537
    )


2538
2539
2540
def fused_attn_bwd(
    qkv: Tuple[jnp.ndarray, ...],
    bias: Optional[jnp.ndarray],
2541
2542
2543
2544
    softmax_aux: jnp.ndarray,
    rng_state: jnp.ndarray,
    output: jnp.ndarray,
    doutput: jnp.ndarray,
2545
    sequence_descriptor: SequenceDescriptor,
Reese Wang's avatar
Reese Wang committed
2546
2547
2548
    attn_bias_type: AttnBiasType,
    attn_mask_type: AttnMaskType,
    qkv_layout: QKVLayout,
2549
2550
2551
    scaling_factor: float,
    dropout_probability: float,
    is_training: bool,
2552
    max_segments_per_seq: int,
2553
    window_size: Optional[Tuple[int, int]] = None,
2554
    context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT,
2555
2556
    context_parallel_causal_load_balanced: bool = False,
    context_parallel_axis: str = "",
2557
):
2558
    """
2559
2560
2561
2562
2563
2564
2565
2566
2567
2568
2569
2570
2571
2572
2573
2574
2575
2576
2577
2578
2579
    Perform the backward pass of the cuDNN fused attention implementations.

    Args:
        qkv (Tuple[jnp.ndarray, ...]): A tuple containing the original query, key, and value tensors
        used in the forward pass. It supports three formats:
            - `(qkv_packed,)`: For interleaved QKV packed format, typically used when query, key,
              and value have the same shape (e.g., self-attention).
            - `(query, kv_packed)`: For separate query and KV packed format, typically used when
              query has a different shape (e.g., cross-attention).
            - `(query, key, value)`: For separate query, key, and value tensors.
        bias (Optional[jnp.ndarray]): An optional bias tensor to be added to the attention scores.
        softmax_aux (jnp.ndarray): Auxiliary tensors from the softmax step used in the forward pass.
        rng_state (jnp.ndarray): Auxiliary tensors to save the random state in the forward pass.
        output (jnp.ndarray): The output tensor from the forward pass.
        doutput (jnp.ndarray): The gradient with respect to the output.
        q_seqlen (jnp.ndarray): Sequence lengths for the query, with shape [batch,].
        kv_seqlen (jnp.ndarray): Sequence lengths for the key and value, with shape [batch,].
        q_seq_offsets (jnp.ndarray):
            The offsets in the sequence dim for the query, with shape [batch + 1,].
        kv_seq_offsets (jnp.ndarray):
            The offsets in the sequence dim for the query, with shape [batch + 1,].
Reese Wang's avatar
Reese Wang committed
2580
2581
2582
        attn_bias_type (AttnBiasType): Type of attention bias.
        attn_mask_type (AttnMaskType): Type of attention mask.
        qkv_layout (QKVLayout): Layout of the QKV tensors.
2583
2584
2585
        scaling_factor (float): Scaling factor for the attention scores.
        dropout_probability (float): Dropout probability to apply during attention.
        is_training (bool): Flag indicating whether the model is in training mode.
2586
2587
2588
2589
2590
        max_segments_per_seq (int):
            Indicating the maximum number of segments inside a sequence. This parameter is to
            constrain the limit usage and need to be static during the e2e training. The XLA compile
            time and memory consumption is proportional to `max_segments_per_seq`.
        window_size (Optional[Tuple[int, int]]): Sliding window size .
2591
2592
2593
        context_parallel_causal_load_balanced (bool):
            Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
        context_parallel_axis (str): The name of the context parallel axis.
2594
2595
2596
2597
2598
    Returns:
        Tuple[jnp.ndarray, ...], jnp.ndarray:
        - The first tuple contains the gradients with respect to the input `qkv` tensors in the
          same format as the input `qkv`.
        - The second value is the gradient with respect to `bias`, or `None` if `bias` is `None`.
2599
    """
2600
2601
2602
    # For optional tensors, which custom calls doesn't support None
    _not_used = jnp.zeros(0, dtype=qkv[0].dtype)

Reese Wang's avatar
Reese Wang committed
2603
2604
2605
2606
2607
2608
2609
2610
2611
2612
2613
2614
2615
2616
2617
2618
2619
    if qkv_layout.is_qkvpacked():
        assert len(qkv) == 1, f"qkv=(packed_qkv,) is expected with {qkv_layout=} but got {qkv=}"
        qkv_for_primitive = [*qkv, _not_used, _not_used]
    elif qkv_layout.is_kvpacked():
        assert (
            len(qkv) == 2
        ), f"qkv=(query, packed_kv) is expected with {qkv_layout=} but got {qkv=}"
        qkv_for_primitive = [*qkv, _not_used]
    elif qkv_layout.is_separate():
        assert (
            len(qkv) == 3
        ), f"qkv=(query, key, value) is expected with {qkv_layout=} but got {qkv=}"
        qkv_for_primitive = qkv
    else:
        raise ValueError(f"Unknown {qkv_layout=}")

    if attn_bias_type == AttnBiasType.NO_BIAS:
2620
        assert bias is None
2621
        bias = jnp.zeros(0, dtype=qkv[0].dtype)
2622

2623
2624
2625
2626
2627
2628
2629
2630
    fused_config = _FusedAttnConfig(
        attn_bias_type=attn_bias_type,
        attn_mask_type=attn_mask_type,
        qkv_layout=qkv_layout,
        scaling_factor=scaling_factor,
        dropout_probability=dropout_probability,
        is_training=is_training,
        max_segments_per_seq=max_segments_per_seq,
2631
        window_size=(-1, -1) if window_size is None else window_size,
2632
2633
2634
2635
        context_parallel_load_balanced=context_parallel_causal_load_balanced,
        cp_axis=_maybe_context_parallel_axis(context_parallel_axis),
    )

2636
    primitive = None
2637
2638
    match context_parallel_strategy:
        case CPStrategy.DEFAULT | CPStrategy.ALL_GATHER:
2639
            primitive = FusedAttnCPWithAllGatherBwdPrimitive.outer_primitive
2640
        case CPStrategy.RING:
Reese Wang's avatar
Reese Wang committed
2641
2642
2643
2644
            if qkv_layout.is_thd():
                primitive = FusedRingAttnStripedBwdPrimitive.outer_primitive
            else:
                primitive = FusedRingAttnBwdPrimitive.outer_primitive
2645
2646
2647

    seq_desc_flatten, _ = jax.tree.flatten(sequence_descriptor)
    *qkv_grads, bias_grad = primitive.bind(
2648
        *qkv_for_primitive,
2649
2650
2651
2652
2653
        bias,
        softmax_aux,
        rng_state,
        output,
        doutput,
2654
        *seq_desc_flatten,
2655
        config=fused_config,
2656
    )
2657
    return tuple(qkv_grads[: len(qkv)]), bias_grad