test_distributed_fused_attn.py 20.4 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.

5
6
import os
import pytest
7
8
9
import jax
import jax.numpy as jnp
from jax import random
10
11
from distributed_test_base import (
    generate_configs,
12
    generate_context_parallel_configs_for_attn,
13
    generate_collectives_count,
14
)
15
from test_fused_attn import FusedAttnRunner, BiasShape, SeqDescFormat
16
from utils import pytest_parametrize_wrapper
17
18
19
20
from transformer_engine.jax.attention import (
    is_fused_attn_kernel_available,
    AttnBiasType,
    AttnMaskType,
21
    AttnSoftmaxType,
22
    QKVLayout,
23
24
25
    QKVFormat,
    reorder_causal_load_balancing,
    inverse_reorder_causal_load_balancing,
26
    CPStrategy,
Reese Wang's avatar
Reese Wang committed
27
    ReorderStrategy,
28
29
)

30

31
DTYPES = [jnp.bfloat16]
32

33
34
35
36
37
38
DISTRIBUTED_SELF_ATTN_DATA_SHAPES = {
    "L0": [()],
    "L1": [(32, 1024, 16, 128)],
    "L2": [(32, 512, 12, 64)],
}

39
40
41

class TestDistributedSelfAttn:

42
43
44
    def generate_collectives_count_ref(
        self, mesh_shape, mesh_axes, mesh_resource, with_bias, shape, dtype
    ):
45
        jax_dtype = jax.dtypes.canonicalize_dtype(dtype)
46
        _, seqlen, heads, _ = shape
47
48
        is_dp_enabled = mesh_resource.dp_resource is not None
        tp_size = 1
49
50
        if mesh_resource.tpsp_resource is not None:
            idx = mesh_axes.index(mesh_resource.tpsp_resource)
51
52
            tp_size = mesh_shape[idx]

53
        all_reduce_loss_bytes = 4  # 1 * FP32
54
55
56
57
58
        bias_bytes = int(with_bias) * (heads // tp_size) * seqlen * seqlen * jax_dtype.itemsize
        allreduce_total_bytes = all_reduce_loss_bytes + (bias_bytes * is_dp_enabled)
        # for loss and dbias
        return generate_collectives_count(allreduce=allreduce_total_bytes, allgather=0, other=0)

59
    def impl_test_self_attn(
60
61
62
63
64
65
66
        self,
        device_count,
        mesh_shape,
        mesh_axes,
        mesh_resource,
        data_shape,
        attn_bias_type,
67
        bias_shape,
68
69
        attn_mask_type,
        dtype,
70
        softmax_type,
71
        use_shardy,
72
    ):
73
        jax.config.update("jax_use_shardy_partitioner", use_shardy)
74
75
        dropout_prob = 0.0
        is_training = True
76
        batch, seqlen, num_head, hidden = data_shape
77

78
        if not is_fused_attn_kernel_available(
79
            is_training,
80
81
82
83
84
            dtype,
            dtype,
            QKVLayout.BS3HD,
            attn_bias_type,
            attn_mask_type,
85
            softmax_type,
86
87
88
89
90
91
            dropout_prob,
            num_head,
            num_head,
            seqlen,
            seqlen,
            hidden,
92
            hidden,
93
            None,  # no window
94
        ):
95
            pytest.skip("No FusedAttn backend found")
96

97
98
99
100
101
102
103
        col_ref = self.generate_collectives_count_ref(
            mesh_shape,
            mesh_axes,
            mesh_resource,
            attn_bias_type != AttnBiasType.NO_BIAS,
            data_shape,
            dtype,
104
        )
105
106
107
108
109
110
111
        runner = FusedAttnRunner(
            batch,
            seqlen,
            seqlen,
            num_head,
            num_head,
            hidden,
112
            hidden,
113
114
            attn_bias_type,
            attn_mask_type,
115
            softmax_type,
116
117
118
119
120
121
            dropout_prob,
            dtype,
            is_training,
            QKVLayout.BS3HD,
            bias_shape,
            None,
122
            SeqDescFormat.Seqlens,
123
124
125
126
127
            number_of_devices=device_count,
            mesh_shape=mesh_shape,
            mesh_axes=mesh_axes,
            mesh_resource=mesh_resource,
            coll_count_ref=col_ref,
128
        )
129
        runner.test_backward()
130

131
    @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
132
    @pytest_parametrize_wrapper("data_shape", DISTRIBUTED_SELF_ATTN_DATA_SHAPES)
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
    @pytest.mark.parametrize(
        "attn_bias_type, bias_shape",
        [
            pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"),
            pytest.param(AttnBiasType.PRE_SCALE_BIAS, BiasShape._1HSS, id="PRE_SCALE_BIAS-1HSS"),
            pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._1HSS, id="POST_SCALE_BIAS-1HSS"),
        ],
    )
    @pytest.mark.parametrize(
        "attn_mask_type",
        [
            pytest.param(AttnMaskType.PADDING_MASK, id="PADDING_MASK"),
            pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL_MASK"),
        ],
    )
    @pytest.mark.parametrize("dtype", DTYPES)
149
150
151
152
153
154
155
156
    @pytest.mark.parametrize(
        "softmax_type",
        [
            pytest.param(AttnSoftmaxType.VANILLA_SOFTMAX, id="VANILLA_SOFTMAX"),
            pytest.param(AttnSoftmaxType.OFF_BY_ONE_SOFTMAX, id="OFF_BY_ONE_SOFTMAX"),
            pytest.param(AttnSoftmaxType.LEARNABLE_SOFTMAX, id="LEARNABLE_SOFTMAX"),
        ],
    )
157
158
159
160
161
162
163
164
165
166
167
    def test_self_attn(
        self,
        device_count,
        mesh_shape,
        mesh_axes,
        mesh_resource,
        data_shape,
        attn_bias_type,
        bias_shape,
        attn_mask_type,
        dtype,
168
        softmax_type,
169
170
171
172
173
174
175
176
177
178
179
    ):
        self.impl_test_self_attn(
            device_count,
            mesh_shape,
            mesh_axes,
            mesh_resource,
            data_shape,
            attn_bias_type,
            bias_shape,
            attn_mask_type,
            dtype,
180
            softmax_type,
181
182
183
184
185
186
187
188
189
190
191
            use_shardy=False,
        )

    @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
    @pytest.mark.parametrize(
        "attn_bias_type, bias_shape",
        [
            pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"),
            pytest.param(AttnBiasType.PRE_SCALE_BIAS, BiasShape._1HSS, id="PRE_SCALE_BIAS-1HSS"),
        ],
    )
192
193
194
195
196
197
198
199
    @pytest.mark.parametrize(
        "softmax_type",
        [
            pytest.param(AttnSoftmaxType.VANILLA_SOFTMAX, id="VANILLA_SOFTMAX"),
            pytest.param(AttnSoftmaxType.OFF_BY_ONE_SOFTMAX, id="OFF_BY_ONE_SOFTMAX"),
            pytest.param(AttnSoftmaxType.LEARNABLE_SOFTMAX, id="LEARNABLE_SOFTMAX"),
        ],
    )
200
    def test_self_attn_shardy(
201
202
203
204
205
206
207
208
        self,
        device_count,
        mesh_shape,
        mesh_axes,
        mesh_resource,
        attn_bias_type,
        bias_shape,
        softmax_type,
209
210
211
212
213
214
215
216
217
218
219
220
    ):
        data_shape = (32, 512, 12, 64)
        self.impl_test_self_attn(
            device_count,
            mesh_shape,
            mesh_axes,
            mesh_resource,
            data_shape,
            attn_bias_type,
            bias_shape,
            AttnMaskType.PADDING_MASK,
            jnp.bfloat16,
221
            softmax_type,
222
223
224
            use_shardy=True,
        )

225

226
227
228
229
230
231
232
DISTRIBUTED_CROSS_ATTN_DATA_SHAPES = {
    "L0": [()],
    "L1": [[32, 512, 16, 64]],
    "L2": [[32, 128, 12, 64]],
}


233
234
235
236
class TestDistributedCrossAttn:

    def generate_collectives_count_ref(self):
        # for loss
237
        all_reduce_loss_bytes = 4  # 1 * FP32
238
239
        return generate_collectives_count(allreduce=all_reduce_loss_bytes, allgather=0, other=0)

240
    @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
241
    @pytest_parametrize_wrapper("data_shape", DISTRIBUTED_CROSS_ATTN_DATA_SHAPES)
242
243
244
245
    @pytest.mark.parametrize(
        "attn_mask_type", [AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK]
    )
    @pytest.mark.parametrize("dtype", DTYPES)
246
247
248
249
250
251
252
253
    @pytest.mark.parametrize(
        "softmax_type",
        [
            pytest.param(AttnSoftmaxType.VANILLA_SOFTMAX, id="VANILLA_SOFTMAX"),
            pytest.param(AttnSoftmaxType.OFF_BY_ONE_SOFTMAX, id="OFF_BY_ONE_SOFTMAX"),
            pytest.param(AttnSoftmaxType.LEARNABLE_SOFTMAX, id="LEARNABLE_SOFTMAX"),
        ],
    )
254
    def test_cross_attn(
255
256
257
258
259
260
261
262
263
        self,
        device_count,
        mesh_shape,
        mesh_axes,
        mesh_resource,
        data_shape,
        attn_mask_type,
        dtype,
        softmax_type,
264
    ):
265
        attn_bias_type = AttnBiasType.NO_BIAS
266
        bias_shape = None
267
268
269
        dropout_prob = 0.0
        is_training = True

270
        batch, seqlen, num_head, hidden = data_shape
271

272
        if not is_fused_attn_kernel_available(
273
            is_training,
274
275
276
277
278
            dtype,
            dtype,
            QKVLayout.BSHD_BS2HD,
            attn_bias_type,
            attn_mask_type,
279
            softmax_type,
280
281
282
283
284
285
            dropout_prob,
            num_head,
            num_head,
            seqlen,
            seqlen,
            hidden,
286
            hidden,
287
            None,  # no window
288
        ):
289
            pytest.skip("No FusedAttn backend found")
290

291
292
293
294
295
296
297
298
        col_ref = self.generate_collectives_count_ref()
        runner = FusedAttnRunner(
            batch,
            seqlen,
            seqlen,
            num_head,
            num_head,
            hidden,
299
            hidden,
300
301
            attn_bias_type,
            attn_mask_type,
302
            softmax_type,
303
304
305
306
307
308
            dropout_prob,
            dtype,
            is_training,
            QKVLayout.BSHD_BS2HD,
            bias_shape,
            None,
309
            SeqDescFormat.Seqlens,
310
311
312
313
314
            number_of_devices=device_count,
            mesh_shape=mesh_shape,
            mesh_axes=mesh_axes,
            mesh_resource=mesh_resource,
            coll_count_ref=col_ref,
315
        )
316
        runner.test_backward()
317
318


319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS = [
    pytest.param(QKVLayout.BSHD_BS2HD, AttnMaskType.CAUSAL_MASK, id="BSHD_KVPACKED-CAUSAL"),
    pytest.param(QKVLayout.BSHD_BSHD_BSHD, AttnMaskType.CAUSAL_MASK, id="BSHD_SEPARATE-CAUSAL"),
    pytest.param(QKVLayout.BSHD_BS2HD, AttnMaskType.NO_MASK, id="HD_KVPACKED-NO_MASK"),
    pytest.param(QKVLayout.BSHD_BSHD_BSHD, AttnMaskType.NO_MASK, id="BSHD_SEPARATE-NO_MASK"),
    pytest.param(
        QKVLayout.THD_THD_THD, AttnMaskType.PADDING_CAUSAL_MASK, id="THD_SEPARATE-PADDING_CAUSAL"
    ),
]

DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES = [
    # Sequence lengths will be scaled by CP so that we don't run with tiny sizes.
    pytest.param([2, 128, 8, 128], id="2-128xCP-8-128"),
    pytest.param([4, 256, 16, 64], id="4-256xCP-16-64"),
]


336
class TestDistributedContextParallelSelfAttn:
337

338
    def impl_test_context_parallel_attn(
339
340
341
342
343
344
345
346
347
348
349
        self,
        device_count,
        mesh_shape,
        mesh_axes,
        mesh_resource,
        data_shape,
        kv_groups,
        attn_mask_type,
        dtype,
        qkv_layout,
        load_balanced,
350
        cp_strategy,
351
352
        use_shardy,
        use_scan_ring=False,
353
        window_size=None,
354
    ):
355
356
357
358
359
360
361
362
363
364
365
366
367
368
        if qkv_layout.is_thd():
            if cp_strategy == CPStrategy.ALL_GATHER:
                pytest.skip("THD doesn't support all gather context parallelism.")
            if not load_balanced and cp_strategy == CPStrategy.RING:
                pytest.skip("THD + ring doesn't support unbalanced context parallelism.")

        assert not use_scan_ring or cp_strategy == CPStrategy.RING

        if use_scan_ring:
            os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"] = "1"
        else:
            os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"] = "0"

        jax.config.update("jax_use_shardy_partitioner", use_shardy)
369
        attn_bias_type = AttnBiasType.NO_BIAS
370
        bias_shape = None
371
372
        dropout_prob = 0.0
        is_training = True
373
374
        # Context parallel does not support softmax_offset
        softmax_type = AttnSoftmaxType.VANILLA_SOFTMAX
375
376
        dp_size, cp_size, tp_size = mesh_shape

377
        batch, seqlen, num_head, hidden = data_shape
378
379
380
381
382
383

        # Scale the sequence length by 2*CP so its never too small as we scale up test.
        # 2*CP is used since we split into two CP groups for load balancing.
        seqlen = seqlen * cp_size * 2
        data_shape = batch, seqlen, num_head, hidden

384
385
        num_kv_heads = num_head // kv_groups

386
387
388
389
390
391
392
        runner = FusedAttnRunner(
            batch,
            seqlen,
            seqlen,
            num_head,
            num_kv_heads,
            hidden,
393
            hidden,
394
395
            attn_bias_type,
            attn_mask_type,
396
            softmax_type,
397
398
399
400
401
            dropout_prob,
            dtype,
            is_training,
            qkv_layout,
            bias_shape,
402
            window_size,
Reese Wang's avatar
Reese Wang committed
403
            SeqDescFormat.SegmentIDs,
404
405
406
407
408
409
410
411
            number_of_devices=device_count,
            mesh_shape=mesh_shape,
            mesh_axes=mesh_axes,
            mesh_resource=mesh_resource,
            cp_strategy=cp_strategy,
            cp_load_balanced=load_balanced,
        )

412
413
        def check_has_backend_for_mask(mask_type):
            return is_fused_attn_kernel_available(
414
                is_training,
415
416
417
418
                dtype,
                dtype,
                qkv_layout,
                attn_bias_type,
Reese Wang's avatar
Reese Wang committed
419
                mask_type,
420
                softmax_type,
421
422
423
424
425
426
                dropout_prob,
                num_head,
                num_kv_heads,
                seqlen,
                seqlen,
                hidden,
427
                hidden,
428
429
430
431
432
433
434
435
436
437
438
439
440
                None,
            )  # no SWA for CP

        # For causal masking we depend on having bottom right support also.
        # The API does not check this and instead we rely on lower level checks to raise
        # and exception if the step backend is not supported. This was a deliberate API
        # decision to keep the CP size or flag out of the function.
        has_backend = check_has_backend_for_mask(attn_mask_type)
        if cp_size > 1 and attn_mask_type == AttnMaskType.CAUSAL_MASK:
            has_backend &= check_has_backend_for_mask(AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK)

        if not has_backend:
            pytest.skip(f"No FusedAttn backend found {cp_size=} {attn_mask_type=}.")
441

442
443
444
        if dp_size > 1 and batch % dp_size != 0:
            pytest.skip(f"Skipping {batch=} not a multiple of {dp_size=}")

445
        # make sure the mesh even divides cp and tp axis
446
447
448
        if num_head % kv_groups != 0 or (num_head // kv_groups) % tp_size != 0:
            pytest.skip(f"Skipping {kv_groups=} not multiple of {data_shape=} or {tp_size=}")

449
        runner.test_backward()
450
451
        del os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"]

452
453
454
    @pytest_parametrize_wrapper(
        "device_count,mesh_shape,mesh_axes,mesh_resource",
        generate_context_parallel_configs_for_attn(),
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
    )
    @pytest.mark.parametrize("data_shape", DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES[:1])
    @pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")])
    @pytest.mark.parametrize(
        "qkv_layout, attn_mask_type",
        DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS,
    )
    def test_context_parallel_allgather_attn_shardy(
        self,
        device_count,
        mesh_shape,
        mesh_axes,
        mesh_resource,
        data_shape,
        attn_mask_type,
        dtype,
        qkv_layout,
    ):
        kv_groups = 8
        self.impl_test_context_parallel_attn(
            device_count,
            mesh_shape,
            mesh_axes,
            mesh_resource,
            data_shape,
            kv_groups,
            attn_mask_type,
            dtype,
            qkv_layout,
            load_balanced=True,
            cp_strategy=CPStrategy.ALL_GATHER,
            use_shardy=True,
        )
488

489
490
491
    @pytest_parametrize_wrapper(
        "device_count,mesh_shape,mesh_axes,mesh_resource",
        generate_context_parallel_configs_for_attn(),
492
493
494
495
496
497
498
499
500
501
502
503
    )
    @pytest.mark.parametrize("data_shape", DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES)
    @pytest.mark.parametrize("kv_groups", [1, 8])
    @pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")])
    @pytest.mark.parametrize(
        "qkv_layout, attn_mask_type",
        DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS,
    )
    @pytest.mark.parametrize(
        "load_balanced",
        [pytest.param(True, id="BALANCED"), pytest.param(False, id="UNBALANCED")],
    )
504
    def test_context_parallel_allgather_attn(
505
506
507
508
509
510
511
512
513
514
515
516
        self,
        device_count,
        mesh_shape,
        mesh_axes,
        mesh_resource,
        data_shape,
        kv_groups,
        attn_mask_type,
        dtype,
        qkv_layout,
        load_balanced,
    ):
517
        self.impl_test_context_parallel_attn(
518
519
520
521
522
523
524
525
526
527
528
            device_count,
            mesh_shape,
            mesh_axes,
            mesh_resource,
            data_shape,
            kv_groups,
            attn_mask_type,
            dtype,
            qkv_layout,
            load_balanced,
            CPStrategy.ALL_GATHER,
529
            use_shardy=False,
530
531
        )

532
533
534
    @pytest_parametrize_wrapper(
        "device_count,mesh_shape,mesh_axes,mesh_resource",
        generate_context_parallel_configs_for_attn(),
535
536
537
538
539
540
541
542
543
544
545
546
    )
    @pytest.mark.parametrize("data_shape", DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES)
    @pytest.mark.parametrize("kv_groups", [1, 8])
    @pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")])
    @pytest.mark.parametrize(
        "qkv_layout, attn_mask_type",
        DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS,
    )
    @pytest.mark.parametrize(
        "load_balanced",
        [pytest.param(True, id="BALANCED"), pytest.param(False, id="UNBALANCED")],
    )
547
548
549
550
    @pytest.mark.parametrize(
        "use_scan",
        [pytest.param(False, id="NO_SCAN"), pytest.param(True, id="USE_SCAN")],
    )
551
552
553
554
555
556
557
    @pytest.mark.parametrize(
        "window_size",
        [
            pytest.param((-1, -1), id="window_size(-1, -1)"),
            pytest.param((20, 0), id="window_size(20, 0)"),
        ],
    )
558
559
560
561
562
563
564
565
566
567
568
569
    def test_context_parallel_ring_attn(
        self,
        device_count,
        mesh_shape,
        mesh_axes,
        mesh_resource,
        data_shape,
        kv_groups,
        attn_mask_type,
        dtype,
        qkv_layout,
        load_balanced,
570
        use_scan,
571
        window_size,
572
    ):
573
574
575
576
577
578
579
        if window_size != (-1, -1) and not qkv_layout.is_thd():
            pytest.skip("Sliding window attention is only supported for THD layout")
        if window_size != (-1, -1) and qkv_layout.is_thd() and use_scan:
            pytest.skip(
                "When context parallelism and sliding window attention are used, "
                "scanloop is not supported"
            )
580
        self.impl_test_context_parallel_attn(
581
582
583
584
585
586
587
588
589
590
591
            device_count,
            mesh_shape,
            mesh_axes,
            mesh_resource,
            data_shape,
            kv_groups,
            attn_mask_type,
            dtype,
            qkv_layout,
            load_balanced,
            CPStrategy.RING,
592
593
            use_shardy=False,
            use_scan_ring=use_scan,
594
            window_size=window_size,
595
596
        )

597
598
599
    @pytest_parametrize_wrapper(
        "device_count,mesh_shape,mesh_axes,mesh_resource",
        generate_context_parallel_configs_for_attn(),
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
    )
    @pytest.mark.parametrize("data_shape", DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES[:1])
    @pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")])
    @pytest.mark.parametrize(
        "qkv_layout, attn_mask_type",
        DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS,
    )
    def test_context_parallel_ring_attn_shardy(
        self,
        device_count,
        mesh_shape,
        mesh_axes,
        mesh_resource,
        data_shape,
        attn_mask_type,
        dtype,
        qkv_layout,
    ):
        kv_groups = 8
        self.impl_test_context_parallel_attn(
            device_count,
            mesh_shape,
            mesh_axes,
            mesh_resource,
            data_shape,
            kv_groups,
            attn_mask_type,
            dtype,
            qkv_layout,
            load_balanced=True,
            cp_strategy=CPStrategy.RING,
            use_shardy=False,
            use_scan_ring=True,
633
634
        )

635

636
637
638
639
640
641
642
REORDER_CAUSAL_LOAD_BALANCING_DATA_SHAPES = {
    "L0": [[]],
    "L1": [[3, 32, 8, 64]],
    "L2": [[4, 32, 12, 32], [1, 16, 1, 1]],
}


643
644
class TestReorderCausalLoadBalancing:
    @pytest.mark.parametrize("cp_size", [2, 4, 8])
645
    @pytest_parametrize_wrapper("shape", REORDER_CAUSAL_LOAD_BALANCING_DATA_SHAPES)
646
    @pytest.mark.parametrize("qkv_format", [QKVFormat.BSHD, QKVFormat.SBHD])
Reese Wang's avatar
Reese Wang committed
647
648
649
650
651
652
653
654
    @pytest.mark.parametrize(
        "reorder_strategy",
        [
            pytest.param(ReorderStrategy.DualChunkSwap, id="DualChunkSwap"),
            pytest.param(ReorderStrategy.Striped, id="Striped"),
        ],
    )
    def test(self, cp_size, shape, qkv_format, reorder_strategy):
655
        tensor = random.normal(random.PRNGKey(1124), shape, dtype=jnp.bfloat16)
Reese Wang's avatar
Reese Wang committed
656
        seq_dim = 1
657
658
        if qkv_format == QKVFormat.SBHD:
            tensor = tensor.swapaxes(0, 1)
Reese Wang's avatar
Reese Wang committed
659
            seq_dim = 0
660
661
662

        ref = tensor.copy()

Reese Wang's avatar
Reese Wang committed
663
664
        reorder = jax.jit(reorder_causal_load_balancing, static_argnums=[1, 2, 3])
        inverse = jax.jit(inverse_reorder_causal_load_balancing, static_argnums=[1, 2, 3])
665

Reese Wang's avatar
Reese Wang committed
666
667
        reordered = reorder(tensor, reorder_strategy, cp_size, seq_dim)
        inversed = inverse(reordered, reorder_strategy, cp_size, seq_dim)
668
669

        assert jnp.array_equal(inversed, ref)