test_distributed_fused_attn.py 18.9 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.

5
6
import os
import pytest
7
8
9
import jax
import jax.numpy as jnp
from jax import random
10
11
from distributed_test_base import (
    generate_configs,
12
    generate_context_parallel_configs_for_attn,
13
    generate_collectives_count,
14
)
15
from test_fused_attn import FusedAttnRunner, BiasShape, SeqDescFormat
16
from utils import pytest_parametrize_wrapper
17
18
19
20
from transformer_engine.jax.attention import (
    is_fused_attn_kernel_available,
    AttnBiasType,
    AttnMaskType,
21
    QKVLayout,
22
23
24
    QKVFormat,
    reorder_causal_load_balancing,
    inverse_reorder_causal_load_balancing,
25
    CPStrategy,
Reese Wang's avatar
Reese Wang committed
26
    ReorderStrategy,
27
28
)

29

30
DTYPES = [jnp.bfloat16]
31

32
33
34
35
36
37
DISTRIBUTED_SELF_ATTN_DATA_SHAPES = {
    "L0": [()],
    "L1": [(32, 1024, 16, 128)],
    "L2": [(32, 512, 12, 64)],
}

38
39
40

class TestDistributedSelfAttn:

41
42
43
    def generate_collectives_count_ref(
        self, mesh_shape, mesh_axes, mesh_resource, with_bias, shape, dtype
    ):
44
        jax_dtype = jax.dtypes.canonicalize_dtype(dtype)
45
        _, seqlen, heads, _ = shape
46
47
48
49
50
51
        is_dp_enabled = mesh_resource.dp_resource is not None
        tp_size = 1
        if mesh_resource.tp_resource is not None:
            idx = mesh_axes.index(mesh_resource.tp_resource)
            tp_size = mesh_shape[idx]

52
        all_reduce_loss_bytes = 4  # 1 * FP32
53
54
55
56
57
        bias_bytes = int(with_bias) * (heads // tp_size) * seqlen * seqlen * jax_dtype.itemsize
        allreduce_total_bytes = all_reduce_loss_bytes + (bias_bytes * is_dp_enabled)
        # for loss and dbias
        return generate_collectives_count(allreduce=allreduce_total_bytes, allgather=0, other=0)

58
    def impl_test_self_attn(
59
60
61
62
63
64
65
        self,
        device_count,
        mesh_shape,
        mesh_axes,
        mesh_resource,
        data_shape,
        attn_bias_type,
66
        bias_shape,
67
68
        attn_mask_type,
        dtype,
69
        use_shardy,
70
    ):
71
        jax.config.update("jax_use_shardy_partitioner", use_shardy)
72
73
        dropout_prob = 0.0
        is_training = True
74
        batch, seqlen, num_head, hidden = data_shape
75

76
        if not is_fused_attn_kernel_available(
77
            is_training,
78
79
80
81
82
83
84
85
86
87
88
            dtype,
            dtype,
            QKVLayout.BS3HD,
            attn_bias_type,
            attn_mask_type,
            dropout_prob,
            num_head,
            num_head,
            seqlen,
            seqlen,
            hidden,
89
            hidden,
90
            None,  # no window
91
        ):
92
            pytest.skip("No FusedAttn backend found")
93

94
95
96
97
98
99
100
        col_ref = self.generate_collectives_count_ref(
            mesh_shape,
            mesh_axes,
            mesh_resource,
            attn_bias_type != AttnBiasType.NO_BIAS,
            data_shape,
            dtype,
101
        )
102
103
104
105
106
107
108
        runner = FusedAttnRunner(
            batch,
            seqlen,
            seqlen,
            num_head,
            num_head,
            hidden,
109
            hidden,
110
111
112
113
114
115
116
117
            attn_bias_type,
            attn_mask_type,
            dropout_prob,
            dtype,
            is_training,
            QKVLayout.BS3HD,
            bias_shape,
            None,
118
            SeqDescFormat.Seqlens,
119
120
121
122
123
            number_of_devices=device_count,
            mesh_shape=mesh_shape,
            mesh_axes=mesh_axes,
            mesh_resource=mesh_resource,
            coll_count_ref=col_ref,
124
        )
125
        runner.test_backward()
126

127
    @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
128
    @pytest_parametrize_wrapper("data_shape", DISTRIBUTED_SELF_ATTN_DATA_SHAPES)
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
    @pytest.mark.parametrize(
        "attn_bias_type, bias_shape",
        [
            pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"),
            pytest.param(AttnBiasType.PRE_SCALE_BIAS, BiasShape._1HSS, id="PRE_SCALE_BIAS-1HSS"),
            pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._1HSS, id="POST_SCALE_BIAS-1HSS"),
        ],
    )
    @pytest.mark.parametrize(
        "attn_mask_type",
        [
            pytest.param(AttnMaskType.PADDING_MASK, id="PADDING_MASK"),
            pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL_MASK"),
        ],
    )
    @pytest.mark.parametrize("dtype", DTYPES)
    def test_self_attn(
        self,
        device_count,
        mesh_shape,
        mesh_axes,
        mesh_resource,
        data_shape,
        attn_bias_type,
        bias_shape,
        attn_mask_type,
        dtype,
    ):
        self.impl_test_self_attn(
            device_count,
            mesh_shape,
            mesh_axes,
            mesh_resource,
            data_shape,
            attn_bias_type,
            bias_shape,
            attn_mask_type,
            dtype,
            use_shardy=False,
        )

    @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
    @pytest.mark.parametrize(
        "attn_bias_type, bias_shape",
        [
            pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"),
            pytest.param(AttnBiasType.PRE_SCALE_BIAS, BiasShape._1HSS, id="PRE_SCALE_BIAS-1HSS"),
        ],
    )
    def test_self_attn_shardy(
        self, device_count, mesh_shape, mesh_axes, mesh_resource, attn_bias_type, bias_shape
    ):
        data_shape = (32, 512, 12, 64)
        self.impl_test_self_attn(
            device_count,
            mesh_shape,
            mesh_axes,
            mesh_resource,
            data_shape,
            attn_bias_type,
            bias_shape,
            AttnMaskType.PADDING_MASK,
            jnp.bfloat16,
            use_shardy=True,
        )

195

196
197
198
199
200
201
202
DISTRIBUTED_CROSS_ATTN_DATA_SHAPES = {
    "L0": [()],
    "L1": [[32, 512, 16, 64]],
    "L2": [[32, 128, 12, 64]],
}


203
204
205
206
class TestDistributedCrossAttn:

    def generate_collectives_count_ref(self):
        # for loss
207
        all_reduce_loss_bytes = 4  # 1 * FP32
208
209
        return generate_collectives_count(allreduce=all_reduce_loss_bytes, allgather=0, other=0)

210
    @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
211
    @pytest_parametrize_wrapper("data_shape", DISTRIBUTED_CROSS_ATTN_DATA_SHAPES)
212
213
214
215
216
217
218
    @pytest.mark.parametrize(
        "attn_mask_type", [AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK]
    )
    @pytest.mark.parametrize("dtype", DTYPES)
    def test_cross_attn(
        self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, attn_mask_type, dtype
    ):
219
        attn_bias_type = AttnBiasType.NO_BIAS
220
        bias_shape = None
221
222
223
        dropout_prob = 0.0
        is_training = True

224
        batch, seqlen, num_head, hidden = data_shape
225

226
        if not is_fused_attn_kernel_available(
227
            is_training,
228
229
230
231
232
233
234
235
236
237
238
            dtype,
            dtype,
            QKVLayout.BSHD_BS2HD,
            attn_bias_type,
            attn_mask_type,
            dropout_prob,
            num_head,
            num_head,
            seqlen,
            seqlen,
            hidden,
239
            hidden,
240
            None,  # no window
241
        ):
242
            pytest.skip("No FusedAttn backend found")
243

244
245
246
247
248
249
250
251
        col_ref = self.generate_collectives_count_ref()
        runner = FusedAttnRunner(
            batch,
            seqlen,
            seqlen,
            num_head,
            num_head,
            hidden,
252
            hidden,
253
254
255
256
257
258
259
260
            attn_bias_type,
            attn_mask_type,
            dropout_prob,
            dtype,
            is_training,
            QKVLayout.BSHD_BS2HD,
            bias_shape,
            None,
261
            SeqDescFormat.Seqlens,
262
263
264
265
266
            number_of_devices=device_count,
            mesh_shape=mesh_shape,
            mesh_axes=mesh_axes,
            mesh_resource=mesh_resource,
            coll_count_ref=col_ref,
267
        )
268
        runner.test_backward()
269
270


271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS = [
    pytest.param(QKVLayout.BSHD_BS2HD, AttnMaskType.CAUSAL_MASK, id="BSHD_KVPACKED-CAUSAL"),
    pytest.param(QKVLayout.BSHD_BSHD_BSHD, AttnMaskType.CAUSAL_MASK, id="BSHD_SEPARATE-CAUSAL"),
    pytest.param(QKVLayout.BSHD_BS2HD, AttnMaskType.NO_MASK, id="HD_KVPACKED-NO_MASK"),
    pytest.param(QKVLayout.BSHD_BSHD_BSHD, AttnMaskType.NO_MASK, id="BSHD_SEPARATE-NO_MASK"),
    pytest.param(
        QKVLayout.THD_THD_THD, AttnMaskType.PADDING_CAUSAL_MASK, id="THD_SEPARATE-PADDING_CAUSAL"
    ),
]

DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES = [
    # Sequence lengths will be scaled by CP so that we don't run with tiny sizes.
    pytest.param([2, 128, 8, 128], id="2-128xCP-8-128"),
    pytest.param([4, 256, 16, 64], id="4-256xCP-16-64"),
]


288
class TestDistributedContextParallelSelfAttn:
289

290
    def impl_test_context_parallel_attn(
291
292
293
294
295
296
297
298
299
300
301
        self,
        device_count,
        mesh_shape,
        mesh_axes,
        mesh_resource,
        data_shape,
        kv_groups,
        attn_mask_type,
        dtype,
        qkv_layout,
        load_balanced,
302
        cp_strategy,
303
304
        use_shardy,
        use_scan_ring=False,
305
        window_size=None,
306
    ):
307
308
309
310
311
312
313
314
315
316
317
318
319
320
        if qkv_layout.is_thd():
            if cp_strategy == CPStrategy.ALL_GATHER:
                pytest.skip("THD doesn't support all gather context parallelism.")
            if not load_balanced and cp_strategy == CPStrategy.RING:
                pytest.skip("THD + ring doesn't support unbalanced context parallelism.")

        assert not use_scan_ring or cp_strategy == CPStrategy.RING

        if use_scan_ring:
            os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"] = "1"
        else:
            os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"] = "0"

        jax.config.update("jax_use_shardy_partitioner", use_shardy)
321
        attn_bias_type = AttnBiasType.NO_BIAS
322
        bias_shape = None
323
324
325
326
        dropout_prob = 0.0
        is_training = True
        dp_size, cp_size, tp_size = mesh_shape

327
        batch, seqlen, num_head, hidden = data_shape
328
329
330
331
332
333

        # Scale the sequence length by 2*CP so its never too small as we scale up test.
        # 2*CP is used since we split into two CP groups for load balancing.
        seqlen = seqlen * cp_size * 2
        data_shape = batch, seqlen, num_head, hidden

334
335
        num_kv_heads = num_head // kv_groups

336
337
338
339
340
341
342
        runner = FusedAttnRunner(
            batch,
            seqlen,
            seqlen,
            num_head,
            num_kv_heads,
            hidden,
343
            hidden,
344
345
346
347
348
349
350
            attn_bias_type,
            attn_mask_type,
            dropout_prob,
            dtype,
            is_training,
            qkv_layout,
            bias_shape,
351
            window_size,
Reese Wang's avatar
Reese Wang committed
352
            SeqDescFormat.SegmentIDs,
353
354
355
356
357
358
359
360
            number_of_devices=device_count,
            mesh_shape=mesh_shape,
            mesh_axes=mesh_axes,
            mesh_resource=mesh_resource,
            cp_strategy=cp_strategy,
            cp_load_balanced=load_balanced,
        )

361
362
        def check_has_backend_for_mask(mask_type):
            return is_fused_attn_kernel_available(
363
                is_training,
364
365
366
367
                dtype,
                dtype,
                qkv_layout,
                attn_bias_type,
Reese Wang's avatar
Reese Wang committed
368
                mask_type,
369
370
371
372
373
374
                dropout_prob,
                num_head,
                num_kv_heads,
                seqlen,
                seqlen,
                hidden,
375
                hidden,
376
377
378
379
380
381
382
383
384
385
386
387
388
                None,
            )  # no SWA for CP

        # For causal masking we depend on having bottom right support also.
        # The API does not check this and instead we rely on lower level checks to raise
        # and exception if the step backend is not supported. This was a deliberate API
        # decision to keep the CP size or flag out of the function.
        has_backend = check_has_backend_for_mask(attn_mask_type)
        if cp_size > 1 and attn_mask_type == AttnMaskType.CAUSAL_MASK:
            has_backend &= check_has_backend_for_mask(AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK)

        if not has_backend:
            pytest.skip(f"No FusedAttn backend found {cp_size=} {attn_mask_type=}.")
389

390
391
392
        if dp_size > 1 and batch % dp_size != 0:
            pytest.skip(f"Skipping {batch=} not a multiple of {dp_size=}")

393
        # make sure the mesh even divides cp and tp axis
394
395
396
        if num_head % kv_groups != 0 or (num_head // kv_groups) % tp_size != 0:
            pytest.skip(f"Skipping {kv_groups=} not multiple of {data_shape=} or {tp_size=}")

397
        runner.test_backward()
398
399
        del os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"]

400
401
402
    @pytest_parametrize_wrapper(
        "device_count,mesh_shape,mesh_axes,mesh_resource",
        generate_context_parallel_configs_for_attn(),
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
    )
    @pytest.mark.parametrize("data_shape", DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES[:1])
    @pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")])
    @pytest.mark.parametrize(
        "qkv_layout, attn_mask_type",
        DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS,
    )
    def test_context_parallel_allgather_attn_shardy(
        self,
        device_count,
        mesh_shape,
        mesh_axes,
        mesh_resource,
        data_shape,
        attn_mask_type,
        dtype,
        qkv_layout,
    ):
        kv_groups = 8
        self.impl_test_context_parallel_attn(
            device_count,
            mesh_shape,
            mesh_axes,
            mesh_resource,
            data_shape,
            kv_groups,
            attn_mask_type,
            dtype,
            qkv_layout,
            load_balanced=True,
            cp_strategy=CPStrategy.ALL_GATHER,
            use_shardy=True,
        )
436

437
438
439
    @pytest_parametrize_wrapper(
        "device_count,mesh_shape,mesh_axes,mesh_resource",
        generate_context_parallel_configs_for_attn(),
440
441
442
443
444
445
446
447
448
449
450
451
    )
    @pytest.mark.parametrize("data_shape", DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES)
    @pytest.mark.parametrize("kv_groups", [1, 8])
    @pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")])
    @pytest.mark.parametrize(
        "qkv_layout, attn_mask_type",
        DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS,
    )
    @pytest.mark.parametrize(
        "load_balanced",
        [pytest.param(True, id="BALANCED"), pytest.param(False, id="UNBALANCED")],
    )
452
    def test_context_parallel_allgather_attn(
453
454
455
456
457
458
459
460
461
462
463
464
        self,
        device_count,
        mesh_shape,
        mesh_axes,
        mesh_resource,
        data_shape,
        kv_groups,
        attn_mask_type,
        dtype,
        qkv_layout,
        load_balanced,
    ):
465
        self.impl_test_context_parallel_attn(
466
467
468
469
470
471
472
473
474
475
476
            device_count,
            mesh_shape,
            mesh_axes,
            mesh_resource,
            data_shape,
            kv_groups,
            attn_mask_type,
            dtype,
            qkv_layout,
            load_balanced,
            CPStrategy.ALL_GATHER,
477
            use_shardy=False,
478
479
        )

480
481
482
    @pytest_parametrize_wrapper(
        "device_count,mesh_shape,mesh_axes,mesh_resource",
        generate_context_parallel_configs_for_attn(),
483
484
485
486
487
488
489
490
491
492
493
494
    )
    @pytest.mark.parametrize("data_shape", DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES)
    @pytest.mark.parametrize("kv_groups", [1, 8])
    @pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")])
    @pytest.mark.parametrize(
        "qkv_layout, attn_mask_type",
        DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS,
    )
    @pytest.mark.parametrize(
        "load_balanced",
        [pytest.param(True, id="BALANCED"), pytest.param(False, id="UNBALANCED")],
    )
495
496
497
498
    @pytest.mark.parametrize(
        "use_scan",
        [pytest.param(False, id="NO_SCAN"), pytest.param(True, id="USE_SCAN")],
    )
499
500
501
502
503
504
505
    @pytest.mark.parametrize(
        "window_size",
        [
            pytest.param((-1, -1), id="window_size(-1, -1)"),
            pytest.param((20, 0), id="window_size(20, 0)"),
        ],
    )
506
507
508
509
510
511
512
513
514
515
516
517
    def test_context_parallel_ring_attn(
        self,
        device_count,
        mesh_shape,
        mesh_axes,
        mesh_resource,
        data_shape,
        kv_groups,
        attn_mask_type,
        dtype,
        qkv_layout,
        load_balanced,
518
        use_scan,
519
        window_size,
520
    ):
521
522
523
524
525
526
527
        if window_size != (-1, -1) and not qkv_layout.is_thd():
            pytest.skip("Sliding window attention is only supported for THD layout")
        if window_size != (-1, -1) and qkv_layout.is_thd() and use_scan:
            pytest.skip(
                "When context parallelism and sliding window attention are used, "
                "scanloop is not supported"
            )
528
        self.impl_test_context_parallel_attn(
529
530
531
532
533
534
535
536
537
538
539
            device_count,
            mesh_shape,
            mesh_axes,
            mesh_resource,
            data_shape,
            kv_groups,
            attn_mask_type,
            dtype,
            qkv_layout,
            load_balanced,
            CPStrategy.RING,
540
541
            use_shardy=False,
            use_scan_ring=use_scan,
542
            window_size=window_size,
543
544
        )

545
546
547
    @pytest_parametrize_wrapper(
        "device_count,mesh_shape,mesh_axes,mesh_resource",
        generate_context_parallel_configs_for_attn(),
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
    )
    @pytest.mark.parametrize("data_shape", DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES[:1])
    @pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")])
    @pytest.mark.parametrize(
        "qkv_layout, attn_mask_type",
        DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS,
    )
    def test_context_parallel_ring_attn_shardy(
        self,
        device_count,
        mesh_shape,
        mesh_axes,
        mesh_resource,
        data_shape,
        attn_mask_type,
        dtype,
        qkv_layout,
    ):
        kv_groups = 8
        self.impl_test_context_parallel_attn(
            device_count,
            mesh_shape,
            mesh_axes,
            mesh_resource,
            data_shape,
            kv_groups,
            attn_mask_type,
            dtype,
            qkv_layout,
            load_balanced=True,
            cp_strategy=CPStrategy.RING,
            use_shardy=False,
            use_scan_ring=True,
581
582
        )

583

584
585
586
587
588
589
590
REORDER_CAUSAL_LOAD_BALANCING_DATA_SHAPES = {
    "L0": [[]],
    "L1": [[3, 32, 8, 64]],
    "L2": [[4, 32, 12, 32], [1, 16, 1, 1]],
}


591
592
class TestReorderCausalLoadBalancing:
    @pytest.mark.parametrize("cp_size", [2, 4, 8])
593
    @pytest_parametrize_wrapper("shape", REORDER_CAUSAL_LOAD_BALANCING_DATA_SHAPES)
594
    @pytest.mark.parametrize("qkv_format", [QKVFormat.BSHD, QKVFormat.SBHD])
Reese Wang's avatar
Reese Wang committed
595
596
597
598
599
600
601
602
    @pytest.mark.parametrize(
        "reorder_strategy",
        [
            pytest.param(ReorderStrategy.DualChunkSwap, id="DualChunkSwap"),
            pytest.param(ReorderStrategy.Striped, id="Striped"),
        ],
    )
    def test(self, cp_size, shape, qkv_format, reorder_strategy):
603
        tensor = random.normal(random.PRNGKey(1124), shape, dtype=jnp.bfloat16)
Reese Wang's avatar
Reese Wang committed
604
        seq_dim = 1
605
606
        if qkv_format == QKVFormat.SBHD:
            tensor = tensor.swapaxes(0, 1)
Reese Wang's avatar
Reese Wang committed
607
            seq_dim = 0
608
609
610

        ref = tensor.copy()

Reese Wang's avatar
Reese Wang committed
611
612
        reorder = jax.jit(reorder_causal_load_balancing, static_argnums=[1, 2, 3])
        inverse = jax.jit(inverse_reorder_causal_load_balancing, static_argnums=[1, 2, 3])
613

Reese Wang's avatar
Reese Wang committed
614
615
        reordered = reorder(tensor, reorder_strategy, cp_size, seq_dim)
        inversed = inverse(reordered, reorder_strategy, cp_size, seq_dim)
616
617

        assert jnp.array_equal(inversed, ref)