test_cuda_graphs.py 21.3 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.

5
from typing import Iterable, List, Union
6
7
8
9
import pytest

import torch
from transformer_engine.pytorch import (
10
11
12
13
14
15
16
17
    DotProductAttention,
    LayerNormLinear,
    LayerNormMLP,
    Linear,
    MultiheadAttention,
    TransformerLayer,
    fp8_autocast,
    fp8_model_init,
18
    make_graphed_callables,
19
20
21
)
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.utils import is_bf16_compatible
22
import transformer_engine.pytorch.ops as te_ops
23
from transformer_engine.common import recipe
24
from utils import ModelConfig, reset_rng_states
25

26
# Check if FP8 is supported.
27
28
29
fp8_available, _ = FP8GlobalStateManager.is_fp8_available()
fp8_block_scaling_available, _ = FP8GlobalStateManager.is_fp8_block_scaling_available()
mxfp8_available, _ = FP8GlobalStateManager.is_mxfp8_available()
30

31
32
# Reset RNG states.
reset_rng_states()
33

34
model_configs = {
35
    "small": ModelConfig(2, 32, 2, 32),
36
}
37

38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83

def nvfp4_vanilla():
    nvfp4_recipe = recipe.NVFP4BlockScaling()
    nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams()
    nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams()
    nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams()
    return nvfp4_recipe


def nvfp4_rht_and_2d_quantization():
    nvfp4_recipe = recipe.NVFP4BlockScaling()
    nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams(
        random_hadamard_transform=True, fp4_2d_quantization=False
    )
    nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams(
        random_hadamard_transform=False, fp4_2d_quantization=True
    )
    nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams(
        random_hadamard_transform=True, fp4_2d_quantization=False
    )
    return nvfp4_recipe


def check_rht_usage(recipe: recipe.Recipe) -> bool:
    # if using RHT, we can only support bf16
    # check fp4_quant_fwd_inp, fp4_quant_fwd_weight, fp4_quant_bwd_grad
    if recipe.nvfp4():
        if (
            recipe.fp4_quant_fwd_inp.random_hadamard_transform
            or recipe.fp4_quant_fwd_weight.random_hadamard_transform
            or recipe.fp4_quant_bwd_grad.random_hadamard_transform
        ):
            return True
    return False


def get_nvfp4_inp_supported_dtypes(recipe: recipe.Recipe, dtype: torch.dtype) -> bool:
    supported_input_dtypes = []
    if recipe.nvfp4():
        supported_input_dtypes.append(torch.bfloat16)
        # if not using RHT, we can add fp32 as well
    if not check_rht_usage(recipe):
        supported_input_dtypes.append(torch.float32)
    return supported_input_dtypes


84
85
86
fp8_recipes = []
if mxfp8_available:
    fp8_recipes.append(recipe.MXFP8BlockScaling())
87
    fp8_recipes.append(nvfp4_rht_and_2d_quantization())
88
89
90
91
92
if fp8_block_scaling_available:
    fp8_recipes.append(recipe.Float8BlockScaling())
if fp8_available:
    fp8_recipes.append(recipe.Float8CurrentScaling())
    fp8_recipes.append(recipe.DelayedScaling())
93

94
95
# Supported data types
dtypes: List[torch.dtype] = [torch.float32, torch.float16]
96
97
98
99
100
101
102
103
104
105
106
if is_bf16_compatible():  # bf16 requires sm_80 or higher
    dtypes.append(torch.bfloat16)


@pytest.fixture(autouse=True)
def reset_global_fp8_state():
    yield
    FP8GlobalStateManager.reset()


def assert_all_equal(l1: List[torch.Tensor], l2: List[torch.Tensor], names=None) -> bool:
107
    """Check that two lists of tensors match exactly."""
108
    assert len(l1) == len(l2), "Unequal number of outputs."
109
110
    failure_message = "Output mismatches in:"
    failed_tensors = []
111
112
    for i, (t1, t2) in enumerate(zip(l1, l2)):
        if not torch.equal(t1, t2):
113
114
115
116
117
118
119
120
121
122
            failure_message += "\n    "
            if names is None:
                failure_message += f"tensor at idx={i}"
            else:
                failure_message += names[i]
            failed_tensors.append((t1, t2))
    if failed_tensors:
        print(failure_message)
        t1, t2 = failed_tensors[0]
        torch.testing.assert_close(t1, t2, rtol=0, atol=0)
123
124
125


def generate_data(
126
    model_config: ModelConfig,
127
128
    dtype: torch.dtype,
    warmup: bool = False,
129
130
    requires_grad: bool = True,
) -> torch.Tensor:
131
132
    """Generate synthetic data."""
    gen_func = torch.ones if warmup else torch.randn
133
    return gen_func(
134
        model_config.max_seqlen_q,
135
136
        model_config.batch_size,
        model_config.hidden_size,
137
        device="cuda",
138
        requires_grad=requires_grad,
139
140
        dtype=dtype,
    )
141
142


143
144
145
146
def get_outputs(
    model: torch.nn.Module,
    output: Union[torch.Tensor, Iterable[torch.Tensor]],
) -> List[torch.Tensor]:
147
148
149
150
151
152
    """Return grads and params for comparsion."""
    values = []
    for param in model.parameters():
        values.append(param)
        if param.grad is not None:
            values.append(param.grad)
153
154
155
156
    if isinstance(output, torch.Tensor):
        values.append(output)
    else:
        values.extend(output)
157
158
159
    return values


160
161
162
163
164
165
166
167
168
169
class _Sequential(torch.nn.Sequential):
    """Sequential model that forwards keyword arguments to modules"""

    def forward(self, input_: torch.Tensor, **kwargs) -> torch.Tensor:
        x = input_
        for module in self:
            x = module(x, **kwargs)
        return x


170
171
# Supported modules
_test_cuda_graphs_modules: List[str] = [
172
173
174
    # Put linear first to test the case where the cuda context might not be set in
    # creating TMA descriptor for MXFP8 quantization.
    "linear",
175
176
177
178
179
180
181
182
    "transformer",
    "layernorm_mlp",
    "layernorm_linear",
    "mha",
    "linear_op",
]


183
184
def _test_cuda_graphs(
    *,
185
186
187
    graph_mode: str,
    module: str,
    model_config: ModelConfig,
188
189
190
191
192
    num_layers: int,
    dtype: torch.dtype,
    fp8: bool,
    fp8_params: bool,
    fp8_weight_caching: bool,
193
    fp8_recipe: recipe.Recipe,
194
) -> List[torch.Tensor]:
195
    """Helper function for CUDA graph test."""
196
197
198
    reset_rng_states()
    FP8GlobalStateManager.reset()

199
200
201
202
203
    # Operation-based API does not support FP8 weight caching.
    if module == "linear_op":
        fp8_weight_caching = False

    # Create modules.
204
    with fp8_model_init(enabled=fp8_params, recipe=fp8_recipe):
205
        if module == "transformer":
206
207
            modules = [
                TransformerLayer(
208
209
210
                    model_config.hidden_size,
                    model_config.hidden_size,
                    model_config.num_heads,
211
212
213
214
215
216
217
                    hidden_dropout=0.0,
                    attention_dropout=0.0,
                    fuse_qkv_params=True,
                    params_dtype=dtype,
                )
                for _ in range(num_layers)
            ]
218
        elif module == "layernorm_mlp":
219
            modules = [
220
221
222
223
224
                LayerNormMLP(
                    model_config.hidden_size,
                    model_config.hidden_size,
                    params_dtype=dtype,
                )
225
226
                for _ in range(num_layers)
            ]
227
        elif module == "layernorm_linear":
228
            modules = [
229
230
231
232
233
                LayerNormLinear(
                    model_config.hidden_size,
                    model_config.hidden_size,
                    params_dtype=dtype,
                )
234
235
                for _ in range(num_layers)
            ]
236
        elif module == "mha":
237
238
            modules = [
                MultiheadAttention(
239
240
                    model_config.hidden_size,
                    model_config.num_heads,
241
242
243
244
245
246
                    attention_dropout=0.0,
                    params_dtype=dtype,
                    fuse_qkv_params=True,
                )
                for _ in range(num_layers)
            ]
247
        elif module == "linear":
248
            modules = [
249
250
251
252
253
254
                Linear(
                    model_config.hidden_size,
                    model_config.hidden_size,
                    device="cuda",
                    params_dtype=dtype,
                )
255
256
                for _ in range(num_layers)
            ]
257
        elif module == "linear_op":
258
            modules = [
259
260
261
262
263
264
265
                te_ops.Sequential(
                    te_ops.Linear(
                        model_config.hidden_size,
                        model_config.hidden_size,
                        dtype=dtype,
                    ),
                )
266
267
                for _ in range(num_layers)
            ]
268
269
        else:
            raise ValueError(f"Unknown module type ({module})")
270

271
272
273
274
275
        # Initialize gradient buffers.
        for module in modules:
            for param in module.parameters():
                param.grad = torch.empty_like(param)

276
        # Generate model and wrap API to return graphed version.
277
278
        if graph_mode == "full":
            # Graph entire model at once.
279
            model = torch.nn.Sequential(*modules)
280
281
            model = make_graphed_callables(
                model,
282
                (generate_data(model_config, dtype, warmup=True),),
283
284
285
                num_warmup_iters=10,
                fp8_enabled=fp8,
                fp8_weight_caching=fp8_weight_caching,
286
                fp8_recipe=fp8_recipe,
287
288
            )
        elif graph_mode == "individual":
289
            # Graph individual modules.
290
291
            modules = [
                make_graphed_callables(
292
                    module,
293
                    (generate_data(model_config, dtype, warmup=True),),
294
                    num_warmup_iters=10,
295
296
                    fp8_enabled=fp8,
                    fp8_weight_caching=fp8_weight_caching,
297
                    fp8_recipe=fp8_recipe,
298
299
300
                )
                for module in modules
            ]
301
            model = _Sequential(*modules)
302
        else:
303
            model = _Sequential(*modules)
304

305
    # Optimizer.
306
    optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
307

308
    # Training steps.
309
    for _ in range(3):
310
        optimizer.zero_grad(set_to_none=False)
311
        for grad_accumulation_step in range(2):
312
313
            input_ = generate_data(model_config, dtype)
            grad_output = generate_data(model_config, dtype, requires_grad=False)
314
            with fp8_autocast(enabled=fp8, fp8_recipe=fp8_recipe):
315
316
                kwargs = {}
                if fp8_weight_caching:
317
                    kwargs["is_first_microbatch"] = grad_accumulation_step == 0
318
                output = model(input_, **kwargs)
319
            output.backward(grad_output)
320
        optimizer.step()
321
322
323
324

    return get_outputs(model, output)


325
@pytest.mark.parametrize("module", _test_cuda_graphs_modules)
326
@pytest.mark.parametrize("dtype", dtypes)
327
@pytest.mark.parametrize("fp8_params", (False, True))
328
@pytest.mark.parametrize("fp8_recipe", fp8_recipes + [None], ids=lambda r: type(r).__name__)
329
330
331
332
333
def test_make_graphed_callables(
    *,
    module: str,
    model_config: str = "small",
    num_layers: int = 3,
334
335
    dtype: torch.dtype,
    fp8_params: bool,
336
    fp8_recipe: recipe.Recipe,
337
    fp8_weight_caching: bool = False,
338
) -> None:
339

340
    fp8 = fp8_recipe is not None
341
342
    if fp8_params and not fp8:
        pytest.skip("FP8 needed for FP8 parameters.")
343
344
    if fp8_weight_caching and not fp8:
        pytest.skip("FP8 needed for FP8 parameters.")
345
346
347
348
349
350
351
352
353
354
355
356
    if fp8 and (fp8_recipe.float8_block_scaling() or fp8_recipe.nvfp4()) and module == "linear_op":
        pytest.skip(
            f"Module not yet supported for {fp8_recipe.__class__.__name__} with CUDA graphs"
        )
    if fp8 and fp8_recipe.nvfp4():
        if dtype not in get_nvfp4_inp_supported_dtypes(fp8_recipe, dtype):
            pytest.skip(
                f"Input dtype {dtype} not supported for NVFP4 Recipe"
                f" {fp8_recipe.__class__.__name__}"
            )
        if fp8_params:
            pytest.skip("NVFP4 params not supported")
357

358
359
    # Run model with different CUDA graph settings.
    model_config = model_configs[model_config]
360
    kwargs = dict(
361
362
        module=module,
        model_config=model_config,
363
364
365
366
367
        num_layers=num_layers,
        dtype=dtype,
        fp8=fp8,
        fp8_params=fp8_params,
        fp8_weight_caching=fp8_weight_caching,
368
        fp8_recipe=fp8_recipe,
369
    )
370
371
    # Put graphed callables first to test the case where the cuda context might not be set in
    # creating TMA descriptor for MXFP8 quantization.
372
373
    graph_outputs_mode1 = _test_cuda_graphs(graph_mode="full", **kwargs)
    graph_outputs_mode2 = _test_cuda_graphs(graph_mode="individual", **kwargs)
374
    outputs = _test_cuda_graphs(graph_mode="none", **kwargs)
375

376
    # Check that results match.
377
378
    assert_all_equal(outputs, graph_outputs_mode1)
    assert_all_equal(outputs, graph_outputs_mode2)
379
380


381
382
383
384
385
386
387
388
389
390
391
392
393
_test_make_graphed_callables_with_fp8_weight_caching_modules = [
    "transformer",
    "layernorm_mlp",
    "layernorm_linear",
    "linear",
    "mha",
]


@pytest.mark.parametrize(
    "module",
    _test_make_graphed_callables_with_fp8_weight_caching_modules,
)
394
@pytest.mark.parametrize("dtype", dtypes)
395
@pytest.mark.parametrize("fp8_params", (False, True))
396
@pytest.mark.parametrize("fp8_recipe", fp8_recipes, ids=lambda r: type(r).__name__)
397
def test_make_graphed_callables_with_fp8_weight_caching(
398
    *,
399
    module: str,
400
    dtype: torch.dtype,
401
    fp8_params: bool,
402
    fp8_recipe: recipe.Recipe,
403
404
405
) -> None:
    test_make_graphed_callables(
        module=module,
406
        dtype=dtype,
407
        fp8_params=fp8_params,
408
        fp8_recipe=fp8_recipe,
409
410
411
412
413
414
        fp8_weight_caching=True,
    )


def generate_data_for_dot_product_attention(
    model_config: ModelConfig,
415
    dtype: torch.dtype,
416
417
418
419
420
421
    warmup: bool = False,
) -> List[torch.Tensor]:
    """Generate synthetic data for dot product attention."""
    gen_func = torch.ones if warmup else torch.randn
    return [
        gen_func(
422
            model_config.max_seqlen_q,
423
424
425
426
427
428
429
430
431
432
433
434
435
            model_config.batch_size,
            model_config.num_heads,
            model_config.kv_channels,
            device="cuda",
            requires_grad=True,
            dtype=dtype,
        )
        for _ in range(3)
    ]


def _test_cuda_graphs_with_dot_product_attention(
    *,
436
    with_graph: bool,
437
438
    model_config: ModelConfig,
    dtype: torch.dtype,
439
) -> List[torch.Tensor]:
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
    """Helper function for CUDA graph test."""
    reset_rng_states()
    FP8GlobalStateManager.reset()

    # Create dot product attention module.
    assert model_config.hidden_size % model_config.num_heads == 0
    model = DotProductAttention(
        model_config.num_heads,
        model_config.kv_channels,
        attention_dropout=0.0,
    )

    # Graph model if needed.
    if with_graph:
        model = make_graphed_callables(
            model,
            generate_data_for_dot_product_attention(model_config, dtype, warmup=True),
            num_warmup_iters=10,
            fp8_enabled=False,
        )

    # Forward and backward passes.
    for _ in range(3):
        inputs = generate_data_for_dot_product_attention(model_config, dtype)
        grad_output = generate_data(model_config, dtype, requires_grad=False)
        output = model(*inputs)
        output.backward(grad_output)

    return get_outputs(model, output)


@pytest.mark.parametrize("dtype", dtypes)
def test_make_graphed_callables_with_dot_product_attention(
    *,
    model_config: str = "small",
    dtype: torch.dtype,
) -> None:
    """Test CUDA graphs with dot product attention."""
    model_config = model_configs[model_config]
    kwargs = dict(model_config=model_config, dtype=dtype)
    outputs = _test_cuda_graphs_with_dot_product_attention(with_graph=False, **kwargs)
    graph_outputs = _test_cuda_graphs_with_dot_product_attention(with_graph=True, **kwargs)
    assert_all_equal(outputs, graph_outputs)


def _test_cuda_graphs_with_kwargs(
    *,
    with_graph: bool,
    model_config: ModelConfig,
    dtype: torch.dtype,
) -> List[torch.Tensor]:
    """Helper function for CUDA graph test with keyword arguments."""
492
493
494
495
    reset_rng_states()

    # Initialize model.
    model = TransformerLayer(
496
497
498
        model_config.hidden_size,
        model_config.hidden_size,
        model_config.num_heads,
499
500
501
502
503
504
505
506
507
508
509
510
511
512
        hidden_dropout=0.0,
        attention_dropout=0.0,
        self_attn_mask_type="arbitrary",
        fuse_qkv_params=True,
        params_dtype=dtype,
    )

    # Initialize gradient buffers.
    for param in model.parameters():
        param.grad = torch.empty_like(param)

    # Make graphed version of model if needed.
    if with_graph:
        attn_mask = torch.zeros(
513
514
515
            (
                model_config.batch_size,
                1,
516
517
                model_config.max_seqlen_q,
                model_config.max_seqlen_kv,
518
            ),
519
520
521
522
523
            dtype=torch.bool,
            device="cuda",
        )
        model = make_graphed_callables(
            model,
524
            (generate_data(model_config, dtype, warmup=True),),
525
526
527
528
529
530
531
532
533
534
535
            sample_kwargs=dict(attention_mask=attn_mask),
            allow_unused_input=True,
        )

    # Optimizer.
    optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

    # Training loop.
    for _ in range(3):
        optimizer.zero_grad(set_to_none=False)
        for grad_accumulation_step in range(2):
536
537
            input_ = generate_data(model_config, dtype)
            grad_output = generate_data(model_config, dtype, requires_grad=False)
538
539
            attn_mask = torch.randint(
                2,
540
541
542
                (
                    model_config.batch_size,
                    1,
543
544
                    model_config.max_seqlen_q,
                    model_config.max_seqlen_kv,
545
                ),
546
547
548
                dtype=torch.bool,
                device="cuda",
            )
549
            output = model(input_, attention_mask=attn_mask)
550
551
552
553
554
555
556
            output.backward(grad_output)
        optimizer.step()

    return get_outputs(model, output)


def test_make_graphed_callables_with_kwargs(
557
558
    *,
    model_config: str = "small",
559
560
561
    dtype: torch.dtype = torch.float32,
) -> None:
    """Test CUDA graphs with keyword arguments."""
562
563
    model_config = model_configs[model_config]
    kwargs = dict(model_config=model_config, dtype=dtype)
564
565
566
567
568
569
570
571
    outputs = _test_cuda_graphs_with_kwargs(with_graph=False, **kwargs)
    graph_outputs = _test_cuda_graphs_with_kwargs(with_graph=True, **kwargs)
    assert_all_equal(outputs, graph_outputs)


def _test_cuda_graphs_with_interleaved_pipeline_parallelism(
    *,
    with_graph: bool,
572
573
    model_config: ModelConfig,
    dtype: torch.dtype,
574
575
576
577
578
579
580
581
582
583
584
585
586
) -> List[torch.Tensor]:
    """Simulate Megatron-LM interleaved pipeline parallelism."""
    reset_rng_states()

    # Pipeline parallel configuration.
    num_layers = 2
    num_microbatches = 3
    layer_order = [1, 2, 1, 2, -2, -1, 1, 2, -2, -1, -2, -1]

    # Initialize model.
    model = torch.nn.ModuleList(
        [
            Linear(
587
588
                model_config.hidden_size,
                model_config.hidden_size,
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
                params_dtype=dtype,
            )
            for _ in range(num_layers)
        ]
    )

    # Initialize gradient buffers.
    for param in model.parameters():
        param.grad = torch.empty_like(param)

    # Make graphed version of model if needed.
    layer_forwards = {
        (i % num_layers, i // num_layers): model[i % num_layers]
        for i in range(num_layers * num_microbatches)
    }
    if with_graph:
        sample_args = tuple(
606
607
            (generate_data(model_config, dtype, warmup=True),)
            for _ in range(num_layers * num_microbatches)
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
        )
        layer_forwards = make_graphed_callables(
            tuple(model),
            sample_args,
            allow_unused_input=True,
            _order=layer_order,
        )
        layer_forwards = {
            (i // num_microbatches, i % num_microbatches): forward
            for i, forward in enumerate(layer_forwards)
        }

    # Optimizer.
    optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

    # Training loop.
    for _ in range(3):
        optimizer.zero_grad(set_to_none=False)

        # Generate data.
        inputs = {}
        grad_outputs = {}
        for layer_idx in range(num_layers):
            for microbatch_idx in range(num_microbatches):
632
633
                x = generate_data(model_config, dtype)
                dy = generate_data(model_config, dtype, requires_grad=False)
634
                idxs = (layer_idx, microbatch_idx)
635
                inputs[idxs] = x
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
                grad_outputs[idxs] = dy

        # Cache for layer outputs.
        outputs = {}

        def forward(layer_idx: int, microbatch_idx: int):
            """Helper function for forward steps"""
            idxs = (layer_idx, microbatch_idx)
            outputs[idxs] = layer_forwards[idxs](inputs[idxs])

        def backward(layer_idx: int, microbatch_idx: int):
            """Helper function for backward steps"""
            outputs[layer_idx, microbatch_idx].backward(grad_outputs[layer_idx, microbatch_idx])

        # Forward and backward steps.
        forward(0, 0)
        forward(1, 0)
        forward(0, 1)
        forward(1, 1)
        backward(1, 0)
        backward(0, 0)
        forward(0, 2)
        forward(1, 2)
        backward(1, 1)
        backward(0, 1)
        backward(1, 2)
        backward(0, 2)

        # Optimizer step.
        optimizer.step()

    outputs = [y for _, y in sorted(outputs.items())]
    return get_outputs(model, outputs)


def test_make_graphed_callables_with_interleaved_pipeline_parallelism(
672
673
    *,
    model_config: str = "small",
674
675
676
    dtype: torch.dtype = torch.float16,
) -> None:
    """Test CUDA graphs with Megatron-LM interleaved pipeline parallelism."""
677
678
    model_config = model_configs[model_config]
    kwargs = dict(model_config=model_config, dtype=dtype)
679
680
681
682
683
684
685
686
687
    outputs = _test_cuda_graphs_with_interleaved_pipeline_parallelism(
        with_graph=False,
        **kwargs,
    )
    graph_outputs = _test_cuda_graphs_with_interleaved_pipeline_parallelism(
        with_graph=True,
        **kwargs,
    )
    assert_all_equal(outputs, graph_outputs)