test_cuda_graphs.py 20 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.

5
from typing import Iterable, List, Union
6
7
8
9
import pytest

import torch
from transformer_engine.pytorch import (
10
11
12
13
14
15
16
17
    DotProductAttention,
    LayerNormLinear,
    LayerNormMLP,
    Linear,
    MultiheadAttention,
    TransformerLayer,
    fp8_autocast,
    fp8_model_init,
18
    make_graphed_callables,
19
20
21
)
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.utils import is_bf16_compatible
22
import transformer_engine.pytorch.ops as te_ops
23
from transformer_engine.common import recipe
24
from utils import ModelConfig, reset_rng_states
yuguo's avatar
yuguo committed
25
26
27
28
from torch.utils.cpp_extension import IS_HIP_EXTENSION
if IS_HIP_EXTENSION:
    import os
    from functools import cache
29

30
# Check if FP8 is supported.
31
32
33
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = FP8GlobalStateManager.is_fp8_block_scaling_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
34

35
36
# Reset RNG states.
reset_rng_states()
37

38
39
40
model_configs = {
    "small": ModelConfig(32, 2, 2, 32),
}
41

42
43
44
45
46
47
48
49
fp8_recipes = []
if mxfp8_available:
    fp8_recipes.append(recipe.MXFP8BlockScaling())
if fp8_block_scaling_available:
    fp8_recipes.append(recipe.Float8BlockScaling())
if fp8_available:
    fp8_recipes.append(recipe.Float8CurrentScaling())
    fp8_recipes.append(recipe.DelayedScaling())
50

51
52
# Supported data types
dtypes: List[torch.dtype] = [torch.float32, torch.float16]
53
54
55
56
57
58
59
60
61
if is_bf16_compatible():  # bf16 requires sm_80 or higher
    dtypes.append(torch.bfloat16)


@pytest.fixture(autouse=True)
def reset_global_fp8_state():
    yield
    FP8GlobalStateManager.reset()

yuguo's avatar
yuguo committed
62
63
64
65
66
67
68
69
70
if IS_HIP_EXTENSION:
    @cache
    def use_hipblaslt() -> bool:
        return (os.getenv("NVTE_USE_HIPBLASLT") is not None
                or os.getenv("NVTE_USE_ROCBLAS") is None )
    @pytest.fixture(autouse=True)
    def skip_rocblas():
        if not use_hipblaslt():
            pytest.skip("CUDA graph capture not supported with rocBLAS path")
71
72

def assert_all_equal(l1: List[torch.Tensor], l2: List[torch.Tensor], names=None) -> bool:
73
    """Check that two lists of tensors match exactly."""
74
    assert len(l1) == len(l2), "Unequal number of outputs."
75
76
    failure_message = "Output mismatches in:"
    failed_tensors = []
77
78
    for i, (t1, t2) in enumerate(zip(l1, l2)):
        if not torch.equal(t1, t2):
79
80
81
82
83
84
85
86
87
88
            failure_message += "\n    "
            if names is None:
                failure_message += f"tensor at idx={i}"
            else:
                failure_message += names[i]
            failed_tensors.append((t1, t2))
    if failed_tensors:
        print(failure_message)
        t1, t2 = failed_tensors[0]
        torch.testing.assert_close(t1, t2, rtol=0, atol=0)
89
90
91


def generate_data(
92
    model_config: ModelConfig,
93
94
    dtype: torch.dtype,
    warmup: bool = False,
95
96
    requires_grad: bool = True,
) -> torch.Tensor:
97
98
    """Generate synthetic data."""
    gen_func = torch.ones if warmup else torch.randn
99
    return gen_func(
100
        model_config.max_seqlen_q,
101
102
        model_config.batch_size,
        model_config.hidden_size,
103
        device="cuda",
104
        requires_grad=requires_grad,
105
106
        dtype=dtype,
    )
107
108


109
110
111
112
def get_outputs(
    model: torch.nn.Module,
    output: Union[torch.Tensor, Iterable[torch.Tensor]],
) -> List[torch.Tensor]:
113
114
115
116
117
118
    """Return grads and params for comparsion."""
    values = []
    for param in model.parameters():
        values.append(param)
        if param.grad is not None:
            values.append(param.grad)
119
120
121
122
    if isinstance(output, torch.Tensor):
        values.append(output)
    else:
        values.extend(output)
123
124
125
    return values


126
127
128
129
130
131
132
133
134
135
class _Sequential(torch.nn.Sequential):
    """Sequential model that forwards keyword arguments to modules"""

    def forward(self, input_: torch.Tensor, **kwargs) -> torch.Tensor:
        x = input_
        for module in self:
            x = module(x, **kwargs)
        return x


136
137
# Supported modules
_test_cuda_graphs_modules: List[str] = [
138
139
140
    # Put linear first to test the case where the cuda context might not be set in
    # creating TMA descriptor for MXFP8 quantization.
    "linear",
141
142
143
144
145
146
147
148
    "transformer",
    "layernorm_mlp",
    "layernorm_linear",
    "mha",
    "linear_op",
]


149
150
def _test_cuda_graphs(
    *,
151
152
153
    graph_mode: str,
    module: str,
    model_config: ModelConfig,
154
155
156
157
158
    num_layers: int,
    dtype: torch.dtype,
    fp8: bool,
    fp8_params: bool,
    fp8_weight_caching: bool,
159
    fp8_recipe: recipe.Recipe,
160
) -> List[torch.Tensor]:
161
    """Helper function for CUDA graph test."""
162
163
164
    reset_rng_states()
    FP8GlobalStateManager.reset()

165
166
167
168
169
    # Operation-based API does not support FP8 weight caching.
    if module == "linear_op":
        fp8_weight_caching = False

    # Create modules.
170
    with fp8_model_init(enabled=fp8_params, recipe=fp8_recipe):
171
        if module == "transformer":
172
173
            modules = [
                TransformerLayer(
174
175
176
                    model_config.hidden_size,
                    model_config.hidden_size,
                    model_config.num_heads,
177
178
179
180
181
182
183
                    hidden_dropout=0.0,
                    attention_dropout=0.0,
                    fuse_qkv_params=True,
                    params_dtype=dtype,
                )
                for _ in range(num_layers)
            ]
184
        elif module == "layernorm_mlp":
185
            modules = [
186
187
188
189
190
                LayerNormMLP(
                    model_config.hidden_size,
                    model_config.hidden_size,
                    params_dtype=dtype,
                )
191
192
                for _ in range(num_layers)
            ]
193
        elif module == "layernorm_linear":
194
            modules = [
195
196
197
198
199
                LayerNormLinear(
                    model_config.hidden_size,
                    model_config.hidden_size,
                    params_dtype=dtype,
                )
200
201
                for _ in range(num_layers)
            ]
202
        elif module == "mha":
203
204
            modules = [
                MultiheadAttention(
205
206
                    model_config.hidden_size,
                    model_config.num_heads,
207
208
209
210
211
212
                    attention_dropout=0.0,
                    params_dtype=dtype,
                    fuse_qkv_params=True,
                )
                for _ in range(num_layers)
            ]
213
        elif module == "linear":
214
            modules = [
215
216
217
218
219
220
                Linear(
                    model_config.hidden_size,
                    model_config.hidden_size,
                    device="cuda",
                    params_dtype=dtype,
                )
221
222
                for _ in range(num_layers)
            ]
223
        elif module == "linear_op":
224
            modules = [
225
226
227
228
229
230
231
                te_ops.Sequential(
                    te_ops.Linear(
                        model_config.hidden_size,
                        model_config.hidden_size,
                        dtype=dtype,
                    ),
                )
232
233
                for _ in range(num_layers)
            ]
234
235
        else:
            raise ValueError(f"Unknown module type ({module})")
236

237
238
239
240
241
        # Initialize gradient buffers.
        for module in modules:
            for param in module.parameters():
                param.grad = torch.empty_like(param)

242
        # Generate model and wrap API to return graphed version.
243
244
        if graph_mode == "full":
            # Graph entire model at once.
245
            model = torch.nn.Sequential(*modules)
246
247
            model = make_graphed_callables(
                model,
248
                (generate_data(model_config, dtype, warmup=True),),
249
250
251
                num_warmup_iters=10,
                fp8_enabled=fp8,
                fp8_weight_caching=fp8_weight_caching,
252
                fp8_recipe=fp8_recipe,
253
254
            )
        elif graph_mode == "individual":
255
            # Graph individual modules.
256
257
            modules = [
                make_graphed_callables(
258
                    module,
259
                    (generate_data(model_config, dtype, warmup=True),),
260
                    num_warmup_iters=10,
261
262
                    fp8_enabled=fp8,
                    fp8_weight_caching=fp8_weight_caching,
263
                    fp8_recipe=fp8_recipe,
264
265
266
                )
                for module in modules
            ]
267
            model = _Sequential(*modules)
268
        else:
269
            model = _Sequential(*modules)
270

271
    # Optimizer.
272
    optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
273

274
    # Training steps.
275
    for _ in range(3):
276
        optimizer.zero_grad(set_to_none=False)
277
        for grad_accumulation_step in range(2):
278
279
            input_ = generate_data(model_config, dtype)
            grad_output = generate_data(model_config, dtype, requires_grad=False)
280
            with fp8_autocast(enabled=fp8, fp8_recipe=fp8_recipe):
281
282
                kwargs = {}
                if fp8_weight_caching:
283
                    kwargs["is_first_microbatch"] = grad_accumulation_step == 0
284
                output = model(input_, **kwargs)
285
            output.backward(grad_output)
286
        optimizer.step()
287
288
289
290

    return get_outputs(model, output)


291
@pytest.mark.parametrize("module", _test_cuda_graphs_modules)
292
@pytest.mark.parametrize("dtype", dtypes)
293
@pytest.mark.parametrize("fp8_params", (False, True))
294
@pytest.mark.parametrize("fp8_recipe", fp8_recipes + [None])
295
296
297
298
299
def test_make_graphed_callables(
    *,
    module: str,
    model_config: str = "small",
    num_layers: int = 3,
300
301
    dtype: torch.dtype,
    fp8_params: bool,
302
    fp8_recipe: recipe.Recipe,
303
    fp8_weight_caching: bool = False,
304
) -> None:
305

306
    fp8 = fp8_recipe is not None
307
308
    if fp8_params and not fp8:
        pytest.skip("FP8 needed for FP8 parameters.")
309
310
    if fp8_weight_caching and not fp8:
        pytest.skip("FP8 needed for FP8 parameters.")
311
    if fp8 and fp8_recipe.float8_block_scaling() and module == "linear_op":
312
        pytest.skip("Module not yet supported for float8_block_scaling with CUDA graphs")
313
314
315
316
317
318
    if fp8 and not fp8_available:
        pytest.skip(reason_for_no_fp8)
    if fp8 and fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
        pytest.skip(reason_for_no_fp8_block_scaling)
    if fp8 and fp8_recipe.mxfp8() and not mxfp8_available:
        pytest.skip(reason_for_no_mxfp8)
319

320
321
    # Run model with different CUDA graph settings.
    model_config = model_configs[model_config]
322
    kwargs = dict(
323
324
        module=module,
        model_config=model_config,
325
326
327
328
329
        num_layers=num_layers,
        dtype=dtype,
        fp8=fp8,
        fp8_params=fp8_params,
        fp8_weight_caching=fp8_weight_caching,
330
        fp8_recipe=fp8_recipe,
331
    )
332
333
    # Put graphed callables first to test the case where the cuda context might not be set in
    # creating TMA descriptor for MXFP8 quantization.
334
335
    graph_outputs_mode1 = _test_cuda_graphs(graph_mode="full", **kwargs)
    graph_outputs_mode2 = _test_cuda_graphs(graph_mode="individual", **kwargs)
336
    outputs = _test_cuda_graphs(graph_mode="none", **kwargs)
337

338
    # Check that results match.
339
340
    assert_all_equal(outputs, graph_outputs_mode1)
    assert_all_equal(outputs, graph_outputs_mode2)
341
342


343
344
345
346
347
348
349
350
351
352
353
354
355
356
_test_make_graphed_callables_with_fp8_weight_caching_modules = [
    "transformer",
    "layernorm_mlp",
    "layernorm_linear",
    "linear",
    "mha",
]


@pytest.mark.parametrize(
    "module",
    _test_make_graphed_callables_with_fp8_weight_caching_modules,
)
@pytest.mark.parametrize("fp8_params", (False, True))
357
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
358
def test_make_graphed_callables_with_fp8_weight_caching(
359
    *,
360
361
    module: str,
    fp8_params: bool,
362
    fp8_recipe: recipe.Recipe,
363
364
365
366
367
) -> None:
    test_make_graphed_callables(
        module=module,
        dtype=torch.float32,
        fp8_params=fp8_params,
368
        fp8_recipe=fp8_recipe,
369
370
371
372
373
374
        fp8_weight_caching=True,
    )


def generate_data_for_dot_product_attention(
    model_config: ModelConfig,
375
    dtype: torch.dtype,
376
377
378
379
380
381
    warmup: bool = False,
) -> List[torch.Tensor]:
    """Generate synthetic data for dot product attention."""
    gen_func = torch.ones if warmup else torch.randn
    return [
        gen_func(
382
            model_config.max_seqlen_q,
383
384
385
386
387
388
389
390
391
392
393
394
395
            model_config.batch_size,
            model_config.num_heads,
            model_config.kv_channels,
            device="cuda",
            requires_grad=True,
            dtype=dtype,
        )
        for _ in range(3)
    ]


def _test_cuda_graphs_with_dot_product_attention(
    *,
396
    with_graph: bool,
397
398
    model_config: ModelConfig,
    dtype: torch.dtype,
399
) -> List[torch.Tensor]:
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
    """Helper function for CUDA graph test."""
    reset_rng_states()
    FP8GlobalStateManager.reset()

    # Create dot product attention module.
    assert model_config.hidden_size % model_config.num_heads == 0
    model = DotProductAttention(
        model_config.num_heads,
        model_config.kv_channels,
        attention_dropout=0.0,
    )

    # Graph model if needed.
    if with_graph:
        model = make_graphed_callables(
            model,
            generate_data_for_dot_product_attention(model_config, dtype, warmup=True),
            num_warmup_iters=10,
            fp8_enabled=False,
        )

    # Forward and backward passes.
    for _ in range(3):
        inputs = generate_data_for_dot_product_attention(model_config, dtype)
        grad_output = generate_data(model_config, dtype, requires_grad=False)
        output = model(*inputs)
        output.backward(grad_output)

    return get_outputs(model, output)


@pytest.mark.parametrize("dtype", dtypes)
def test_make_graphed_callables_with_dot_product_attention(
    *,
    model_config: str = "small",
    dtype: torch.dtype,
) -> None:
    """Test CUDA graphs with dot product attention."""
    model_config = model_configs[model_config]
    kwargs = dict(model_config=model_config, dtype=dtype)
    outputs = _test_cuda_graphs_with_dot_product_attention(with_graph=False, **kwargs)
    graph_outputs = _test_cuda_graphs_with_dot_product_attention(with_graph=True, **kwargs)
    assert_all_equal(outputs, graph_outputs)


def _test_cuda_graphs_with_kwargs(
    *,
    with_graph: bool,
    model_config: ModelConfig,
    dtype: torch.dtype,
) -> List[torch.Tensor]:
    """Helper function for CUDA graph test with keyword arguments."""
452
453
454
455
    reset_rng_states()

    # Initialize model.
    model = TransformerLayer(
456
457
458
        model_config.hidden_size,
        model_config.hidden_size,
        model_config.num_heads,
459
460
461
462
463
464
465
466
467
468
469
470
471
472
        hidden_dropout=0.0,
        attention_dropout=0.0,
        self_attn_mask_type="arbitrary",
        fuse_qkv_params=True,
        params_dtype=dtype,
    )

    # Initialize gradient buffers.
    for param in model.parameters():
        param.grad = torch.empty_like(param)

    # Make graphed version of model if needed.
    if with_graph:
        attn_mask = torch.zeros(
473
474
475
            (
                model_config.batch_size,
                1,
476
477
                model_config.max_seqlen_q,
                model_config.max_seqlen_kv,
478
            ),
479
480
481
482
483
            dtype=torch.bool,
            device="cuda",
        )
        model = make_graphed_callables(
            model,
484
            (generate_data(model_config, dtype, warmup=True),),
485
486
487
488
489
490
491
492
493
494
495
            sample_kwargs=dict(attention_mask=attn_mask),
            allow_unused_input=True,
        )

    # Optimizer.
    optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

    # Training loop.
    for _ in range(3):
        optimizer.zero_grad(set_to_none=False)
        for grad_accumulation_step in range(2):
496
497
            input_ = generate_data(model_config, dtype)
            grad_output = generate_data(model_config, dtype, requires_grad=False)
498
499
            attn_mask = torch.randint(
                2,
500
501
502
                (
                    model_config.batch_size,
                    1,
503
504
                    model_config.max_seqlen_q,
                    model_config.max_seqlen_kv,
505
                ),
506
507
508
                dtype=torch.bool,
                device="cuda",
            )
509
            output = model(input_, attention_mask=attn_mask)
510
511
512
513
514
515
516
            output.backward(grad_output)
        optimizer.step()

    return get_outputs(model, output)


def test_make_graphed_callables_with_kwargs(
517
518
    *,
    model_config: str = "small",
519
520
521
    dtype: torch.dtype = torch.float32,
) -> None:
    """Test CUDA graphs with keyword arguments."""
522
523
    model_config = model_configs[model_config]
    kwargs = dict(model_config=model_config, dtype=dtype)
524
525
526
527
528
529
530
531
    outputs = _test_cuda_graphs_with_kwargs(with_graph=False, **kwargs)
    graph_outputs = _test_cuda_graphs_with_kwargs(with_graph=True, **kwargs)
    assert_all_equal(outputs, graph_outputs)


def _test_cuda_graphs_with_interleaved_pipeline_parallelism(
    *,
    with_graph: bool,
532
533
    model_config: ModelConfig,
    dtype: torch.dtype,
534
535
536
537
538
539
540
541
542
543
544
545
546
) -> List[torch.Tensor]:
    """Simulate Megatron-LM interleaved pipeline parallelism."""
    reset_rng_states()

    # Pipeline parallel configuration.
    num_layers = 2
    num_microbatches = 3
    layer_order = [1, 2, 1, 2, -2, -1, 1, 2, -2, -1, -2, -1]

    # Initialize model.
    model = torch.nn.ModuleList(
        [
            Linear(
547
548
                model_config.hidden_size,
                model_config.hidden_size,
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
                params_dtype=dtype,
            )
            for _ in range(num_layers)
        ]
    )

    # Initialize gradient buffers.
    for param in model.parameters():
        param.grad = torch.empty_like(param)

    # Make graphed version of model if needed.
    layer_forwards = {
        (i % num_layers, i // num_layers): model[i % num_layers]
        for i in range(num_layers * num_microbatches)
    }
    if with_graph:
        sample_args = tuple(
566
567
            (generate_data(model_config, dtype, warmup=True),)
            for _ in range(num_layers * num_microbatches)
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
        )
        layer_forwards = make_graphed_callables(
            tuple(model),
            sample_args,
            allow_unused_input=True,
            _order=layer_order,
        )
        layer_forwards = {
            (i // num_microbatches, i % num_microbatches): forward
            for i, forward in enumerate(layer_forwards)
        }

    # Optimizer.
    optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

    # Training loop.
    for _ in range(3):
        optimizer.zero_grad(set_to_none=False)

        # Generate data.
        inputs = {}
        grad_outputs = {}
        for layer_idx in range(num_layers):
            for microbatch_idx in range(num_microbatches):
592
593
                x = generate_data(model_config, dtype)
                dy = generate_data(model_config, dtype, requires_grad=False)
594
                idxs = (layer_idx, microbatch_idx)
595
                inputs[idxs] = x
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
                grad_outputs[idxs] = dy

        # Cache for layer outputs.
        outputs = {}

        def forward(layer_idx: int, microbatch_idx: int):
            """Helper function for forward steps"""
            idxs = (layer_idx, microbatch_idx)
            outputs[idxs] = layer_forwards[idxs](inputs[idxs])

        def backward(layer_idx: int, microbatch_idx: int):
            """Helper function for backward steps"""
            outputs[layer_idx, microbatch_idx].backward(grad_outputs[layer_idx, microbatch_idx])

        # Forward and backward steps.
        forward(0, 0)
        forward(1, 0)
        forward(0, 1)
        forward(1, 1)
        backward(1, 0)
        backward(0, 0)
        forward(0, 2)
        forward(1, 2)
        backward(1, 1)
        backward(0, 1)
        backward(1, 2)
        backward(0, 2)

        # Optimizer step.
        optimizer.step()

    outputs = [y for _, y in sorted(outputs.items())]
    return get_outputs(model, outputs)


def test_make_graphed_callables_with_interleaved_pipeline_parallelism(
632
633
    *,
    model_config: str = "small",
634
635
636
    dtype: torch.dtype = torch.float16,
) -> None:
    """Test CUDA graphs with Megatron-LM interleaved pipeline parallelism."""
637
638
    model_config = model_configs[model_config]
    kwargs = dict(model_config=model_config, dtype=dtype)
639
640
641
642
643
644
645
646
647
    outputs = _test_cuda_graphs_with_interleaved_pipeline_parallelism(
        with_graph=False,
        **kwargs,
    )
    graph_outputs = _test_cuda_graphs_with_interleaved_pipeline_parallelism(
        with_graph=True,
        **kwargs,
    )
    assert_all_equal(outputs, graph_outputs)