test_distributed_fused_attn.py 23.5 KB
Newer Older
1
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.

5
6
import os
import pytest
7
8
9
import jax
import jax.numpy as jnp
from jax import random
10
11
from distributed_test_base import (
    generate_configs,
12
    generate_context_parallel_configs_for_attn,
13
    generate_collectives_count,
14
)
15
from test_fused_attn import FusedAttnRunner, BiasShape, SeqDescFormat
16
from utils import pytest_parametrize_wrapper
17
18
19
20
from transformer_engine.jax.attention import (
    is_fused_attn_kernel_available,
    AttnBiasType,
    AttnMaskType,
21
    AttnSoftmaxType,
22
    QKVLayout,
23
24
25
    QKVFormat,
    reorder_causal_load_balancing,
    inverse_reorder_causal_load_balancing,
26
    CPStrategy,
Reese Wang's avatar
Reese Wang committed
27
    ReorderStrategy,
28
29
)

30

31
DTYPES = [jnp.bfloat16]
32

33
34
35
36
37
38
DISTRIBUTED_SELF_ATTN_DATA_SHAPES = {
    "L0": [()],
    "L1": [(32, 1024, 16, 128)],
    "L2": [(32, 512, 12, 64)],
}

39
40
41

class TestDistributedSelfAttn:

42
43
44
    def generate_collectives_count_ref(
        self, mesh_shape, mesh_axes, mesh_resource, with_bias, shape, dtype
    ):
45
        jax_dtype = jax.dtypes.canonicalize_dtype(dtype)
46
        _, seqlen, heads, _ = shape
47
48
        is_dp_enabled = mesh_resource.dp_resource is not None
        tp_size = 1
49
50
        if mesh_resource.tpsp_resource is not None:
            idx = mesh_axes.index(mesh_resource.tpsp_resource)
51
52
            tp_size = mesh_shape[idx]

53
        all_reduce_loss_bytes = 4  # 1 * FP32
54
55
56
57
58
        bias_bytes = int(with_bias) * (heads // tp_size) * seqlen * seqlen * jax_dtype.itemsize
        allreduce_total_bytes = all_reduce_loss_bytes + (bias_bytes * is_dp_enabled)
        # for loss and dbias
        return generate_collectives_count(allreduce=allreduce_total_bytes, allgather=0, other=0)

59
    def impl_test_self_attn(
60
61
62
63
64
65
66
        self,
        device_count,
        mesh_shape,
        mesh_axes,
        mesh_resource,
        data_shape,
        attn_bias_type,
67
        bias_shape,
68
69
        attn_mask_type,
        dtype,
70
        softmax_type,
71
        use_shardy,
72
    ):
73
        jax.config.update("jax_use_shardy_partitioner", use_shardy)
74
75
        dropout_prob = 0.0
        is_training = True
76
        batch, seqlen, num_head, hidden = data_shape
77

78
        if not is_fused_attn_kernel_available(
79
            is_training,
80
81
82
83
84
            dtype,
            dtype,
            QKVLayout.BS3HD,
            attn_bias_type,
            attn_mask_type,
85
            softmax_type,
86
87
88
89
90
91
            dropout_prob,
            num_head,
            num_head,
            seqlen,
            seqlen,
            hidden,
92
            hidden,
93
            None,  # no window
94
        ):
95
            pytest.skip("No FusedAttn backend found")
96

97
98
99
100
101
102
103
        col_ref = self.generate_collectives_count_ref(
            mesh_shape,
            mesh_axes,
            mesh_resource,
            attn_bias_type != AttnBiasType.NO_BIAS,
            data_shape,
            dtype,
104
        )
105
106
107
108
109
110
111
        runner = FusedAttnRunner(
            batch,
            seqlen,
            seqlen,
            num_head,
            num_head,
            hidden,
112
            hidden,
113
114
            attn_bias_type,
            attn_mask_type,
115
            softmax_type,
116
117
118
119
120
121
            dropout_prob,
            dtype,
            is_training,
            QKVLayout.BS3HD,
            bias_shape,
            None,
122
            SeqDescFormat.Seqlens,
123
124
125
126
127
            number_of_devices=device_count,
            mesh_shape=mesh_shape,
            mesh_axes=mesh_axes,
            mesh_resource=mesh_resource,
            coll_count_ref=col_ref,
128
        )
129
        runner.test_backward()
130

131
    @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
132
    @pytest_parametrize_wrapper("data_shape", DISTRIBUTED_SELF_ATTN_DATA_SHAPES)
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
    @pytest.mark.parametrize(
        "attn_bias_type, bias_shape",
        [
            pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"),
            pytest.param(AttnBiasType.PRE_SCALE_BIAS, BiasShape._1HSS, id="PRE_SCALE_BIAS-1HSS"),
            pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._1HSS, id="POST_SCALE_BIAS-1HSS"),
        ],
    )
    @pytest.mark.parametrize(
        "attn_mask_type",
        [
            pytest.param(AttnMaskType.PADDING_MASK, id="PADDING_MASK"),
            pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL_MASK"),
        ],
    )
    @pytest.mark.parametrize("dtype", DTYPES)
149
150
151
152
153
154
155
156
    @pytest.mark.parametrize(
        "softmax_type",
        [
            pytest.param(AttnSoftmaxType.VANILLA_SOFTMAX, id="VANILLA_SOFTMAX"),
            pytest.param(AttnSoftmaxType.OFF_BY_ONE_SOFTMAX, id="OFF_BY_ONE_SOFTMAX"),
            pytest.param(AttnSoftmaxType.LEARNABLE_SOFTMAX, id="LEARNABLE_SOFTMAX"),
        ],
    )
157
158
159
160
161
162
163
164
165
166
167
    def test_self_attn(
        self,
        device_count,
        mesh_shape,
        mesh_axes,
        mesh_resource,
        data_shape,
        attn_bias_type,
        bias_shape,
        attn_mask_type,
        dtype,
168
        softmax_type,
169
170
171
172
173
174
175
176
177
178
179
    ):
        self.impl_test_self_attn(
            device_count,
            mesh_shape,
            mesh_axes,
            mesh_resource,
            data_shape,
            attn_bias_type,
            bias_shape,
            attn_mask_type,
            dtype,
180
            softmax_type,
181
182
183
184
185
186
187
188
189
190
191
            use_shardy=False,
        )

    @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
    @pytest.mark.parametrize(
        "attn_bias_type, bias_shape",
        [
            pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"),
            pytest.param(AttnBiasType.PRE_SCALE_BIAS, BiasShape._1HSS, id="PRE_SCALE_BIAS-1HSS"),
        ],
    )
192
193
194
195
196
197
198
199
    @pytest.mark.parametrize(
        "softmax_type",
        [
            pytest.param(AttnSoftmaxType.VANILLA_SOFTMAX, id="VANILLA_SOFTMAX"),
            pytest.param(AttnSoftmaxType.OFF_BY_ONE_SOFTMAX, id="OFF_BY_ONE_SOFTMAX"),
            pytest.param(AttnSoftmaxType.LEARNABLE_SOFTMAX, id="LEARNABLE_SOFTMAX"),
        ],
    )
200
    def test_self_attn_shardy(
201
202
203
204
205
206
207
208
        self,
        device_count,
        mesh_shape,
        mesh_axes,
        mesh_resource,
        attn_bias_type,
        bias_shape,
        softmax_type,
209
210
211
212
213
214
215
216
217
218
219
220
    ):
        data_shape = (32, 512, 12, 64)
        self.impl_test_self_attn(
            device_count,
            mesh_shape,
            mesh_axes,
            mesh_resource,
            data_shape,
            attn_bias_type,
            bias_shape,
            AttnMaskType.PADDING_MASK,
            jnp.bfloat16,
221
            softmax_type,
222
223
224
            use_shardy=True,
        )

225

226
227
228
229
230
231
232
DISTRIBUTED_CROSS_ATTN_DATA_SHAPES = {
    "L0": [()],
    "L1": [[32, 512, 16, 64]],
    "L2": [[32, 128, 12, 64]],
}


233
234
235
236
class TestDistributedCrossAttn:

    def generate_collectives_count_ref(self):
        # for loss
237
        all_reduce_loss_bytes = 4  # 1 * FP32
238
239
        return generate_collectives_count(allreduce=all_reduce_loss_bytes, allgather=0, other=0)

240
    @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
241
    @pytest_parametrize_wrapper("data_shape", DISTRIBUTED_CROSS_ATTN_DATA_SHAPES)
242
243
244
245
    @pytest.mark.parametrize(
        "attn_mask_type", [AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK]
    )
    @pytest.mark.parametrize("dtype", DTYPES)
246
247
248
249
250
251
252
253
    @pytest.mark.parametrize(
        "softmax_type",
        [
            pytest.param(AttnSoftmaxType.VANILLA_SOFTMAX, id="VANILLA_SOFTMAX"),
            pytest.param(AttnSoftmaxType.OFF_BY_ONE_SOFTMAX, id="OFF_BY_ONE_SOFTMAX"),
            pytest.param(AttnSoftmaxType.LEARNABLE_SOFTMAX, id="LEARNABLE_SOFTMAX"),
        ],
    )
254
    def test_cross_attn(
255
256
257
258
259
260
261
262
263
        self,
        device_count,
        mesh_shape,
        mesh_axes,
        mesh_resource,
        data_shape,
        attn_mask_type,
        dtype,
        softmax_type,
264
    ):
265
        attn_bias_type = AttnBiasType.NO_BIAS
266
        bias_shape = None
267
268
269
        dropout_prob = 0.0
        is_training = True

270
        batch, seqlen, num_head, hidden = data_shape
271

272
        if not is_fused_attn_kernel_available(
273
            is_training,
274
275
276
277
278
            dtype,
            dtype,
            QKVLayout.BSHD_BS2HD,
            attn_bias_type,
            attn_mask_type,
279
            softmax_type,
280
281
282
283
284
285
            dropout_prob,
            num_head,
            num_head,
            seqlen,
            seqlen,
            hidden,
286
            hidden,
287
            None,  # no window
288
        ):
289
            pytest.skip("No FusedAttn backend found")
290

291
292
293
294
295
296
297
298
        col_ref = self.generate_collectives_count_ref()
        runner = FusedAttnRunner(
            batch,
            seqlen,
            seqlen,
            num_head,
            num_head,
            hidden,
299
            hidden,
300
301
            attn_bias_type,
            attn_mask_type,
302
            softmax_type,
303
304
305
306
307
308
            dropout_prob,
            dtype,
            is_training,
            QKVLayout.BSHD_BS2HD,
            bias_shape,
            None,
309
            SeqDescFormat.Seqlens,
310
311
312
313
314
            number_of_devices=device_count,
            mesh_shape=mesh_shape,
            mesh_axes=mesh_axes,
            mesh_resource=mesh_resource,
            coll_count_ref=col_ref,
315
        )
316
        runner.test_backward()
317
318


319
320
321
322
323
324
325
326
327
328
329
DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS = [
    pytest.param(QKVLayout.BSHD_BS2HD, AttnMaskType.CAUSAL_MASK, id="BSHD_KVPACKED-CAUSAL"),
    pytest.param(QKVLayout.BSHD_BSHD_BSHD, AttnMaskType.CAUSAL_MASK, id="BSHD_SEPARATE-CAUSAL"),
    pytest.param(QKVLayout.BSHD_BS2HD, AttnMaskType.NO_MASK, id="HD_KVPACKED-NO_MASK"),
    pytest.param(QKVLayout.BSHD_BSHD_BSHD, AttnMaskType.NO_MASK, id="BSHD_SEPARATE-NO_MASK"),
    pytest.param(
        QKVLayout.THD_THD_THD, AttnMaskType.PADDING_CAUSAL_MASK, id="THD_SEPARATE-PADDING_CAUSAL"
    ),
]

DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES = [
330
331
332
    # Sequence lengths will be scaled by CP*2 so that we don't run with tiny sizes.
    pytest.param([2, 128, 8, 128], id="2-128xCPx2-8-128"),
    pytest.param([4, 256, 16, 64], id="4-256xCPx2-16-64"),
333
334
335
]


336
class TestDistributedContextParallelSelfAttn:
337
    # TODO(KshitijLakhani): parametrize num_segments_per_seq for all CP tests
338
    def impl_test_context_parallel_attn(
339
340
341
342
343
344
345
346
347
348
349
        self,
        device_count,
        mesh_shape,
        mesh_axes,
        mesh_resource,
        data_shape,
        kv_groups,
        attn_mask_type,
        dtype,
        qkv_layout,
        load_balanced,
350
        cp_strategy,
351
352
        use_shardy,
        use_scan_ring=False,
353
        window_size=None,
354
355
        stripe_size=None,
        num_segments_per_seq=None,
356
    ):
357
        if qkv_layout.is_thd():
358
359
360
361
            if not load_balanced and (
                cp_strategy == CPStrategy.RING or cp_strategy == CPStrategy.ALL_GATHER
            ):
                pytest.skip(f"THD + {cp_strategy=} doesn't support unbalanced context parallelism.")
362
363
364
365
366
367
368
369
370

        assert not use_scan_ring or cp_strategy == CPStrategy.RING

        if use_scan_ring:
            os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"] = "1"
        else:
            os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"] = "0"

        jax.config.update("jax_use_shardy_partitioner", use_shardy)
371
        attn_bias_type = AttnBiasType.NO_BIAS
372
        bias_shape = None
373
374
        dropout_prob = 0.0
        is_training = True
375
376
        # Context parallel does not support softmax_offset
        softmax_type = AttnSoftmaxType.VANILLA_SOFTMAX
377
378
        dp_size, cp_size, tp_size = mesh_shape

379
        batch, seqlen, num_head, hidden = data_shape
380
381
382
383
384
385

        # Scale the sequence length by 2*CP so its never too small as we scale up test.
        # 2*CP is used since we split into two CP groups for load balancing.
        seqlen = seqlen * cp_size * 2
        data_shape = batch, seqlen, num_head, hidden

386
        num_kv_heads = num_head // kv_groups
387
388
389
390
391
392
393
        runner = FusedAttnRunner(
            batch,
            seqlen,
            seqlen,
            num_head,
            num_kv_heads,
            hidden,
394
            hidden,
395
396
            attn_bias_type,
            attn_mask_type,
397
            softmax_type,
398
399
400
401
402
            dropout_prob,
            dtype,
            is_training,
            qkv_layout,
            bias_shape,
403
            window_size,
Reese Wang's avatar
Reese Wang committed
404
            SeqDescFormat.SegmentIDs,
405
406
            stripe_size=stripe_size,
            num_segments_per_seq=num_segments_per_seq,
407
408
409
410
411
412
413
414
            number_of_devices=device_count,
            mesh_shape=mesh_shape,
            mesh_axes=mesh_axes,
            mesh_resource=mesh_resource,
            cp_strategy=cp_strategy,
            cp_load_balanced=load_balanced,
        )

415
416
        def check_has_backend_for_mask(mask_type):
            return is_fused_attn_kernel_available(
417
                is_training,
418
419
420
421
                dtype,
                dtype,
                qkv_layout,
                attn_bias_type,
Reese Wang's avatar
Reese Wang committed
422
                mask_type,
423
                softmax_type,
424
425
426
427
428
429
                dropout_prob,
                num_head,
                num_kv_heads,
                seqlen,
                seqlen,
                hidden,
430
                hidden,
431
432
433
434
435
436
437
438
439
440
441
442
443
                None,
            )  # no SWA for CP

        # For causal masking we depend on having bottom right support also.
        # The API does not check this and instead we rely on lower level checks to raise
        # and exception if the step backend is not supported. This was a deliberate API
        # decision to keep the CP size or flag out of the function.
        has_backend = check_has_backend_for_mask(attn_mask_type)
        if cp_size > 1 and attn_mask_type == AttnMaskType.CAUSAL_MASK:
            has_backend &= check_has_backend_for_mask(AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK)

        if not has_backend:
            pytest.skip(f"No FusedAttn backend found {cp_size=} {attn_mask_type=}.")
444

445
446
447
        if dp_size > 1 and batch % dp_size != 0:
            pytest.skip(f"Skipping {batch=} not a multiple of {dp_size=}")

448
        # make sure the mesh even divides cp and tp axis
449
450
451
        if num_head % kv_groups != 0 or (num_head // kv_groups) % tp_size != 0:
            pytest.skip(f"Skipping {kv_groups=} not multiple of {data_shape=} or {tp_size=}")

452
        runner.test_backward()
453
454
        del os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"]

455
456
457
    @pytest_parametrize_wrapper(
        "device_count,mesh_shape,mesh_axes,mesh_resource",
        generate_context_parallel_configs_for_attn(),
458
    )
459
    @pytest.mark.parametrize("data_shape", DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES)
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
    @pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")])
    @pytest.mark.parametrize(
        "qkv_layout, attn_mask_type",
        DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS,
    )
    def test_context_parallel_allgather_attn_shardy(
        self,
        device_count,
        mesh_shape,
        mesh_axes,
        mesh_resource,
        data_shape,
        attn_mask_type,
        dtype,
        qkv_layout,
    ):
476
477
        if qkv_layout.is_thd():
            pytest.skip("Only BSHD layout is supported for CP + AG + Dual chunk attention")
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
        kv_groups = 8
        self.impl_test_context_parallel_attn(
            device_count,
            mesh_shape,
            mesh_axes,
            mesh_resource,
            data_shape,
            kv_groups,
            attn_mask_type,
            dtype,
            qkv_layout,
            load_balanced=True,
            cp_strategy=CPStrategy.ALL_GATHER,
            use_shardy=True,
        )
493

494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
    @pytest_parametrize_wrapper(
        "device_count,mesh_shape,mesh_axes,mesh_resource",
        generate_context_parallel_configs_for_attn(),
    )
    @pytest.mark.parametrize("data_shape", DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES[:1])
    @pytest.mark.parametrize("kv_groups", [1, 8])
    @pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")])
    @pytest.mark.parametrize(
        "qkv_layout, attn_mask_type",
        DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS,
    )
    @pytest.mark.parametrize(
        "load_balanced",
        [pytest.param(True, id="BALANCED")],
    )
    @pytest.mark.parametrize(
        "stripe_size",
        [pytest.param(64, id="STRIPE-64"), pytest.param(128, id="STRIPE-128")],
    )
    @pytest.mark.parametrize(
        "window_size",
        [
            pytest.param((-1, -1), id="window_size(-1, -1)"),
            pytest.param((5, 0), id="window_size(8, 0)"),
        ],
    )
    @pytest.mark.parametrize(
        "num_segments_per_seq",
        [pytest.param(5, id="SEG-5")],
    )
    def test_context_parallel_allgather_striped_attn(
        self,
        device_count,
        mesh_shape,
        mesh_axes,
        mesh_resource,
        data_shape,
        kv_groups,
        attn_mask_type,
        dtype,
        qkv_layout,
        load_balanced,
        window_size,
        stripe_size,
        num_segments_per_seq,
    ):
        if not qkv_layout.is_thd():
            pytest.skip("Only THD layout is supported for CP + AG + Striped attention")
        self.impl_test_context_parallel_attn(
            device_count,
            mesh_shape,
            mesh_axes,
            mesh_resource,
            data_shape,
            kv_groups,
            attn_mask_type,
            dtype,
            qkv_layout,
            load_balanced,
            CPStrategy.ALL_GATHER,
            use_shardy=False,
            window_size=window_size,
            stripe_size=stripe_size,
            num_segments_per_seq=num_segments_per_seq,
        )

560
561
562
    @pytest_parametrize_wrapper(
        "device_count,mesh_shape,mesh_axes,mesh_resource",
        generate_context_parallel_configs_for_attn(),
563
564
565
566
567
568
569
570
571
572
573
574
    )
    @pytest.mark.parametrize("data_shape", DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES)
    @pytest.mark.parametrize("kv_groups", [1, 8])
    @pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")])
    @pytest.mark.parametrize(
        "qkv_layout, attn_mask_type",
        DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS,
    )
    @pytest.mark.parametrize(
        "load_balanced",
        [pytest.param(True, id="BALANCED"), pytest.param(False, id="UNBALANCED")],
    )
575
    def test_context_parallel_allgather_attn(
576
577
578
579
580
581
582
583
584
585
586
587
        self,
        device_count,
        mesh_shape,
        mesh_axes,
        mesh_resource,
        data_shape,
        kv_groups,
        attn_mask_type,
        dtype,
        qkv_layout,
        load_balanced,
    ):
588
589
        if qkv_layout.is_thd():
            pytest.skip("Only BSHD layout is supported for CP + AG + Dual chunk attention")
590
        self.impl_test_context_parallel_attn(
591
592
593
594
595
596
597
598
599
600
601
            device_count,
            mesh_shape,
            mesh_axes,
            mesh_resource,
            data_shape,
            kv_groups,
            attn_mask_type,
            dtype,
            qkv_layout,
            load_balanced,
            CPStrategy.ALL_GATHER,
602
            use_shardy=False,
603
604
        )

605
606
607
    @pytest_parametrize_wrapper(
        "device_count,mesh_shape,mesh_axes,mesh_resource",
        generate_context_parallel_configs_for_attn(),
608
609
610
611
612
613
614
615
616
617
618
619
    )
    @pytest.mark.parametrize("data_shape", DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES)
    @pytest.mark.parametrize("kv_groups", [1, 8])
    @pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")])
    @pytest.mark.parametrize(
        "qkv_layout, attn_mask_type",
        DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS,
    )
    @pytest.mark.parametrize(
        "load_balanced",
        [pytest.param(True, id="BALANCED"), pytest.param(False, id="UNBALANCED")],
    )
620
621
622
623
    @pytest.mark.parametrize(
        "use_scan",
        [pytest.param(False, id="NO_SCAN"), pytest.param(True, id="USE_SCAN")],
    )
624
625
626
627
628
629
630
    @pytest.mark.parametrize(
        "window_size",
        [
            pytest.param((-1, -1), id="window_size(-1, -1)"),
            pytest.param((20, 0), id="window_size(20, 0)"),
        ],
    )
631
632
633
634
635
636
637
638
639
640
641
642
    def test_context_parallel_ring_attn(
        self,
        device_count,
        mesh_shape,
        mesh_axes,
        mesh_resource,
        data_shape,
        kv_groups,
        attn_mask_type,
        dtype,
        qkv_layout,
        load_balanced,
643
        use_scan,
644
        window_size,
645
    ):
646
647
648
649
650
651
652
        if window_size != (-1, -1) and not qkv_layout.is_thd():
            pytest.skip("Sliding window attention is only supported for THD layout")
        if window_size != (-1, -1) and qkv_layout.is_thd() and use_scan:
            pytest.skip(
                "When context parallelism and sliding window attention are used, "
                "scanloop is not supported"
            )
653
654
        # Set the stripe size to 1 (ring attention only support stripe_size=1)
        stripe_size = 1 if qkv_layout.is_thd() else None
655
        self.impl_test_context_parallel_attn(
656
657
658
659
660
661
662
663
664
665
666
            device_count,
            mesh_shape,
            mesh_axes,
            mesh_resource,
            data_shape,
            kv_groups,
            attn_mask_type,
            dtype,
            qkv_layout,
            load_balanced,
            CPStrategy.RING,
667
668
            use_shardy=False,
            use_scan_ring=use_scan,
669
            window_size=window_size,
670
            stripe_size=stripe_size,
671
672
        )

673
674
675
    @pytest_parametrize_wrapper(
        "device_count,mesh_shape,mesh_axes,mesh_resource",
        generate_context_parallel_configs_for_attn(),
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
    )
    @pytest.mark.parametrize("data_shape", DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES[:1])
    @pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")])
    @pytest.mark.parametrize(
        "qkv_layout, attn_mask_type",
        DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS,
    )
    def test_context_parallel_ring_attn_shardy(
        self,
        device_count,
        mesh_shape,
        mesh_axes,
        mesh_resource,
        data_shape,
        attn_mask_type,
        dtype,
        qkv_layout,
    ):
        kv_groups = 8
695
696
        # Set the stripe size to 1 (ring attention only support stripe_size=1)
        stripe_size = 1 if qkv_layout.is_thd() else None
697
698
699
700
701
702
703
704
705
706
707
708
709
710
        self.impl_test_context_parallel_attn(
            device_count,
            mesh_shape,
            mesh_axes,
            mesh_resource,
            data_shape,
            kv_groups,
            attn_mask_type,
            dtype,
            qkv_layout,
            load_balanced=True,
            cp_strategy=CPStrategy.RING,
            use_shardy=False,
            use_scan_ring=True,
711
            stripe_size=stripe_size,
712
713
        )

714

715
716
717
718
719
720
REORDER_CAUSAL_LOAD_BALANCING_DATA_SHAPES = {
    "L0": [[]],
    "L1": [[3, 32, 8, 64]],
    "L2": [[4, 32, 12, 32], [1, 16, 1, 1]],
}

721
722
723
724
725
726
REORDER_STRATEGY = [
    pytest.param(ReorderStrategy.DualChunkSwap, None, id="DualChunkSwap"),
    pytest.param(ReorderStrategy.Striped, 1, id="Striped-1"),
    pytest.param(ReorderStrategy.Striped, 4, id="Striped-4"),
]

727

728
729
class TestReorderCausalLoadBalancing:
    @pytest.mark.parametrize("cp_size", [2, 4, 8])
730
    @pytest_parametrize_wrapper("shape", REORDER_CAUSAL_LOAD_BALANCING_DATA_SHAPES)
731
    @pytest.mark.parametrize("qkv_format", [QKVFormat.BSHD, QKVFormat.SBHD, QKVFormat.THD])
Reese Wang's avatar
Reese Wang committed
732
    @pytest.mark.parametrize(
733
734
        "reorder_strategy, stripe_size",
        REORDER_STRATEGY,
Reese Wang's avatar
Reese Wang committed
735
    )
736
    def test(self, cp_size, shape, qkv_format, reorder_strategy, stripe_size):
737
        tensor = random.normal(random.PRNGKey(1124), shape, dtype=jnp.bfloat16)
Reese Wang's avatar
Reese Wang committed
738
        seq_dim = 1
739
740
        if qkv_format == QKVFormat.SBHD:
            tensor = tensor.swapaxes(0, 1)
Reese Wang's avatar
Reese Wang committed
741
            seq_dim = 0
742

743
744
745
746
747
        if reorder_strategy == ReorderStrategy.Striped:
            seq_lens = shape[seq_dim]
            if seq_lens < (cp_size * stripe_size):
                pytest.skip(f"{seq_lens=} must be larger than {cp_size*stripe_size=}")

748
749
        ref = tensor.copy()

750
751
        reorder = jax.jit(reorder_causal_load_balancing, static_argnums=[1, 2, 3, 4])
        inverse = jax.jit(inverse_reorder_causal_load_balancing, static_argnums=[1, 2, 3, 4])
752

753
754
        reordered = reorder(tensor, reorder_strategy, cp_size, seq_dim, stripe_size)
        inversed = inverse(reordered, reorder_strategy, cp_size, seq_dim, stripe_size)
755
756

        assert jnp.array_equal(inversed, ref)