test_distributed_fused_attn.py 22.7 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.

import pytest
6
from functools import partial
7
8
9
10
11
12

import jax
import jax.numpy as jnp
import numpy as np
from flax.linen import dot_product_attention
from jax import random
13
from jax.sharding import Mesh, NamedSharding, PartitionSpec
14
15
16
17
18
19
from distributed_test_base import (
    generate_configs,
    generate_context_parallel_configs,
    generate_collectives_count,
    compare_ops,
)
20
21
22
23
24
25
26
from utils import (
    make_causal_mask,
    make_self_mask,
    assert_tree_like_allclose,
    assert_allclose,
    print_debug_tensor_stats,
)
27
from transformer_engine.jax import fp8_autocast
28
29
from transformer_engine.jax.attention import (
    is_fused_attn_kernel_available,
30
    fused_attn,
31
32
    AttnBiasType,
    AttnMaskType,
33
    QKVLayout,
34
35
36
37
    QKVFormat,
    get_qkv_format,
    reorder_causal_load_balancing,
    inverse_reorder_causal_load_balancing,
38
    CPStrategy,
39
)
40
from transformer_engine.jax.sharding import MeshResource
41

42
43
# We will use the golden reference model from our non distributed attention test fixture.
from test_fused_attn import general_dot_product_attention, make_mask
44
45
46
47
48
49

DTYPES = [jnp.float16, jnp.bfloat16]


class TestDistributedSelfAttn:

50
51
52
    def generate_collectives_count_ref(
        self, mesh_shape, mesh_axes, mesh_resource, with_bias, shape, dtype
    ):
53
54
55
56
57
58
59
60
        jax_dtype = jax.dtypes.canonicalize_dtype(dtype)
        _, seqlen, _, heads, _ = shape
        is_dp_enabled = mesh_resource.dp_resource is not None
        tp_size = 1
        if mesh_resource.tp_resource is not None:
            idx = mesh_axes.index(mesh_resource.tp_resource)
            tp_size = mesh_shape[idx]

61
        all_reduce_loss_bytes = 4  # 1 * FP32
62
63
64
65
66
67
68
69
70
71
        bias_bytes = int(with_bias) * (heads // tp_size) * seqlen * seqlen * jax_dtype.itemsize
        allreduce_total_bytes = all_reduce_loss_bytes + (bias_bytes * is_dp_enabled)
        # for loss and dbias
        return generate_collectives_count(allreduce=allreduce_total_bytes, allgather=0, other=0)

    def generate_inputs(self, shape, mesh_resource, with_bias, attn_mask_type, dtype):
        batch, seqlen, _, heads, _ = shape

        qkv = random.normal(random.PRNGKey(1124), shape, dtype=dtype)

72
73
74
75
76
        bias = (
            random.normal(random.PRNGKey(1125), (1, heads, seqlen, seqlen), dtype)
            if with_bias
            else None
        )
77
78
79
80
81
82
83

        mask = None
        if attn_mask_type == AttnMaskType.PADDING_MASK:
            mask = make_causal_mask(batch, seqlen)
        elif attn_mask_type == AttnMaskType.CAUSAL_MASK:
            mask = make_self_mask(batch, seqlen)

84
85
86
87
88
89
90
91
92
93
94
        qkv_pspec = PartitionSpec(
            mesh_resource.dp_resource, None, None, mesh_resource.tp_resource, None
        )
        bias_pspec = (
            PartitionSpec(None, mesh_resource.tp_resource, None, None) if with_bias else None
        )
        mask_pspec = (
            PartitionSpec(mesh_resource.dp_resource, None, None, None)
            if attn_mask_type != AttnMaskType.NO_MASK
            else None
        )
95
96
97

        return (qkv, bias, mask), (qkv_pspec, bias_pspec, mask_pspec)

98
99
    @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
    @pytest.mark.parametrize("data_shape", [[32, 512, 3, 12, 64], [32, 1024, 3, 16, 128]])
100
    @pytest.mark.parametrize(
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
        "attn_bias_type",
        [AttnBiasType.NO_BIAS, AttnBiasType.PRE_SCALE_BIAS, AttnBiasType.POST_SCALE_BIAS],
    )
    @pytest.mark.parametrize(
        "attn_mask_type", [AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK]
    )
    @pytest.mark.parametrize("dtype", DTYPES)
    def test_self_attn(
        self,
        device_count,
        mesh_shape,
        mesh_axes,
        mesh_resource,
        data_shape,
        attn_bias_type,
        attn_mask_type,
        dtype,
    ):
119
120
121
122
        dropout_prob = 0.0
        is_training = True
        scaling_factor = 1.0

123
        _, seqlen, _, num_head, hidden = data_shape
124

125
126
127
128
129
130
131
132
133
134
135
136
        if not is_fused_attn_kernel_available(
            dtype,
            dtype,
            QKVLayout.BS3HD,
            attn_bias_type,
            attn_mask_type,
            dropout_prob,
            num_head,
            num_head,
            seqlen,
            seqlen,
            hidden,
137
            None,  # no window
138
        ):
139
            pytest.skip(f"No FusedAttn backend found")
140
141
142

        def target_func(qkv, bias, mask):
            return jnp.mean(
143
144
                fused_attn(
                    (qkv,),
145
146
147
148
149
                    bias,
                    mask,
                    None,
                    attn_bias_type=attn_bias_type,
                    attn_mask_type=attn_mask_type,
150
                    qkv_layout=QKVLayout.BS3HD,
151
152
153
154
155
                    scaling_factor=scaling_factor,
                    dropout_probability=dropout_prob,
                    is_training=is_training,
                )
            )
156
157
158
159
160
161
162

        def ref_func(qkv, bias, mask):
            query, key, value = jnp.split(qkv, [1, 2], axis=-3)
            query = jnp.squeeze(query)
            key = jnp.squeeze(key)
            value = jnp.squeeze(value)

163
164
165
166
167
168
169
170
171
172
173
            output = dot_product_attention(
                query,
                key,
                value,
                bias=bias,
                mask=mask,
                deterministic=is_training,
                dropout_rate=dropout_prob,
                dropout_rng=None,
                dtype=jnp.float32,
            )
174
175
176
177

            return jnp.mean(output).astype(dtype)

        with_bias = attn_bias_type != AttnBiasType.NO_BIAS
178
179
180
181
182
183
        (qkv, bias, mask), (qkv_pspec, bias_pspec, mask_pspec) = self.generate_inputs(
            data_shape, mesh_resource, with_bias, attn_mask_type, dtype
        )
        collective_count_ref = self.generate_collectives_count_ref(
            mesh_shape, mesh_axes, mesh_resource, with_bias, data_shape, dtype
        )
184
185
186
187
        devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
        mesh = Mesh(devices, mesh_axes)
        with mesh, fp8_autocast(mesh_resource=mesh_resource):
            qkv_ = jax.device_put(qkv, NamedSharding(mesh, qkv_pspec))
188
189
190
191
192
193
            bias_ = (
                jax.device_put(bias, NamedSharding(mesh, bias_pspec)) if bias is not None else bias
            )
            mask_ = (
                jax.device_put(mask, NamedSharding(mesh, mask_pspec)) if mask is not None else mask
            )
194
195
196
197

            grad_args = (0, 1) if with_bias else (0,)
            out_grad_shardings = (qkv_pspec, bias_pspec) if with_bias else (qkv_pspec,)

198
199
200
201
202
203
204
205
206
207
208
            compare_ops(
                target_func,
                ref_func,
                [qkv_, bias_, mask_],
                collective_count_ref,
                grad_args=grad_args,
                metric_fwd_dtype=dtype,
                metric_bwd_dtype=dtype,
                in_shardings=(qkv_pspec, bias_pspec, mask_pspec),
                out_shardings=(None, out_grad_shardings),
            )
209
210
211
212
213
214


class TestDistributedCrossAttn:

    def generate_collectives_count_ref(self):
        # for loss
215
        all_reduce_loss_bytes = 4  # 1 * FP32
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
        return generate_collectives_count(allreduce=all_reduce_loss_bytes, allgather=0, other=0)

    def generate_inputs(self, shape, mesh_resource, attn_mask_type, dtype):
        batch, seqlen, heads, hidden = shape

        q = random.normal(random.PRNGKey(1124), shape, dtype=dtype)
        kv = random.normal(random.PRNGKey(1125), (batch, seqlen, 2, heads, hidden), dtype=dtype)

        mask = None
        if attn_mask_type == AttnMaskType.PADDING_MASK:
            mask = make_causal_mask(batch, seqlen)
        elif attn_mask_type == AttnMaskType.CAUSAL_MASK:
            mask = make_self_mask(batch, seqlen)

        q_pspec = PartitionSpec(mesh_resource.dp_resource, None, mesh_resource.tp_resource, None)

232
233
234
235
236
237
238
239
        kv_pspec = PartitionSpec(
            mesh_resource.dp_resource, None, None, mesh_resource.tp_resource, None
        )
        mask_pspec = (
            PartitionSpec(mesh_resource.dp_resource, None, None, None)
            if attn_mask_type != AttnMaskType.NO_MASK
            else None
        )
240
241
242

        return (q, kv, mask), (q_pspec, kv_pspec, mask_pspec)

243
244
245
246
247
248
249
250
251
    @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
    @pytest.mark.parametrize("data_shape", [[32, 128, 12, 64], [32, 512, 16, 64]])
    @pytest.mark.parametrize(
        "attn_mask_type", [AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK]
    )
    @pytest.mark.parametrize("dtype", DTYPES)
    def test_cross_attn(
        self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, attn_mask_type, dtype
    ):
252
253
254
255
256
        attn_bias_type = AttnBiasType.NO_BIAS
        dropout_prob = 0.0
        is_training = True
        scaling_factor = 1.0

257
        _, seqlen, num_head, hidden = data_shape
258

259
260
261
262
263
264
265
266
267
268
269
270
        if not is_fused_attn_kernel_available(
            dtype,
            dtype,
            QKVLayout.BSHD_BS2HD,
            attn_bias_type,
            attn_mask_type,
            dropout_prob,
            num_head,
            num_head,
            seqlen,
            seqlen,
            hidden,
271
            None,  # no window
272
        ):
273
            pytest.skip(f"No FusedAttn backend found")
274
275
276

        def target_func(q, kv, mask):
            return jnp.mean(
277
278
                fused_attn(
                    (q, kv),
279
280
281
282
283
                    None,
                    mask,
                    None,
                    attn_bias_type=attn_bias_type,
                    attn_mask_type=attn_mask_type,
284
                    qkv_layout=QKVLayout.BSHD_BS2HD,
285
286
287
                    scaling_factor=scaling_factor,
                    dropout_probability=dropout_prob,
                    is_training=is_training,
288
289
                ),
                dtype=jnp.float32,
290
            )
291
292
293
294
295
296
297

        def ref_func(query, kv, mask):
            key, value = jnp.split(kv, [1], axis=-3)
            query = jnp.squeeze(query)
            key = jnp.squeeze(key)
            value = jnp.squeeze(value)

298
299
300
301
302
303
304
305
306
307
308
            output = dot_product_attention(
                query,
                key,
                value,
                bias=None,
                mask=mask,
                deterministic=is_training,
                dropout_rate=dropout_prob,
                dropout_rng=None,
                dtype=jnp.float32,
            )
309

310
            return jnp.mean(output, dtype=jnp.float32)
311

312
313
314
        (q, kv, mask), (q_pspec, kv_pspec, mask_pspec) = self.generate_inputs(
            data_shape, mesh_resource, attn_mask_type, dtype
        )
315
316
317
318
319
320
        collective_count_ref = self.generate_collectives_count_ref()
        devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
        mesh = Mesh(devices, mesh_axes)
        with mesh, fp8_autocast(mesh_resource=mesh_resource):
            q_ = jax.device_put(q, NamedSharding(mesh, q_pspec))
            kv_ = jax.device_put(kv, NamedSharding(mesh, kv_pspec))
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
            mask_ = (
                jax.device_put(mask, NamedSharding(mesh, mask_pspec)) if mask is not None else mask
            )

            compare_ops(
                target_func,
                ref_func,
                [q_, kv_, mask_],
                collective_count_ref,
                grad_args=(0, 1),
                metric_fwd_dtype=dtype,
                metric_bwd_dtype=dtype,
                in_shardings=(q_pspec, kv_pspec, mask_pspec),
                out_shardings=(None, (q_pspec, kv_pspec)),
            )
336
337


338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
@pytest.mark.parametrize(
    "device_count,mesh_shape,mesh_axes,mesh_resource", generate_context_parallel_configs()
)
@pytest.mark.parametrize(
    "data_shape",
    [
        pytest.param([2, 512, 12, 128], id="2-512-12-128"),
        pytest.param([4, 1024, 16, 64], id="4-1024-16-64"),
    ],
)
@pytest.mark.parametrize("kv_groups", [1, 4, 8, 12, 16])
@pytest.mark.parametrize(
    "attn_mask_type",
    [
        pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL_MASK"),
        pytest.param(AttnMaskType.NO_MASK, id="NO_MASK"),
    ],
)
@pytest.mark.parametrize("dtype", [jnp.bfloat16])
@pytest.mark.parametrize(
    "qkv_layout",
    [
        pytest.param(QKVLayout.BSHD_BS2HD, id="COMBINED_KV"),
        pytest.param(QKVLayout.BSHD_BSHD_BSHD, id="SEPARATE"),
    ],
)
@pytest.mark.parametrize(
    "load_balanced",
    [pytest.param(False, id="UNBALANCED"), pytest.param(True, id="BALANCED")],
)
368
class TestDistributedContextParallelSelfAttn:
369
370
371

    def generate_inputs(self, shape, kv_groups: int, attn_mask_type: AttnMaskType, dtype):
        batch, seqlen, heads, hidden = shape
372
        kv_shape = (batch, seqlen, heads // kv_groups, hidden)
373
374
375
376
377
        qkey, kkey, vkey = random.split(random.PRNGKey(1124), 3)
        q = random.normal(qkey, shape, dtype=dtype)
        k = random.normal(kkey, (batch, seqlen, heads // kv_groups, hidden), dtype=dtype)
        v = random.normal(vkey, (batch, seqlen, heads // kv_groups, hidden), dtype=dtype)

378
379
380
381
382
383
384
385
386
387
388
        def gen_valid(bs, max_seqlen, pad_ratio):
            pad_len = int(max_seqlen * pad_ratio)
            valid_len = max_seqlen - pad_len
            tokens = jnp.concatenate([jnp.ones((bs, valid_len)), jnp.zeros((bs, pad_len))], axis=-1)
            return tokens, jnp.logical_not(tokens)

        from test_fused_attn import make_mask

        q_idx, _ = gen_valid(batch, seqlen, 0.0)
        kv_idx, _ = gen_valid(batch, seqlen, 0.0)
        mask = make_mask(q_idx, kv_idx, None, None, attn_mask_type)
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404

        return q, k, v, mask

    def qkv_to_layout(self, q, k, v, qkv_layout):
        qkv_args = ()
        match qkv_layout:
            case QKVLayout.BSHD_BS2HD:
                k, v = map(partial(jnp.expand_dims, axis=-3), [k, v])
                kv = jnp.concatenate((k, v), axis=-3)
                qkv_args = (q, kv)
            case QKVLayout.BSHD_BSHD_BSHD:
                qkv_args = (q, k, v)
            case _:
                raise ValueError(f"Unsupported {qkv_layout=}")
        return qkv_args

405
    def impl_test_contex_parallel_attn(
406
407
408
409
410
411
412
413
414
415
416
        self,
        device_count,
        mesh_shape,
        mesh_axes,
        mesh_resource,
        data_shape,
        kv_groups,
        attn_mask_type,
        dtype,
        qkv_layout,
        load_balanced,
417
        cp_strategy,
418
419
420
421
422
423
424
    ):
        attn_bias_type = AttnBiasType.NO_BIAS
        dropout_prob = 0.0
        is_training = True
        dp_size, cp_size, tp_size = mesh_shape
        qkv_format = get_qkv_format(qkv_layout)

425
        batch, seqlen, num_head, hidden = data_shape
426
        num_kv_heads = num_head // kv_groups
427
        scaling_factor = 1.0 / np.sqrt(num_head)
428

429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
        def check_has_backend_for_mask(mask_type):
            return is_fused_attn_kernel_available(
                dtype,
                dtype,
                qkv_layout,
                attn_bias_type,
                attn_mask_type,
                dropout_prob,
                num_head,
                num_kv_heads,
                seqlen,
                seqlen,
                hidden,
                None,
            )  # no SWA for CP

        # For causal masking we depend on having bottom right support also.
        # The API does not check this and instead we rely on lower level checks to raise
        # and exception if the step backend is not supported. This was a deliberate API
        # decision to keep the CP size or flag out of the function.
        has_backend = check_has_backend_for_mask(attn_mask_type)
        if cp_size > 1 and attn_mask_type == AttnMaskType.CAUSAL_MASK:
            has_backend &= check_has_backend_for_mask(AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK)

        if not has_backend:
            pytest.skip(f"No FusedAttn backend found {cp_size=} {attn_mask_type=}.")
455

456
457
458
        if dp_size > 1 and batch % dp_size != 0:
            pytest.skip(f"Skipping {batch=} not a multiple of {dp_size=}")

459
        # make sure the mesh even divides cp and tp axis
460
461
462
463
        if num_head % kv_groups != 0 or (num_head // kv_groups) % tp_size != 0:
            pytest.skip(f"Skipping {kv_groups=} not multiple of {data_shape=} or {tp_size=}")

        def target_func(q, k, v, mask):
464
465
466
467
468
469
470
471
472
473
474
            return fused_attn(
                self.qkv_to_layout(q, k, v, qkv_layout),
                None,  # bias
                mask,
                None,  # seed
                attn_bias_type=attn_bias_type,
                attn_mask_type=attn_mask_type,
                qkv_layout=qkv_layout,
                scaling_factor=scaling_factor,
                dropout_probability=dropout_prob,
                is_training=is_training,
475
                context_parallel_strategy=cp_strategy,
476
477
                context_parallel_causal_load_balanced=load_balanced,
                context_parallel_axis="cp",
478
479
            ).astype(dtype)

480
481
        def ref_func(q, k, v, mask):
            output = general_dot_product_attention(
482
483
484
485
486
                q,
                k,
                v,
                bias=None,
                mask=mask,
487
488
                deterministic=not is_training,
                scale_factor=scaling_factor,
489
490
491
492
                dropout_rate=dropout_prob,
                dropout_rng=None,
                dtype=jnp.float32,
            )
493
494
495
496
497
498
499
500
501
502
            return output.astype(dtype)

        def grad_func(func, *args, **kwargs):
            # Gradient is small, use a gradient multiplier to amplify the gradient
            _, max_seq_len, num_heads, _ = data_shape
            gradient_multiplier = max_seq_len * num_heads
            if attn_mask_type in [AttnMaskType.CAUSAL_MASK, AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK]:
                gradient_multiplier /= 10
            ret_valid = func(*args, **kwargs)
            return (jnp.mean(ret_valid, dtype=jnp.float32) * gradient_multiplier).astype(dtype)
503
504
505

        q, k, v, mask = self.generate_inputs(data_shape, kv_groups, attn_mask_type, dtype)

506
507
        diff_argnums = (0, 1, 2)

508
        # Single GPU (reference)
509
510
511
512
513
514
        ref_func_jit = jax.jit(
            jax.value_and_grad(
                lambda q, k, v, mask: grad_func(ref_func, q, k, v, mask), argnums=diff_argnums
            )
        )
        ref_fwd, ref_grads = ref_func_jit(q, k, v, mask)
515
516
517
518

        # Multi GPU (function under test)
        devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
        mesh = Mesh(devices, mesh_axes)
519
        with mesh, fp8_autocast(mesh_resource=mesh_resource, enabled=False):
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
            qkv_ps = PartitionSpec(
                mesh_resource.dp_resource,
                mesh_resource.cp_resource,
                mesh_resource.tp_resource,
                None,
            )
            qkv_sharding = NamedSharding(mesh, qkv_ps)

            mask_ps = PartitionSpec(
                mesh_resource.dp_resource, None, mesh_resource.cp_resource, None
            )
            mask_sharding = NamedSharding(mesh, mask_ps)

            reorder = partial(
                reorder_causal_load_balancing, cp_size=cp_size, tensor_format=qkv_format
            )
            inverse_reorder = partial(
                inverse_reorder_causal_load_balancing, cp_size=cp_size, tensor_format=qkv_format
            )

            if load_balanced:
                q, k, v = jax.tree.map(reorder, (q, k, v))

            q_, k_, v_ = map(partial(jax.device_put, device=qkv_sharding), [q, k, v])
            mask_ = jax.device_put(mask, device=mask_sharding)

            target_func_jit = jax.jit(
547
548
549
550
                jax.value_and_grad(
                    lambda q, k, v, mask: grad_func(target_func, q, k, v, mask),
                    argnums=diff_argnums,
                ),
551
552
553
554
555
556
557
558
559
560
561
562
                in_shardings=[qkv_sharding, qkv_sharding, qkv_sharding, mask_sharding],
                out_shardings=(None, (qkv_sharding, qkv_sharding, qkv_sharding)),
            )

            target_fwd, target_grads = target_func_jit(q_, k_, v_, mask_)

            if load_balanced:
                target_dq, target_dk, target_dv = jax.tree.map(inverse_reorder, target_grads[0:3])
                target_grads = (target_dq, target_dk, target_dv, *target_grads[3:])

            has_diffs = False

563
564
565
566
            print_debug_tensor_stats("target", target_fwd)
            print_debug_tensor_stats("ref", ref_fwd)
            print_debug_tensor_stats("diff", jnp.abs(target_fwd - ref_fwd))
            assert_allclose(target_fwd, ref_fwd, dtype=dtype)
567
568
569
570
571
572

            for i in range(len(target_grads)):
                if ref_grads[i] is None or target_grads[i] is None:
                    # expect both none if one is
                    assert target_grads[i] is None and ref_grads[i] is None
                else:
573
574
575
576
577
578
579
                    print_debug_tensor_stats(f"target_grad[{i}]", target_grads[i])
                    print_debug_tensor_stats(f"ref_grad[{i}]", ref_grads[i])
                    print_debug_tensor_stats(
                        f"diff_grad[{i}]", jnp.abs(target_grads[i] - ref_grads[i])
                    )

                assert_allclose(target_grads[i], ref_grads[i], dtype=dtype)
580

581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
    def test_contex_parallel_allgather_attn(
        self,
        device_count,
        mesh_shape,
        mesh_axes,
        mesh_resource,
        data_shape,
        kv_groups,
        attn_mask_type,
        dtype,
        qkv_layout,
        load_balanced,
    ):
        return self.impl_test_contex_parallel_attn(
            device_count,
            mesh_shape,
            mesh_axes,
            mesh_resource,
            data_shape,
            kv_groups,
            attn_mask_type,
            dtype,
            qkv_layout,
            load_balanced,
            CPStrategy.ALL_GATHER,
        )

    def test_context_parallel_ring_attn(
        self,
        device_count,
        mesh_shape,
        mesh_axes,
        mesh_resource,
        data_shape,
        kv_groups,
        attn_mask_type,
        dtype,
        qkv_layout,
        load_balanced,
    ):
        return self.impl_test_contex_parallel_attn(
            device_count,
            mesh_shape,
            mesh_axes,
            mesh_resource,
            data_shape,
            kv_groups,
            attn_mask_type,
            dtype,
            qkv_layout,
            load_balanced,
            CPStrategy.RING,
        )

635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660

class TestReorderCausalLoadBalancing:
    @pytest.mark.parametrize("cp_size", [2, 4, 8])
    @pytest.mark.parametrize(
        "shape",
        [
            pytest.param([1, 16, 1, 1], id="1-16-1-1"),
            pytest.param([4, 32, 12, 32], id="4-32-12-32"),
            pytest.param([3, 32, 8, 64], id="3-32-8-64"),
        ],
    )
    @pytest.mark.parametrize("qkv_format", [QKVFormat.BSHD, QKVFormat.SBHD])
    def test(self, cp_size, shape, qkv_format):
        tensor = random.normal(random.PRNGKey(1124), shape, dtype=jnp.bfloat16)
        if qkv_format == QKVFormat.SBHD:
            tensor = tensor.swapaxes(0, 1)

        ref = tensor.copy()

        reorder = jax.jit(reorder_causal_load_balancing, static_argnums=[1, 2])
        inverse = jax.jit(inverse_reorder_causal_load_balancing, static_argnums=[1, 2])

        reordered = reorder(tensor, cp_size, qkv_format)
        inversed = inverse(reordered, cp_size, qkv_format)

        assert jnp.array_equal(inversed, ref)