test_distributed_fused_attn.py 18.7 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.

5
6
import os
import pytest
7
8
9
import jax
import jax.numpy as jnp
from jax import random
10
11
12
13
from distributed_test_base import (
    generate_configs,
    generate_context_parallel_configs,
    generate_collectives_count,
14
)
15
from test_fused_attn import FusedAttnRunner, BiasShape, SeqDescFormat
16
17
18
19
from transformer_engine.jax.attention import (
    is_fused_attn_kernel_available,
    AttnBiasType,
    AttnMaskType,
20
    QKVLayout,
21
22
23
    QKVFormat,
    reorder_causal_load_balancing,
    inverse_reorder_causal_load_balancing,
24
    CPStrategy,
Reese Wang's avatar
Reese Wang committed
25
    ReorderStrategy,
26
27
)

28

29
DTYPES = [jnp.bfloat16]
30
31
32
33


class TestDistributedSelfAttn:

34
35
36
    def generate_collectives_count_ref(
        self, mesh_shape, mesh_axes, mesh_resource, with_bias, shape, dtype
    ):
37
        jax_dtype = jax.dtypes.canonicalize_dtype(dtype)
38
        _, seqlen, heads, _ = shape
39
40
41
42
43
44
        is_dp_enabled = mesh_resource.dp_resource is not None
        tp_size = 1
        if mesh_resource.tp_resource is not None:
            idx = mesh_axes.index(mesh_resource.tp_resource)
            tp_size = mesh_shape[idx]

45
        all_reduce_loss_bytes = 4  # 1 * FP32
46
47
48
49
50
        bias_bytes = int(with_bias) * (heads // tp_size) * seqlen * seqlen * jax_dtype.itemsize
        allreduce_total_bytes = all_reduce_loss_bytes + (bias_bytes * is_dp_enabled)
        # for loss and dbias
        return generate_collectives_count(allreduce=allreduce_total_bytes, allgather=0, other=0)

51
    def impl_test_self_attn(
52
53
54
55
56
57
58
        self,
        device_count,
        mesh_shape,
        mesh_axes,
        mesh_resource,
        data_shape,
        attn_bias_type,
59
        bias_shape,
60
61
        attn_mask_type,
        dtype,
62
        use_shardy,
63
    ):
64
        jax.config.update("jax_use_shardy_partitioner", use_shardy)
65
66
67
        dropout_prob = 0.0
        is_training = True

68
        batch, seqlen, num_head, hidden = data_shape
69

70
        if not is_fused_attn_kernel_available(
71
            is_training,
72
73
74
75
76
77
78
79
80
81
82
            dtype,
            dtype,
            QKVLayout.BS3HD,
            attn_bias_type,
            attn_mask_type,
            dropout_prob,
            num_head,
            num_head,
            seqlen,
            seqlen,
            hidden,
83
            hidden,
84
            None,  # no window
85
        ):
86
            pytest.skip("No FusedAttn backend found")
87

88
89
90
91
92
93
94
        col_ref = self.generate_collectives_count_ref(
            mesh_shape,
            mesh_axes,
            mesh_resource,
            attn_bias_type != AttnBiasType.NO_BIAS,
            data_shape,
            dtype,
95
        )
96
97
98
99
100
101
102
        runner = FusedAttnRunner(
            batch,
            seqlen,
            seqlen,
            num_head,
            num_head,
            hidden,
103
            hidden,
104
105
106
107
108
109
110
111
            attn_bias_type,
            attn_mask_type,
            dropout_prob,
            dtype,
            is_training,
            QKVLayout.BS3HD,
            bias_shape,
            None,
112
            SeqDescFormat.Seqlens,
113
114
115
116
117
            number_of_devices=device_count,
            mesh_shape=mesh_shape,
            mesh_axes=mesh_axes,
            mesh_resource=mesh_resource,
            coll_count_ref=col_ref,
118
        )
119
        runner.test_backward()
120

121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
    @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
    @pytest.mark.parametrize(
        "data_shape",
        [
            pytest.param((32, 512, 12, 64), id="32-512-12-64"),
            pytest.param((32, 1024, 16, 128), id="32-1024-16-128"),
        ],
    )
    @pytest.mark.parametrize(
        "attn_bias_type, bias_shape",
        [
            pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"),
            pytest.param(AttnBiasType.PRE_SCALE_BIAS, BiasShape._1HSS, id="PRE_SCALE_BIAS-1HSS"),
            pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._1HSS, id="POST_SCALE_BIAS-1HSS"),
        ],
    )
    @pytest.mark.parametrize(
        "attn_mask_type",
        [
            pytest.param(AttnMaskType.PADDING_MASK, id="PADDING_MASK"),
            pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL_MASK"),
        ],
    )
    @pytest.mark.parametrize("dtype", DTYPES)
    def test_self_attn(
        self,
        device_count,
        mesh_shape,
        mesh_axes,
        mesh_resource,
        data_shape,
        attn_bias_type,
        bias_shape,
        attn_mask_type,
        dtype,
    ):
        self.impl_test_self_attn(
            device_count,
            mesh_shape,
            mesh_axes,
            mesh_resource,
            data_shape,
            attn_bias_type,
            bias_shape,
            attn_mask_type,
            dtype,
            use_shardy=False,
        )

    @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
    @pytest.mark.parametrize(
        "attn_bias_type, bias_shape",
        [
            pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"),
            pytest.param(AttnBiasType.PRE_SCALE_BIAS, BiasShape._1HSS, id="PRE_SCALE_BIAS-1HSS"),
        ],
    )
    def test_self_attn_shardy(
        self, device_count, mesh_shape, mesh_axes, mesh_resource, attn_bias_type, bias_shape
    ):
        data_shape = (32, 512, 12, 64)
        self.impl_test_self_attn(
            device_count,
            mesh_shape,
            mesh_axes,
            mesh_resource,
            data_shape,
            attn_bias_type,
            bias_shape,
            AttnMaskType.PADDING_MASK,
            jnp.bfloat16,
            use_shardy=True,
        )

195
196
197
198
199

class TestDistributedCrossAttn:

    def generate_collectives_count_ref(self):
        # for loss
200
        all_reduce_loss_bytes = 4  # 1 * FP32
201
202
        return generate_collectives_count(allreduce=all_reduce_loss_bytes, allgather=0, other=0)

203
204
205
206
207
208
209
210
211
    @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
    @pytest.mark.parametrize("data_shape", [[32, 128, 12, 64], [32, 512, 16, 64]])
    @pytest.mark.parametrize(
        "attn_mask_type", [AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK]
    )
    @pytest.mark.parametrize("dtype", DTYPES)
    def test_cross_attn(
        self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, attn_mask_type, dtype
    ):
212
        attn_bias_type = AttnBiasType.NO_BIAS
213
        bias_shape = None
214
215
216
        dropout_prob = 0.0
        is_training = True

217
        batch, seqlen, num_head, hidden = data_shape
218

219
        if not is_fused_attn_kernel_available(
220
            is_training,
221
222
223
224
225
226
227
228
229
230
231
            dtype,
            dtype,
            QKVLayout.BSHD_BS2HD,
            attn_bias_type,
            attn_mask_type,
            dropout_prob,
            num_head,
            num_head,
            seqlen,
            seqlen,
            hidden,
232
            hidden,
233
            None,  # no window
234
        ):
235
            pytest.skip("No FusedAttn backend found")
236

237
238
239
240
241
242
243
244
        col_ref = self.generate_collectives_count_ref()
        runner = FusedAttnRunner(
            batch,
            seqlen,
            seqlen,
            num_head,
            num_head,
            hidden,
245
            hidden,
246
247
248
249
250
251
252
253
            attn_bias_type,
            attn_mask_type,
            dropout_prob,
            dtype,
            is_training,
            QKVLayout.BSHD_BS2HD,
            bias_shape,
            None,
254
            SeqDescFormat.Seqlens,
255
256
257
258
259
            number_of_devices=device_count,
            mesh_shape=mesh_shape,
            mesh_axes=mesh_axes,
            mesh_resource=mesh_resource,
            coll_count_ref=col_ref,
260
        )
261
        runner.test_backward()
262
263


264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS = [
    pytest.param(QKVLayout.BSHD_BS2HD, AttnMaskType.CAUSAL_MASK, id="BSHD_KVPACKED-CAUSAL"),
    pytest.param(QKVLayout.BSHD_BSHD_BSHD, AttnMaskType.CAUSAL_MASK, id="BSHD_SEPARATE-CAUSAL"),
    pytest.param(QKVLayout.BSHD_BS2HD, AttnMaskType.NO_MASK, id="HD_KVPACKED-NO_MASK"),
    pytest.param(QKVLayout.BSHD_BSHD_BSHD, AttnMaskType.NO_MASK, id="BSHD_SEPARATE-NO_MASK"),
    pytest.param(
        QKVLayout.THD_THD_THD, AttnMaskType.PADDING_CAUSAL_MASK, id="THD_SEPARATE-PADDING_CAUSAL"
    ),
]

DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES = [
    # Sequence lengths will be scaled by CP so that we don't run with tiny sizes.
    pytest.param([2, 128, 8, 128], id="2-128xCP-8-128"),
    pytest.param([4, 256, 16, 64], id="4-256xCP-16-64"),
]


281
class TestDistributedContextParallelSelfAttn:
282

283
    def impl_test_context_parallel_attn(
284
285
286
287
288
289
290
291
292
293
294
        self,
        device_count,
        mesh_shape,
        mesh_axes,
        mesh_resource,
        data_shape,
        kv_groups,
        attn_mask_type,
        dtype,
        qkv_layout,
        load_balanced,
295
        cp_strategy,
296
297
        use_shardy,
        use_scan_ring=False,
298
        window_size=None,
299
    ):
300
301
302
303
304
305
306
307
308
309
310
311
312
313
        if qkv_layout.is_thd():
            if cp_strategy == CPStrategy.ALL_GATHER:
                pytest.skip("THD doesn't support all gather context parallelism.")
            if not load_balanced and cp_strategy == CPStrategy.RING:
                pytest.skip("THD + ring doesn't support unbalanced context parallelism.")

        assert not use_scan_ring or cp_strategy == CPStrategy.RING

        if use_scan_ring:
            os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"] = "1"
        else:
            os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"] = "0"

        jax.config.update("jax_use_shardy_partitioner", use_shardy)
314
        attn_bias_type = AttnBiasType.NO_BIAS
315
        bias_shape = None
316
317
318
319
        dropout_prob = 0.0
        is_training = True
        dp_size, cp_size, tp_size = mesh_shape

320
        batch, seqlen, num_head, hidden = data_shape
321
322
323
324
325
326

        # Scale the sequence length by 2*CP so its never too small as we scale up test.
        # 2*CP is used since we split into two CP groups for load balancing.
        seqlen = seqlen * cp_size * 2
        data_shape = batch, seqlen, num_head, hidden

327
328
        num_kv_heads = num_head // kv_groups

329
330
331
332
333
334
335
        runner = FusedAttnRunner(
            batch,
            seqlen,
            seqlen,
            num_head,
            num_kv_heads,
            hidden,
336
            hidden,
337
338
339
340
341
342
343
            attn_bias_type,
            attn_mask_type,
            dropout_prob,
            dtype,
            is_training,
            qkv_layout,
            bias_shape,
344
            window_size,
Reese Wang's avatar
Reese Wang committed
345
            SeqDescFormat.SegmentIDs,
346
347
348
349
350
351
352
353
            number_of_devices=device_count,
            mesh_shape=mesh_shape,
            mesh_axes=mesh_axes,
            mesh_resource=mesh_resource,
            cp_strategy=cp_strategy,
            cp_load_balanced=load_balanced,
        )

354
355
        def check_has_backend_for_mask(mask_type):
            return is_fused_attn_kernel_available(
356
                is_training,
357
358
359
360
                dtype,
                dtype,
                qkv_layout,
                attn_bias_type,
Reese Wang's avatar
Reese Wang committed
361
                mask_type,
362
363
364
365
366
367
                dropout_prob,
                num_head,
                num_kv_heads,
                seqlen,
                seqlen,
                hidden,
368
                hidden,
369
370
371
372
373
374
375
376
377
378
379
380
381
                None,
            )  # no SWA for CP

        # For causal masking we depend on having bottom right support also.
        # The API does not check this and instead we rely on lower level checks to raise
        # and exception if the step backend is not supported. This was a deliberate API
        # decision to keep the CP size or flag out of the function.
        has_backend = check_has_backend_for_mask(attn_mask_type)
        if cp_size > 1 and attn_mask_type == AttnMaskType.CAUSAL_MASK:
            has_backend &= check_has_backend_for_mask(AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK)

        if not has_backend:
            pytest.skip(f"No FusedAttn backend found {cp_size=} {attn_mask_type=}.")
382

383
384
385
        if dp_size > 1 and batch % dp_size != 0:
            pytest.skip(f"Skipping {batch=} not a multiple of {dp_size=}")

386
        # make sure the mesh even divides cp and tp axis
387
388
389
        if num_head % kv_groups != 0 or (num_head // kv_groups) % tp_size != 0:
            pytest.skip(f"Skipping {kv_groups=} not multiple of {data_shape=} or {tp_size=}")

390
        runner.test_backward()
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
        del os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"]

    @pytest.mark.parametrize(
        "device_count,mesh_shape,mesh_axes,mesh_resource", generate_context_parallel_configs()
    )
    @pytest.mark.parametrize("data_shape", DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES[:1])
    @pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")])
    @pytest.mark.parametrize(
        "qkv_layout, attn_mask_type",
        DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS,
    )
    def test_context_parallel_allgather_attn_shardy(
        self,
        device_count,
        mesh_shape,
        mesh_axes,
        mesh_resource,
        data_shape,
        attn_mask_type,
        dtype,
        qkv_layout,
    ):
        kv_groups = 8
        self.impl_test_context_parallel_attn(
            device_count,
            mesh_shape,
            mesh_axes,
            mesh_resource,
            data_shape,
            kv_groups,
            attn_mask_type,
            dtype,
            qkv_layout,
            load_balanced=True,
            cp_strategy=CPStrategy.ALL_GATHER,
            use_shardy=True,
        )
428

429
430
431
432
433
434
435
436
437
438
439
440
441
442
    @pytest.mark.parametrize(
        "device_count,mesh_shape,mesh_axes,mesh_resource", generate_context_parallel_configs()
    )
    @pytest.mark.parametrize("data_shape", DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES)
    @pytest.mark.parametrize("kv_groups", [1, 8])
    @pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")])
    @pytest.mark.parametrize(
        "qkv_layout, attn_mask_type",
        DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS,
    )
    @pytest.mark.parametrize(
        "load_balanced",
        [pytest.param(True, id="BALANCED"), pytest.param(False, id="UNBALANCED")],
    )
443
    def test_context_parallel_allgather_attn(
444
445
446
447
448
449
450
451
452
453
454
455
        self,
        device_count,
        mesh_shape,
        mesh_axes,
        mesh_resource,
        data_shape,
        kv_groups,
        attn_mask_type,
        dtype,
        qkv_layout,
        load_balanced,
    ):
456
        self.impl_test_context_parallel_attn(
457
458
459
460
461
462
463
464
465
466
467
            device_count,
            mesh_shape,
            mesh_axes,
            mesh_resource,
            data_shape,
            kv_groups,
            attn_mask_type,
            dtype,
            qkv_layout,
            load_balanced,
            CPStrategy.ALL_GATHER,
468
            use_shardy=False,
469
470
        )

471
472
473
474
475
476
477
478
479
480
481
482
483
484
    @pytest.mark.parametrize(
        "device_count,mesh_shape,mesh_axes,mesh_resource", generate_context_parallel_configs()
    )
    @pytest.mark.parametrize("data_shape", DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES)
    @pytest.mark.parametrize("kv_groups", [1, 8])
    @pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")])
    @pytest.mark.parametrize(
        "qkv_layout, attn_mask_type",
        DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS,
    )
    @pytest.mark.parametrize(
        "load_balanced",
        [pytest.param(True, id="BALANCED"), pytest.param(False, id="UNBALANCED")],
    )
485
486
487
488
    @pytest.mark.parametrize(
        "use_scan",
        [pytest.param(False, id="NO_SCAN"), pytest.param(True, id="USE_SCAN")],
    )
489
490
491
492
493
494
495
    @pytest.mark.parametrize(
        "window_size",
        [
            pytest.param((-1, -1), id="window_size(-1, -1)"),
            pytest.param((20, 0), id="window_size(20, 0)"),
        ],
    )
496
497
498
499
500
501
502
503
504
505
506
507
    def test_context_parallel_ring_attn(
        self,
        device_count,
        mesh_shape,
        mesh_axes,
        mesh_resource,
        data_shape,
        kv_groups,
        attn_mask_type,
        dtype,
        qkv_layout,
        load_balanced,
508
        use_scan,
509
        window_size,
510
    ):
511
512
513
514
515
516
517
        if window_size != (-1, -1) and not qkv_layout.is_thd():
            pytest.skip("Sliding window attention is only supported for THD layout")
        if window_size != (-1, -1) and qkv_layout.is_thd() and use_scan:
            pytest.skip(
                "When context parallelism and sliding window attention are used, "
                "scanloop is not supported"
            )
518
        self.impl_test_context_parallel_attn(
519
520
521
522
523
524
525
526
527
528
529
            device_count,
            mesh_shape,
            mesh_axes,
            mesh_resource,
            data_shape,
            kv_groups,
            attn_mask_type,
            dtype,
            qkv_layout,
            load_balanced,
            CPStrategy.RING,
530
531
            use_shardy=False,
            use_scan_ring=use_scan,
532
            window_size=window_size,
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
        )

    @pytest.mark.parametrize(
        "device_count,mesh_shape,mesh_axes,mesh_resource", generate_context_parallel_configs()
    )
    @pytest.mark.parametrize("data_shape", DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES[:1])
    @pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")])
    @pytest.mark.parametrize(
        "qkv_layout, attn_mask_type",
        DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS,
    )
    def test_context_parallel_ring_attn_shardy(
        self,
        device_count,
        mesh_shape,
        mesh_axes,
        mesh_resource,
        data_shape,
        attn_mask_type,
        dtype,
        qkv_layout,
    ):
        kv_groups = 8
        self.impl_test_context_parallel_attn(
            device_count,
            mesh_shape,
            mesh_axes,
            mesh_resource,
            data_shape,
            kv_groups,
            attn_mask_type,
            dtype,
            qkv_layout,
            load_balanced=True,
            cp_strategy=CPStrategy.RING,
            use_shardy=False,
            use_scan_ring=True,
570
571
        )

572
573
574
575
576
577
578
579
580
581
582
583

class TestReorderCausalLoadBalancing:
    @pytest.mark.parametrize("cp_size", [2, 4, 8])
    @pytest.mark.parametrize(
        "shape",
        [
            pytest.param([1, 16, 1, 1], id="1-16-1-1"),
            pytest.param([4, 32, 12, 32], id="4-32-12-32"),
            pytest.param([3, 32, 8, 64], id="3-32-8-64"),
        ],
    )
    @pytest.mark.parametrize("qkv_format", [QKVFormat.BSHD, QKVFormat.SBHD])
Reese Wang's avatar
Reese Wang committed
584
585
586
587
588
589
590
591
    @pytest.mark.parametrize(
        "reorder_strategy",
        [
            pytest.param(ReorderStrategy.DualChunkSwap, id="DualChunkSwap"),
            pytest.param(ReorderStrategy.Striped, id="Striped"),
        ],
    )
    def test(self, cp_size, shape, qkv_format, reorder_strategy):
592
        tensor = random.normal(random.PRNGKey(1124), shape, dtype=jnp.bfloat16)
Reese Wang's avatar
Reese Wang committed
593
        seq_dim = 1
594
595
        if qkv_format == QKVFormat.SBHD:
            tensor = tensor.swapaxes(0, 1)
Reese Wang's avatar
Reese Wang committed
596
            seq_dim = 0
597
598
599

        ref = tensor.copy()

Reese Wang's avatar
Reese Wang committed
600
601
        reorder = jax.jit(reorder_causal_load_balancing, static_argnums=[1, 2, 3])
        inverse = jax.jit(inverse_reorder_causal_load_balancing, static_argnums=[1, 2, 3])
602

Reese Wang's avatar
Reese Wang committed
603
604
        reordered = reorder(tensor, reorder_strategy, cp_size, seq_dim)
        inversed = inverse(reordered, reorder_strategy, cp_size, seq_dim)
605
606

        assert jnp.array_equal(inversed, ref)