test_cuda_graphs.py 21.2 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.

5
from typing import Iterable, List, Union
6
7
8
9
import pytest

import torch
from transformer_engine.pytorch import (
10
11
12
13
14
15
    DotProductAttention,
    LayerNormLinear,
    LayerNormMLP,
    Linear,
    MultiheadAttention,
    TransformerLayer,
16
17
    autocast,
    quantized_model_init,
18
    make_graphed_callables,
19
20
21
22
    is_fp8_available,
    is_fp8_block_scaling_available,
    is_mxfp8_available,
    is_bf16_available,
23
)
24
from transformer_engine.pytorch.quantization import FP8GlobalStateManager
25
import transformer_engine.pytorch.ops as te_ops
26
from transformer_engine.common import recipe
27
from utils import ModelConfig, reset_rng_states
28

29
# Check if FP8 is supported.
30
31
32
fp8_available = is_fp8_available()
fp8_block_scaling_available = is_fp8_block_scaling_available()
mxfp8_available = is_mxfp8_available()
33

34
35
# Reset RNG states.
reset_rng_states()
36

37
model_configs = {
38
    "small": ModelConfig(2, 32, 2, 32),
39
}
40

41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86

def nvfp4_vanilla():
    nvfp4_recipe = recipe.NVFP4BlockScaling()
    nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams()
    nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams()
    nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams()
    return nvfp4_recipe


def nvfp4_rht_and_2d_quantization():
    nvfp4_recipe = recipe.NVFP4BlockScaling()
    nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams(
        random_hadamard_transform=True, fp4_2d_quantization=False
    )
    nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams(
        random_hadamard_transform=False, fp4_2d_quantization=True
    )
    nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams(
        random_hadamard_transform=True, fp4_2d_quantization=False
    )
    return nvfp4_recipe


def check_rht_usage(recipe: recipe.Recipe) -> bool:
    # if using RHT, we can only support bf16
    # check fp4_quant_fwd_inp, fp4_quant_fwd_weight, fp4_quant_bwd_grad
    if recipe.nvfp4():
        if (
            recipe.fp4_quant_fwd_inp.random_hadamard_transform
            or recipe.fp4_quant_fwd_weight.random_hadamard_transform
            or recipe.fp4_quant_bwd_grad.random_hadamard_transform
        ):
            return True
    return False


def get_nvfp4_inp_supported_dtypes(recipe: recipe.Recipe, dtype: torch.dtype) -> bool:
    supported_input_dtypes = []
    if recipe.nvfp4():
        supported_input_dtypes.append(torch.bfloat16)
        # if not using RHT, we can add fp32 as well
    if not check_rht_usage(recipe):
        supported_input_dtypes.append(torch.float32)
    return supported_input_dtypes


87
88
89
fp8_recipes = []
if mxfp8_available:
    fp8_recipes.append(recipe.MXFP8BlockScaling())
90
    fp8_recipes.append(nvfp4_rht_and_2d_quantization())
91
92
93
94
95
if fp8_block_scaling_available:
    fp8_recipes.append(recipe.Float8BlockScaling())
if fp8_available:
    fp8_recipes.append(recipe.Float8CurrentScaling())
    fp8_recipes.append(recipe.DelayedScaling())
96

97
98
# Supported data types
dtypes: List[torch.dtype] = [torch.float32, torch.float16]
99
if is_bf16_available():  # bf16 requires sm_80 or higher
100
101
102
103
104
105
106
107
108
109
    dtypes.append(torch.bfloat16)


@pytest.fixture(autouse=True)
def reset_global_fp8_state():
    yield
    FP8GlobalStateManager.reset()


def assert_all_equal(l1: List[torch.Tensor], l2: List[torch.Tensor], names=None) -> bool:
110
    """Check that two lists of tensors match exactly."""
111
    assert len(l1) == len(l2), "Unequal number of outputs."
112
113
    failure_message = "Output mismatches in:"
    failed_tensors = []
114
115
    for i, (t1, t2) in enumerate(zip(l1, l2)):
        if not torch.equal(t1, t2):
116
117
118
119
120
121
122
123
124
125
            failure_message += "\n    "
            if names is None:
                failure_message += f"tensor at idx={i}"
            else:
                failure_message += names[i]
            failed_tensors.append((t1, t2))
    if failed_tensors:
        print(failure_message)
        t1, t2 = failed_tensors[0]
        torch.testing.assert_close(t1, t2, rtol=0, atol=0)
126
127
128


def generate_data(
129
    model_config: ModelConfig,
130
131
    dtype: torch.dtype,
    warmup: bool = False,
132
133
    requires_grad: bool = True,
) -> torch.Tensor:
134
135
    """Generate synthetic data."""
    gen_func = torch.ones if warmup else torch.randn
136
    return gen_func(
137
        model_config.max_seqlen_q,
138
139
        model_config.batch_size,
        model_config.hidden_size,
140
        device="cuda",
141
        requires_grad=requires_grad,
142
143
        dtype=dtype,
    )
144
145


146
147
148
149
def get_outputs(
    model: torch.nn.Module,
    output: Union[torch.Tensor, Iterable[torch.Tensor]],
) -> List[torch.Tensor]:
150
151
152
153
154
155
    """Return grads and params for comparsion."""
    values = []
    for param in model.parameters():
        values.append(param)
        if param.grad is not None:
            values.append(param.grad)
156
157
158
159
    if isinstance(output, torch.Tensor):
        values.append(output)
    else:
        values.extend(output)
160
161
162
    return values


163
164
165
166
167
168
169
170
171
172
class _Sequential(torch.nn.Sequential):
    """Sequential model that forwards keyword arguments to modules"""

    def forward(self, input_: torch.Tensor, **kwargs) -> torch.Tensor:
        x = input_
        for module in self:
            x = module(x, **kwargs)
        return x


173
174
# Supported modules
_test_cuda_graphs_modules: List[str] = [
175
176
177
    # Put linear first to test the case where the cuda context might not be set in
    # creating TMA descriptor for MXFP8 quantization.
    "linear",
178
179
180
181
182
183
184
185
    "transformer",
    "layernorm_mlp",
    "layernorm_linear",
    "mha",
    "linear_op",
]


186
187
def _test_cuda_graphs(
    *,
188
189
190
    graph_mode: str,
    module: str,
    model_config: ModelConfig,
191
192
193
194
195
    num_layers: int,
    dtype: torch.dtype,
    fp8: bool,
    fp8_params: bool,
    fp8_weight_caching: bool,
196
    fp8_recipe: recipe.Recipe,
197
) -> List[torch.Tensor]:
198
    """Helper function for CUDA graph test."""
199
200
201
    reset_rng_states()
    FP8GlobalStateManager.reset()

202
203
204
205
206
    # Operation-based API does not support FP8 weight caching.
    if module == "linear_op":
        fp8_weight_caching = False

    # Create modules.
207
    with quantized_model_init(enabled=fp8_params, recipe=fp8_recipe):
208
        if module == "transformer":
209
210
            modules = [
                TransformerLayer(
211
212
213
                    model_config.hidden_size,
                    model_config.hidden_size,
                    model_config.num_heads,
214
215
216
217
218
219
220
                    hidden_dropout=0.0,
                    attention_dropout=0.0,
                    fuse_qkv_params=True,
                    params_dtype=dtype,
                )
                for _ in range(num_layers)
            ]
221
        elif module == "layernorm_mlp":
222
            modules = [
223
224
225
226
227
                LayerNormMLP(
                    model_config.hidden_size,
                    model_config.hidden_size,
                    params_dtype=dtype,
                )
228
229
                for _ in range(num_layers)
            ]
230
        elif module == "layernorm_linear":
231
            modules = [
232
233
234
235
236
                LayerNormLinear(
                    model_config.hidden_size,
                    model_config.hidden_size,
                    params_dtype=dtype,
                )
237
238
                for _ in range(num_layers)
            ]
239
        elif module == "mha":
240
241
            modules = [
                MultiheadAttention(
242
243
                    model_config.hidden_size,
                    model_config.num_heads,
244
245
246
247
248
249
                    attention_dropout=0.0,
                    params_dtype=dtype,
                    fuse_qkv_params=True,
                )
                for _ in range(num_layers)
            ]
250
        elif module == "linear":
251
            modules = [
252
253
254
255
256
257
                Linear(
                    model_config.hidden_size,
                    model_config.hidden_size,
                    device="cuda",
                    params_dtype=dtype,
                )
258
259
                for _ in range(num_layers)
            ]
260
        elif module == "linear_op":
261
            modules = [
262
263
264
265
266
267
268
                te_ops.Sequential(
                    te_ops.Linear(
                        model_config.hidden_size,
                        model_config.hidden_size,
                        dtype=dtype,
                    ),
                )
269
270
                for _ in range(num_layers)
            ]
271
272
        else:
            raise ValueError(f"Unknown module type ({module})")
273

274
275
276
277
278
        # Initialize gradient buffers.
        for module in modules:
            for param in module.parameters():
                param.grad = torch.empty_like(param)

279
        # Generate model and wrap API to return graphed version.
280
281
        if graph_mode == "full":
            # Graph entire model at once.
282
            model = torch.nn.Sequential(*modules)
283
284
            model = make_graphed_callables(
                model,
285
                (generate_data(model_config, dtype, warmup=True),),
286
                num_warmup_iters=10,
287
288
289
                enabled=fp8,
                cache_quantized_params=fp8_weight_caching,
                recipe=fp8_recipe,
290
291
            )
        elif graph_mode == "individual":
292
            # Graph individual modules.
293
294
            modules = [
                make_graphed_callables(
295
                    module,
296
                    (generate_data(model_config, dtype, warmup=True),),
297
                    num_warmup_iters=10,
298
299
300
                    enabled=fp8,
                    cache_quantized_params=fp8_weight_caching,
                    recipe=fp8_recipe,
301
302
303
                )
                for module in modules
            ]
304
            model = _Sequential(*modules)
305
        else:
306
            model = _Sequential(*modules)
307

308
    # Optimizer.
309
    optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
310

311
    # Training steps.
312
    for _ in range(3):
313
        optimizer.zero_grad(set_to_none=False)
314
        for grad_accumulation_step in range(2):
315
316
            input_ = generate_data(model_config, dtype)
            grad_output = generate_data(model_config, dtype, requires_grad=False)
317
            with autocast(enabled=fp8, recipe=fp8_recipe):
318
319
                kwargs = {}
                if fp8_weight_caching:
320
                    kwargs["is_first_microbatch"] = grad_accumulation_step == 0
321
                output = model(input_, **kwargs)
322
            output.backward(grad_output)
323
        optimizer.step()
324
325
326
327

    return get_outputs(model, output)


328
@pytest.mark.parametrize("module", _test_cuda_graphs_modules)
329
@pytest.mark.parametrize("dtype", dtypes)
330
@pytest.mark.parametrize("fp8_params", (False, True))
331
@pytest.mark.parametrize("fp8_recipe", fp8_recipes + [None], ids=lambda r: type(r).__name__)
332
333
334
335
336
def test_make_graphed_callables(
    *,
    module: str,
    model_config: str = "small",
    num_layers: int = 3,
337
338
    dtype: torch.dtype,
    fp8_params: bool,
339
    fp8_recipe: recipe.Recipe,
340
    fp8_weight_caching: bool = False,
341
) -> None:
342

343
    fp8 = fp8_recipe is not None
344
345
    if fp8_params and not fp8:
        pytest.skip("FP8 needed for FP8 parameters.")
346
347
    if fp8_weight_caching and not fp8:
        pytest.skip("FP8 needed for FP8 parameters.")
348
349
350
351
352
353
354
355
356
357
358
359
    if fp8 and (fp8_recipe.float8_block_scaling() or fp8_recipe.nvfp4()) and module == "linear_op":
        pytest.skip(
            f"Module not yet supported for {fp8_recipe.__class__.__name__} with CUDA graphs"
        )
    if fp8 and fp8_recipe.nvfp4():
        if dtype not in get_nvfp4_inp_supported_dtypes(fp8_recipe, dtype):
            pytest.skip(
                f"Input dtype {dtype} not supported for NVFP4 Recipe"
                f" {fp8_recipe.__class__.__name__}"
            )
        if fp8_params:
            pytest.skip("NVFP4 params not supported")
360

361
362
    # Run model with different CUDA graph settings.
    model_config = model_configs[model_config]
363
    kwargs = dict(
364
365
        module=module,
        model_config=model_config,
366
367
368
369
370
        num_layers=num_layers,
        dtype=dtype,
        fp8=fp8,
        fp8_params=fp8_params,
        fp8_weight_caching=fp8_weight_caching,
371
        fp8_recipe=fp8_recipe,
372
    )
373
374
    # Put graphed callables first to test the case where the cuda context might not be set in
    # creating TMA descriptor for MXFP8 quantization.
375
376
    graph_outputs_mode1 = _test_cuda_graphs(graph_mode="full", **kwargs)
    graph_outputs_mode2 = _test_cuda_graphs(graph_mode="individual", **kwargs)
377
    outputs = _test_cuda_graphs(graph_mode="none", **kwargs)
378

379
    # Check that results match.
380
381
    assert_all_equal(outputs, graph_outputs_mode1)
    assert_all_equal(outputs, graph_outputs_mode2)
382
383


384
385
386
387
388
389
390
391
392
393
394
395
396
_test_make_graphed_callables_with_fp8_weight_caching_modules = [
    "transformer",
    "layernorm_mlp",
    "layernorm_linear",
    "linear",
    "mha",
]


@pytest.mark.parametrize(
    "module",
    _test_make_graphed_callables_with_fp8_weight_caching_modules,
)
397
@pytest.mark.parametrize("dtype", dtypes)
398
@pytest.mark.parametrize("fp8_params", (False, True))
399
@pytest.mark.parametrize("fp8_recipe", fp8_recipes, ids=lambda r: type(r).__name__)
400
def test_make_graphed_callables_with_fp8_weight_caching(
401
    *,
402
    module: str,
403
    dtype: torch.dtype,
404
    fp8_params: bool,
405
    fp8_recipe: recipe.Recipe,
406
407
408
) -> None:
    test_make_graphed_callables(
        module=module,
409
        dtype=dtype,
410
        fp8_params=fp8_params,
411
        fp8_recipe=fp8_recipe,
412
413
414
415
416
417
        fp8_weight_caching=True,
    )


def generate_data_for_dot_product_attention(
    model_config: ModelConfig,
418
    dtype: torch.dtype,
419
420
421
422
423
424
    warmup: bool = False,
) -> List[torch.Tensor]:
    """Generate synthetic data for dot product attention."""
    gen_func = torch.ones if warmup else torch.randn
    return [
        gen_func(
425
            model_config.max_seqlen_q,
426
427
428
429
430
431
432
433
434
435
436
437
438
            model_config.batch_size,
            model_config.num_heads,
            model_config.kv_channels,
            device="cuda",
            requires_grad=True,
            dtype=dtype,
        )
        for _ in range(3)
    ]


def _test_cuda_graphs_with_dot_product_attention(
    *,
439
    with_graph: bool,
440
441
    model_config: ModelConfig,
    dtype: torch.dtype,
442
) -> List[torch.Tensor]:
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
    """Helper function for CUDA graph test."""
    reset_rng_states()
    FP8GlobalStateManager.reset()

    # Create dot product attention module.
    assert model_config.hidden_size % model_config.num_heads == 0
    model = DotProductAttention(
        model_config.num_heads,
        model_config.kv_channels,
        attention_dropout=0.0,
    )

    # Graph model if needed.
    if with_graph:
        model = make_graphed_callables(
            model,
            generate_data_for_dot_product_attention(model_config, dtype, warmup=True),
            num_warmup_iters=10,
461
            enabled=False,
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
        )

    # Forward and backward passes.
    for _ in range(3):
        inputs = generate_data_for_dot_product_attention(model_config, dtype)
        grad_output = generate_data(model_config, dtype, requires_grad=False)
        output = model(*inputs)
        output.backward(grad_output)

    return get_outputs(model, output)


@pytest.mark.parametrize("dtype", dtypes)
def test_make_graphed_callables_with_dot_product_attention(
    *,
    model_config: str = "small",
    dtype: torch.dtype,
) -> None:
    """Test CUDA graphs with dot product attention."""
    model_config = model_configs[model_config]
    kwargs = dict(model_config=model_config, dtype=dtype)
    outputs = _test_cuda_graphs_with_dot_product_attention(with_graph=False, **kwargs)
    graph_outputs = _test_cuda_graphs_with_dot_product_attention(with_graph=True, **kwargs)
    assert_all_equal(outputs, graph_outputs)


def _test_cuda_graphs_with_kwargs(
    *,
    with_graph: bool,
    model_config: ModelConfig,
    dtype: torch.dtype,
) -> List[torch.Tensor]:
    """Helper function for CUDA graph test with keyword arguments."""
495
496
497
498
    reset_rng_states()

    # Initialize model.
    model = TransformerLayer(
499
500
501
        model_config.hidden_size,
        model_config.hidden_size,
        model_config.num_heads,
502
503
504
505
506
507
508
509
510
511
512
513
514
515
        hidden_dropout=0.0,
        attention_dropout=0.0,
        self_attn_mask_type="arbitrary",
        fuse_qkv_params=True,
        params_dtype=dtype,
    )

    # Initialize gradient buffers.
    for param in model.parameters():
        param.grad = torch.empty_like(param)

    # Make graphed version of model if needed.
    if with_graph:
        attn_mask = torch.zeros(
516
517
518
            (
                model_config.batch_size,
                1,
519
520
                model_config.max_seqlen_q,
                model_config.max_seqlen_kv,
521
            ),
522
523
524
525
526
            dtype=torch.bool,
            device="cuda",
        )
        model = make_graphed_callables(
            model,
527
            (generate_data(model_config, dtype, warmup=True),),
528
529
530
531
532
533
534
535
536
537
538
            sample_kwargs=dict(attention_mask=attn_mask),
            allow_unused_input=True,
        )

    # Optimizer.
    optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

    # Training loop.
    for _ in range(3):
        optimizer.zero_grad(set_to_none=False)
        for grad_accumulation_step in range(2):
539
540
            input_ = generate_data(model_config, dtype)
            grad_output = generate_data(model_config, dtype, requires_grad=False)
541
542
            attn_mask = torch.randint(
                2,
543
544
545
                (
                    model_config.batch_size,
                    1,
546
547
                    model_config.max_seqlen_q,
                    model_config.max_seqlen_kv,
548
                ),
549
550
551
                dtype=torch.bool,
                device="cuda",
            )
552
            output = model(input_, attention_mask=attn_mask)
553
554
555
556
557
558
559
            output.backward(grad_output)
        optimizer.step()

    return get_outputs(model, output)


def test_make_graphed_callables_with_kwargs(
560
561
    *,
    model_config: str = "small",
562
563
564
    dtype: torch.dtype = torch.float32,
) -> None:
    """Test CUDA graphs with keyword arguments."""
565
566
    model_config = model_configs[model_config]
    kwargs = dict(model_config=model_config, dtype=dtype)
567
568
569
570
571
572
573
574
    outputs = _test_cuda_graphs_with_kwargs(with_graph=False, **kwargs)
    graph_outputs = _test_cuda_graphs_with_kwargs(with_graph=True, **kwargs)
    assert_all_equal(outputs, graph_outputs)


def _test_cuda_graphs_with_interleaved_pipeline_parallelism(
    *,
    with_graph: bool,
575
576
    model_config: ModelConfig,
    dtype: torch.dtype,
577
578
579
580
581
582
583
584
585
586
587
588
589
) -> List[torch.Tensor]:
    """Simulate Megatron-LM interleaved pipeline parallelism."""
    reset_rng_states()

    # Pipeline parallel configuration.
    num_layers = 2
    num_microbatches = 3
    layer_order = [1, 2, 1, 2, -2, -1, 1, 2, -2, -1, -2, -1]

    # Initialize model.
    model = torch.nn.ModuleList(
        [
            Linear(
590
591
                model_config.hidden_size,
                model_config.hidden_size,
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
                params_dtype=dtype,
            )
            for _ in range(num_layers)
        ]
    )

    # Initialize gradient buffers.
    for param in model.parameters():
        param.grad = torch.empty_like(param)

    # Make graphed version of model if needed.
    layer_forwards = {
        (i % num_layers, i // num_layers): model[i % num_layers]
        for i in range(num_layers * num_microbatches)
    }
    if with_graph:
        sample_args = tuple(
609
610
            (generate_data(model_config, dtype, warmup=True),)
            for _ in range(num_layers * num_microbatches)
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
        )
        layer_forwards = make_graphed_callables(
            tuple(model),
            sample_args,
            allow_unused_input=True,
            _order=layer_order,
        )
        layer_forwards = {
            (i // num_microbatches, i % num_microbatches): forward
            for i, forward in enumerate(layer_forwards)
        }

    # Optimizer.
    optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

    # Training loop.
    for _ in range(3):
        optimizer.zero_grad(set_to_none=False)

        # Generate data.
        inputs = {}
        grad_outputs = {}
        for layer_idx in range(num_layers):
            for microbatch_idx in range(num_microbatches):
635
636
                x = generate_data(model_config, dtype)
                dy = generate_data(model_config, dtype, requires_grad=False)
637
                idxs = (layer_idx, microbatch_idx)
638
                inputs[idxs] = x
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
                grad_outputs[idxs] = dy

        # Cache for layer outputs.
        outputs = {}

        def forward(layer_idx: int, microbatch_idx: int):
            """Helper function for forward steps"""
            idxs = (layer_idx, microbatch_idx)
            outputs[idxs] = layer_forwards[idxs](inputs[idxs])

        def backward(layer_idx: int, microbatch_idx: int):
            """Helper function for backward steps"""
            outputs[layer_idx, microbatch_idx].backward(grad_outputs[layer_idx, microbatch_idx])

        # Forward and backward steps.
        forward(0, 0)
        forward(1, 0)
        forward(0, 1)
        forward(1, 1)
        backward(1, 0)
        backward(0, 0)
        forward(0, 2)
        forward(1, 2)
        backward(1, 1)
        backward(0, 1)
        backward(1, 2)
        backward(0, 2)

        # Optimizer step.
        optimizer.step()

    outputs = [y for _, y in sorted(outputs.items())]
    return get_outputs(model, outputs)


def test_make_graphed_callables_with_interleaved_pipeline_parallelism(
675
676
    *,
    model_config: str = "small",
677
678
679
    dtype: torch.dtype = torch.float16,
) -> None:
    """Test CUDA graphs with Megatron-LM interleaved pipeline parallelism."""
680
681
    model_config = model_configs[model_config]
    kwargs = dict(model_config=model_config, dtype=dtype)
682
683
684
685
686
687
688
689
690
    outputs = _test_cuda_graphs_with_interleaved_pipeline_parallelism(
        with_graph=False,
        **kwargs,
    )
    graph_outputs = _test_cuda_graphs_with_interleaved_pipeline_parallelism(
        with_graph=True,
        **kwargs,
    )
    assert_all_equal(outputs, graph_outputs)