test_distributed_fused_attn.py 22.9 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.

import pytest
6
from functools import partial
7
8
9
10
11
12

import jax
import jax.numpy as jnp
import numpy as np
from flax.linen import dot_product_attention
from jax import random
13
from jax.sharding import Mesh, NamedSharding, PartitionSpec
14
15
16
17
18
19
from distributed_test_base import (
    generate_configs,
    generate_context_parallel_configs,
    generate_collectives_count,
    compare_ops,
)
20
21
22
23
24
25
from utils import (
    make_causal_mask,
    make_self_mask,
    assert_allclose,
    print_debug_tensor_stats,
)
26
from transformer_engine.jax import fp8_autocast
27
28
from transformer_engine.jax.attention import (
    is_fused_attn_kernel_available,
29
    fused_attn,
30
31
    AttnBiasType,
    AttnMaskType,
32
    QKVLayout,
33
34
35
    QKVFormat,
    reorder_causal_load_balancing,
    inverse_reorder_causal_load_balancing,
36
    CPStrategy,
37
)
38
from transformer_engine.jax.sharding import MeshResource
39

40
41
# We will use the golden reference model from our non distributed attention test fixture.
from test_fused_attn import general_dot_product_attention, make_mask
42
43
44
45
46
47

DTYPES = [jnp.float16, jnp.bfloat16]


class TestDistributedSelfAttn:

48
49
50
    def generate_collectives_count_ref(
        self, mesh_shape, mesh_axes, mesh_resource, with_bias, shape, dtype
    ):
51
52
53
54
55
56
57
58
        jax_dtype = jax.dtypes.canonicalize_dtype(dtype)
        _, seqlen, _, heads, _ = shape
        is_dp_enabled = mesh_resource.dp_resource is not None
        tp_size = 1
        if mesh_resource.tp_resource is not None:
            idx = mesh_axes.index(mesh_resource.tp_resource)
            tp_size = mesh_shape[idx]

59
        all_reduce_loss_bytes = 4  # 1 * FP32
60
61
62
63
64
65
66
67
68
69
        bias_bytes = int(with_bias) * (heads // tp_size) * seqlen * seqlen * jax_dtype.itemsize
        allreduce_total_bytes = all_reduce_loss_bytes + (bias_bytes * is_dp_enabled)
        # for loss and dbias
        return generate_collectives_count(allreduce=allreduce_total_bytes, allgather=0, other=0)

    def generate_inputs(self, shape, mesh_resource, with_bias, attn_mask_type, dtype):
        batch, seqlen, _, heads, _ = shape

        qkv = random.normal(random.PRNGKey(1124), shape, dtype=dtype)

70
71
72
73
74
        bias = (
            random.normal(random.PRNGKey(1125), (1, heads, seqlen, seqlen), dtype)
            if with_bias
            else None
        )
75
76
77
78
79
80
81

        mask = None
        if attn_mask_type == AttnMaskType.PADDING_MASK:
            mask = make_causal_mask(batch, seqlen)
        elif attn_mask_type == AttnMaskType.CAUSAL_MASK:
            mask = make_self_mask(batch, seqlen)

82
83
84
85
86
87
88
89
90
91
92
        qkv_pspec = PartitionSpec(
            mesh_resource.dp_resource, None, None, mesh_resource.tp_resource, None
        )
        bias_pspec = (
            PartitionSpec(None, mesh_resource.tp_resource, None, None) if with_bias else None
        )
        mask_pspec = (
            PartitionSpec(mesh_resource.dp_resource, None, None, None)
            if attn_mask_type != AttnMaskType.NO_MASK
            else None
        )
93
94
95

        return (qkv, bias, mask), (qkv_pspec, bias_pspec, mask_pspec)

96
97
    @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
    @pytest.mark.parametrize("data_shape", [[32, 512, 3, 12, 64], [32, 1024, 3, 16, 128]])
98
    @pytest.mark.parametrize(
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
        "attn_bias_type",
        [AttnBiasType.NO_BIAS, AttnBiasType.PRE_SCALE_BIAS, AttnBiasType.POST_SCALE_BIAS],
    )
    @pytest.mark.parametrize(
        "attn_mask_type", [AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK]
    )
    @pytest.mark.parametrize("dtype", DTYPES)
    def test_self_attn(
        self,
        device_count,
        mesh_shape,
        mesh_axes,
        mesh_resource,
        data_shape,
        attn_bias_type,
        attn_mask_type,
        dtype,
    ):
117
118
119
120
        dropout_prob = 0.0
        is_training = True
        scaling_factor = 1.0

121
        _, seqlen, _, num_head, hidden = data_shape
122

123
124
125
126
127
128
129
130
131
132
133
134
        if not is_fused_attn_kernel_available(
            dtype,
            dtype,
            QKVLayout.BS3HD,
            attn_bias_type,
            attn_mask_type,
            dropout_prob,
            num_head,
            num_head,
            seqlen,
            seqlen,
            hidden,
135
            None,  # no window
136
        ):
137
            pytest.skip(f"No FusedAttn backend found")
138
139
140

        def target_func(qkv, bias, mask):
            return jnp.mean(
141
142
                fused_attn(
                    (qkv,),
143
144
145
146
147
                    bias,
                    mask,
                    None,
                    attn_bias_type=attn_bias_type,
                    attn_mask_type=attn_mask_type,
148
                    qkv_layout=QKVLayout.BS3HD,
149
150
151
152
153
                    scaling_factor=scaling_factor,
                    dropout_probability=dropout_prob,
                    is_training=is_training,
                )
            )
154
155
156
157
158
159
160

        def ref_func(qkv, bias, mask):
            query, key, value = jnp.split(qkv, [1, 2], axis=-3)
            query = jnp.squeeze(query)
            key = jnp.squeeze(key)
            value = jnp.squeeze(value)

161
162
163
164
165
166
167
168
169
170
171
            output = dot_product_attention(
                query,
                key,
                value,
                bias=bias,
                mask=mask,
                deterministic=is_training,
                dropout_rate=dropout_prob,
                dropout_rng=None,
                dtype=jnp.float32,
            )
172
173
174
175

            return jnp.mean(output).astype(dtype)

        with_bias = attn_bias_type != AttnBiasType.NO_BIAS
176
177
178
179
180
181
        (qkv, bias, mask), (qkv_pspec, bias_pspec, mask_pspec) = self.generate_inputs(
            data_shape, mesh_resource, with_bias, attn_mask_type, dtype
        )
        collective_count_ref = self.generate_collectives_count_ref(
            mesh_shape, mesh_axes, mesh_resource, with_bias, data_shape, dtype
        )
182
183
184
185
        devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
        mesh = Mesh(devices, mesh_axes)
        with mesh, fp8_autocast(mesh_resource=mesh_resource):
            qkv_ = jax.device_put(qkv, NamedSharding(mesh, qkv_pspec))
186
187
188
189
190
191
            bias_ = (
                jax.device_put(bias, NamedSharding(mesh, bias_pspec)) if bias is not None else bias
            )
            mask_ = (
                jax.device_put(mask, NamedSharding(mesh, mask_pspec)) if mask is not None else mask
            )
192
193
194
195

            grad_args = (0, 1) if with_bias else (0,)
            out_grad_shardings = (qkv_pspec, bias_pspec) if with_bias else (qkv_pspec,)

196
197
198
199
200
201
202
203
204
205
206
            compare_ops(
                target_func,
                ref_func,
                [qkv_, bias_, mask_],
                collective_count_ref,
                grad_args=grad_args,
                metric_fwd_dtype=dtype,
                metric_bwd_dtype=dtype,
                in_shardings=(qkv_pspec, bias_pspec, mask_pspec),
                out_shardings=(None, out_grad_shardings),
            )
207
208
209
210
211
212


class TestDistributedCrossAttn:

    def generate_collectives_count_ref(self):
        # for loss
213
        all_reduce_loss_bytes = 4  # 1 * FP32
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
        return generate_collectives_count(allreduce=all_reduce_loss_bytes, allgather=0, other=0)

    def generate_inputs(self, shape, mesh_resource, attn_mask_type, dtype):
        batch, seqlen, heads, hidden = shape

        q = random.normal(random.PRNGKey(1124), shape, dtype=dtype)
        kv = random.normal(random.PRNGKey(1125), (batch, seqlen, 2, heads, hidden), dtype=dtype)

        mask = None
        if attn_mask_type == AttnMaskType.PADDING_MASK:
            mask = make_causal_mask(batch, seqlen)
        elif attn_mask_type == AttnMaskType.CAUSAL_MASK:
            mask = make_self_mask(batch, seqlen)

        q_pspec = PartitionSpec(mesh_resource.dp_resource, None, mesh_resource.tp_resource, None)

230
231
232
233
234
235
236
237
        kv_pspec = PartitionSpec(
            mesh_resource.dp_resource, None, None, mesh_resource.tp_resource, None
        )
        mask_pspec = (
            PartitionSpec(mesh_resource.dp_resource, None, None, None)
            if attn_mask_type != AttnMaskType.NO_MASK
            else None
        )
238
239
240

        return (q, kv, mask), (q_pspec, kv_pspec, mask_pspec)

241
242
243
244
245
246
247
248
249
    @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
    @pytest.mark.parametrize("data_shape", [[32, 128, 12, 64], [32, 512, 16, 64]])
    @pytest.mark.parametrize(
        "attn_mask_type", [AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK]
    )
    @pytest.mark.parametrize("dtype", DTYPES)
    def test_cross_attn(
        self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, attn_mask_type, dtype
    ):
250
251
252
253
254
        attn_bias_type = AttnBiasType.NO_BIAS
        dropout_prob = 0.0
        is_training = True
        scaling_factor = 1.0

255
        _, seqlen, num_head, hidden = data_shape
256

257
258
259
260
261
262
263
264
265
266
267
268
        if not is_fused_attn_kernel_available(
            dtype,
            dtype,
            QKVLayout.BSHD_BS2HD,
            attn_bias_type,
            attn_mask_type,
            dropout_prob,
            num_head,
            num_head,
            seqlen,
            seqlen,
            hidden,
269
            None,  # no window
270
        ):
271
            pytest.skip(f"No FusedAttn backend found")
272
273
274

        def target_func(q, kv, mask):
            return jnp.mean(
275
276
                fused_attn(
                    (q, kv),
277
278
279
280
281
                    None,
                    mask,
                    None,
                    attn_bias_type=attn_bias_type,
                    attn_mask_type=attn_mask_type,
282
                    qkv_layout=QKVLayout.BSHD_BS2HD,
283
284
285
                    scaling_factor=scaling_factor,
                    dropout_probability=dropout_prob,
                    is_training=is_training,
286
287
                ),
                dtype=jnp.float32,
288
            )
289
290
291
292
293
294
295

        def ref_func(query, kv, mask):
            key, value = jnp.split(kv, [1], axis=-3)
            query = jnp.squeeze(query)
            key = jnp.squeeze(key)
            value = jnp.squeeze(value)

296
297
298
299
300
301
302
303
304
305
306
            output = dot_product_attention(
                query,
                key,
                value,
                bias=None,
                mask=mask,
                deterministic=is_training,
                dropout_rate=dropout_prob,
                dropout_rng=None,
                dtype=jnp.float32,
            )
307

308
            return jnp.mean(output, dtype=jnp.float32)
309

310
311
312
        (q, kv, mask), (q_pspec, kv_pspec, mask_pspec) = self.generate_inputs(
            data_shape, mesh_resource, attn_mask_type, dtype
        )
313
314
315
316
317
318
        collective_count_ref = self.generate_collectives_count_ref()
        devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
        mesh = Mesh(devices, mesh_axes)
        with mesh, fp8_autocast(mesh_resource=mesh_resource):
            q_ = jax.device_put(q, NamedSharding(mesh, q_pspec))
            kv_ = jax.device_put(kv, NamedSharding(mesh, kv_pspec))
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
            mask_ = (
                jax.device_put(mask, NamedSharding(mesh, mask_pspec)) if mask is not None else mask
            )

            compare_ops(
                target_func,
                ref_func,
                [q_, kv_, mask_],
                collective_count_ref,
                grad_args=(0, 1),
                metric_fwd_dtype=dtype,
                metric_bwd_dtype=dtype,
                in_shardings=(q_pspec, kv_pspec, mask_pspec),
                out_shardings=(None, (q_pspec, kv_pspec)),
            )
334
335


336
337
338
339
340
341
@pytest.mark.parametrize(
    "device_count,mesh_shape,mesh_axes,mesh_resource", generate_context_parallel_configs()
)
@pytest.mark.parametrize(
    "data_shape",
    [
342
343
344
        # Sequence lengths will be scaled by CP so that we don't run with tiny sizes.
        pytest.param([2, 128, 12, 128], id="2-128xCP-12-128"),
        pytest.param([4, 256, 16, 64], id="4-256xCP-16-64"),
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
    ],
)
@pytest.mark.parametrize("kv_groups", [1, 4, 8, 12, 16])
@pytest.mark.parametrize(
    "attn_mask_type",
    [
        pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL_MASK"),
        pytest.param(AttnMaskType.NO_MASK, id="NO_MASK"),
    ],
)
@pytest.mark.parametrize("dtype", [jnp.bfloat16])
@pytest.mark.parametrize(
    "qkv_layout",
    [
        pytest.param(QKVLayout.BSHD_BS2HD, id="COMBINED_KV"),
        pytest.param(QKVLayout.BSHD_BSHD_BSHD, id="SEPARATE"),
    ],
)
@pytest.mark.parametrize(
    "load_balanced",
    [pytest.param(False, id="UNBALANCED"), pytest.param(True, id="BALANCED")],
)
367
class TestDistributedContextParallelSelfAttn:
368
369
370

    def generate_inputs(self, shape, kv_groups: int, attn_mask_type: AttnMaskType, dtype):
        batch, seqlen, heads, hidden = shape
371
        kv_shape = (batch, seqlen, heads // kv_groups, hidden)
372
373
374
375
376
        qkey, kkey, vkey = random.split(random.PRNGKey(1124), 3)
        q = random.normal(qkey, shape, dtype=dtype)
        k = random.normal(kkey, (batch, seqlen, heads // kv_groups, hidden), dtype=dtype)
        v = random.normal(vkey, (batch, seqlen, heads // kv_groups, hidden), dtype=dtype)

377
378
379
380
381
382
383
384
385
386
387
        def gen_valid(bs, max_seqlen, pad_ratio):
            pad_len = int(max_seqlen * pad_ratio)
            valid_len = max_seqlen - pad_len
            tokens = jnp.concatenate([jnp.ones((bs, valid_len)), jnp.zeros((bs, pad_len))], axis=-1)
            return tokens, jnp.logical_not(tokens)

        from test_fused_attn import make_mask

        q_idx, _ = gen_valid(batch, seqlen, 0.0)
        kv_idx, _ = gen_valid(batch, seqlen, 0.0)
        mask = make_mask(q_idx, kv_idx, None, None, attn_mask_type)
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403

        return q, k, v, mask

    def qkv_to_layout(self, q, k, v, qkv_layout):
        qkv_args = ()
        match qkv_layout:
            case QKVLayout.BSHD_BS2HD:
                k, v = map(partial(jnp.expand_dims, axis=-3), [k, v])
                kv = jnp.concatenate((k, v), axis=-3)
                qkv_args = (q, kv)
            case QKVLayout.BSHD_BSHD_BSHD:
                qkv_args = (q, k, v)
            case _:
                raise ValueError(f"Unsupported {qkv_layout=}")
        return qkv_args

404
    def impl_test_context_parallel_attn(
405
406
407
408
409
410
411
412
413
414
415
        self,
        device_count,
        mesh_shape,
        mesh_axes,
        mesh_resource,
        data_shape,
        kv_groups,
        attn_mask_type,
        dtype,
        qkv_layout,
        load_balanced,
416
        cp_strategy,
417
418
419
420
421
    ):
        attn_bias_type = AttnBiasType.NO_BIAS
        dropout_prob = 0.0
        is_training = True
        dp_size, cp_size, tp_size = mesh_shape
422
        qkv_format = qkv_layout.get_qkv_format()
423

424
        batch, seqlen, num_head, hidden = data_shape
425
426
427
428
429
430

        # Scale the sequence length by 2*CP so its never too small as we scale up test.
        # 2*CP is used since we split into two CP groups for load balancing.
        seqlen = seqlen * cp_size * 2
        data_shape = batch, seqlen, num_head, hidden

431
        num_kv_heads = num_head // kv_groups
432
        scaling_factor = 1.0 / np.sqrt(num_head)
433

434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
        def check_has_backend_for_mask(mask_type):
            return is_fused_attn_kernel_available(
                dtype,
                dtype,
                qkv_layout,
                attn_bias_type,
                attn_mask_type,
                dropout_prob,
                num_head,
                num_kv_heads,
                seqlen,
                seqlen,
                hidden,
                None,
            )  # no SWA for CP

        # For causal masking we depend on having bottom right support also.
        # The API does not check this and instead we rely on lower level checks to raise
        # and exception if the step backend is not supported. This was a deliberate API
        # decision to keep the CP size or flag out of the function.
        has_backend = check_has_backend_for_mask(attn_mask_type)
        if cp_size > 1 and attn_mask_type == AttnMaskType.CAUSAL_MASK:
            has_backend &= check_has_backend_for_mask(AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK)

        if not has_backend:
            pytest.skip(f"No FusedAttn backend found {cp_size=} {attn_mask_type=}.")
460

461
462
463
        if dp_size > 1 and batch % dp_size != 0:
            pytest.skip(f"Skipping {batch=} not a multiple of {dp_size=}")

464
        # make sure the mesh even divides cp and tp axis
465
466
467
468
        if num_head % kv_groups != 0 or (num_head // kv_groups) % tp_size != 0:
            pytest.skip(f"Skipping {kv_groups=} not multiple of {data_shape=} or {tp_size=}")

        def target_func(q, k, v, mask):
469
470
471
472
473
474
475
476
477
478
479
            return fused_attn(
                self.qkv_to_layout(q, k, v, qkv_layout),
                None,  # bias
                mask,
                None,  # seed
                attn_bias_type=attn_bias_type,
                attn_mask_type=attn_mask_type,
                qkv_layout=qkv_layout,
                scaling_factor=scaling_factor,
                dropout_probability=dropout_prob,
                is_training=is_training,
480
                context_parallel_strategy=cp_strategy,
481
482
                context_parallel_causal_load_balanced=load_balanced,
                context_parallel_axis="cp",
483
484
            ).astype(dtype)

485
486
        def ref_func(q, k, v, mask):
            output = general_dot_product_attention(
487
488
489
490
491
                q,
                k,
                v,
                bias=None,
                mask=mask,
492
493
                deterministic=not is_training,
                scale_factor=scaling_factor,
494
495
496
497
                dropout_rate=dropout_prob,
                dropout_rng=None,
                dtype=jnp.float32,
            )
498
499
500
501
502
503
            return output.astype(dtype)

        def grad_func(func, *args, **kwargs):
            # Gradient is small, use a gradient multiplier to amplify the gradient
            _, max_seq_len, num_heads, _ = data_shape
            gradient_multiplier = max_seq_len * num_heads
504
            if attn_mask_type.is_causal():
505
506
507
                gradient_multiplier /= 10
            ret_valid = func(*args, **kwargs)
            return (jnp.mean(ret_valid, dtype=jnp.float32) * gradient_multiplier).astype(dtype)
508
509
510

        q, k, v, mask = self.generate_inputs(data_shape, kv_groups, attn_mask_type, dtype)

511
512
        diff_argnums = (0, 1, 2)

513
        # Single GPU (reference)
514
515
516
517
518
519
        ref_func_jit = jax.jit(
            jax.value_and_grad(
                lambda q, k, v, mask: grad_func(ref_func, q, k, v, mask), argnums=diff_argnums
            )
        )
        ref_fwd, ref_grads = ref_func_jit(q, k, v, mask)
520
521
522
523

        # Multi GPU (function under test)
        devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
        mesh = Mesh(devices, mesh_axes)
524
        with mesh, fp8_autocast(mesh_resource=mesh_resource, enabled=False):
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
            qkv_ps = PartitionSpec(
                mesh_resource.dp_resource,
                mesh_resource.cp_resource,
                mesh_resource.tp_resource,
                None,
            )
            qkv_sharding = NamedSharding(mesh, qkv_ps)

            mask_ps = PartitionSpec(
                mesh_resource.dp_resource, None, mesh_resource.cp_resource, None
            )
            mask_sharding = NamedSharding(mesh, mask_ps)

            reorder = partial(
                reorder_causal_load_balancing, cp_size=cp_size, tensor_format=qkv_format
            )
            inverse_reorder = partial(
                inverse_reorder_causal_load_balancing, cp_size=cp_size, tensor_format=qkv_format
            )

            if load_balanced:
                q, k, v = jax.tree.map(reorder, (q, k, v))

            q_, k_, v_ = map(partial(jax.device_put, device=qkv_sharding), [q, k, v])
            mask_ = jax.device_put(mask, device=mask_sharding)

            target_func_jit = jax.jit(
552
553
554
555
                jax.value_and_grad(
                    lambda q, k, v, mask: grad_func(target_func, q, k, v, mask),
                    argnums=diff_argnums,
                ),
556
557
558
559
560
561
562
563
564
565
566
567
                in_shardings=[qkv_sharding, qkv_sharding, qkv_sharding, mask_sharding],
                out_shardings=(None, (qkv_sharding, qkv_sharding, qkv_sharding)),
            )

            target_fwd, target_grads = target_func_jit(q_, k_, v_, mask_)

            if load_balanced:
                target_dq, target_dk, target_dv = jax.tree.map(inverse_reorder, target_grads[0:3])
                target_grads = (target_dq, target_dk, target_dv, *target_grads[3:])

            has_diffs = False

568
569
570
571
            print_debug_tensor_stats("target", target_fwd)
            print_debug_tensor_stats("ref", ref_fwd)
            print_debug_tensor_stats("diff", jnp.abs(target_fwd - ref_fwd))
            assert_allclose(target_fwd, ref_fwd, dtype=dtype)
572
573
574
575
576
577

            for i in range(len(target_grads)):
                if ref_grads[i] is None or target_grads[i] is None:
                    # expect both none if one is
                    assert target_grads[i] is None and ref_grads[i] is None
                else:
578
579
580
581
582
583
584
                    print_debug_tensor_stats(f"target_grad[{i}]", target_grads[i])
                    print_debug_tensor_stats(f"ref_grad[{i}]", ref_grads[i])
                    print_debug_tensor_stats(
                        f"diff_grad[{i}]", jnp.abs(target_grads[i] - ref_grads[i])
                    )

                assert_allclose(target_grads[i], ref_grads[i], dtype=dtype)
585

586
    def test_context_parallel_allgather_attn(
587
588
589
590
591
592
593
594
595
596
597
598
        self,
        device_count,
        mesh_shape,
        mesh_axes,
        mesh_resource,
        data_shape,
        kv_groups,
        attn_mask_type,
        dtype,
        qkv_layout,
        load_balanced,
    ):
599
        return self.impl_test_context_parallel_attn(
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
            device_count,
            mesh_shape,
            mesh_axes,
            mesh_resource,
            data_shape,
            kv_groups,
            attn_mask_type,
            dtype,
            qkv_layout,
            load_balanced,
            CPStrategy.ALL_GATHER,
        )

    def test_context_parallel_ring_attn(
        self,
        device_count,
        mesh_shape,
        mesh_axes,
        mesh_resource,
        data_shape,
        kv_groups,
        attn_mask_type,
        dtype,
        qkv_layout,
        load_balanced,
    ):
626
        return self.impl_test_context_parallel_attn(
627
628
629
630
631
632
633
634
635
636
637
638
639
            device_count,
            mesh_shape,
            mesh_axes,
            mesh_resource,
            data_shape,
            kv_groups,
            attn_mask_type,
            dtype,
            qkv_layout,
            load_balanced,
            CPStrategy.RING,
        )

640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665

class TestReorderCausalLoadBalancing:
    @pytest.mark.parametrize("cp_size", [2, 4, 8])
    @pytest.mark.parametrize(
        "shape",
        [
            pytest.param([1, 16, 1, 1], id="1-16-1-1"),
            pytest.param([4, 32, 12, 32], id="4-32-12-32"),
            pytest.param([3, 32, 8, 64], id="3-32-8-64"),
        ],
    )
    @pytest.mark.parametrize("qkv_format", [QKVFormat.BSHD, QKVFormat.SBHD])
    def test(self, cp_size, shape, qkv_format):
        tensor = random.normal(random.PRNGKey(1124), shape, dtype=jnp.bfloat16)
        if qkv_format == QKVFormat.SBHD:
            tensor = tensor.swapaxes(0, 1)

        ref = tensor.copy()

        reorder = jax.jit(reorder_causal_load_balancing, static_argnums=[1, 2])
        inverse = jax.jit(inverse_reorder_causal_load_balancing, static_argnums=[1, 2])

        reordered = reorder(tensor, cp_size, qkv_format)
        inversed = inverse(reordered, cp_size, qkv_format)

        assert jnp.array_equal(inversed, ref)