test_cuda_graphs.py 20.2 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.

5
from dataclasses import dataclass
6
7
import itertools
from typing import Iterable, List, Tuple, Union
8
9
10
11
import pytest

import torch
from transformer_engine.pytorch import (
12
13
14
15
16
17
18
19
    DotProductAttention,
    LayerNormLinear,
    LayerNormMLP,
    Linear,
    MultiheadAttention,
    TransformerLayer,
    fp8_autocast,
    fp8_model_init,
20
    make_graphed_callables,
21
22
23
)
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.utils import is_bf16_compatible
24
import transformer_engine.pytorch.ops as te_ops
25
from transformer_engine.common import recipe
yuguo's avatar
yuguo committed
26
27
28
29
from torch.utils.cpp_extension import IS_HIP_EXTENSION
if IS_HIP_EXTENSION:
    import os
    from functools import cache
30

31
# Check if FP8 is supported.
32
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
33
34
35
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = (
    FP8GlobalStateManager.is_fp8_block_scaling_available()
)
36
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
37
38


39
# Record initial RNG state.
40
41
42
43
44
45
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
_cpu_rng_state = torch.get_rng_state()
_cuda_rng_state = torch.cuda.get_rng_state()

46

47
@dataclass
48
class ModelConfig:
49
    """Data tensor dimensions within Transformer model"""
50

51
52
53
54
55
    sequence_length: int
    batch_size: int
    hidden_size: int
    num_heads: int
    kv_channels: int
56

57

58
model_configs = {"small": ModelConfig(2, 32, 64, 2, 32)}
59

60
61
62
fp8_recipes = [
    recipe.DelayedScaling(),
    recipe.MXFP8BlockScaling(),
63
    recipe.Float8CurrentScaling(),
64
    recipe.Float8BlockScaling(),
65
66
]

67
68
# Supported data types
dtypes: List[torch.dtype] = [torch.float32, torch.float16]
69
70
71
72
73
if is_bf16_compatible():  # bf16 requires sm_80 or higher
    dtypes.append(torch.bfloat16)


def reset_rng_states() -> None:
74
    """Revert to initial RNG state."""
75
76
77
78
79
80
81
82
83
    torch.set_rng_state(_cpu_rng_state)
    torch.cuda.set_rng_state(_cuda_rng_state)


@pytest.fixture(autouse=True)
def reset_global_fp8_state():
    yield
    FP8GlobalStateManager.reset()

yuguo's avatar
yuguo committed
84
85
86
87
88
89
90
91
92
if IS_HIP_EXTENSION:
    @cache
    def use_hipblaslt() -> bool:
        return (os.getenv("NVTE_USE_HIPBLASLT") is not None
                or os.getenv("NVTE_USE_ROCBLAS") is None )
    @pytest.fixture(autouse=True)
    def skip_rocblas():
        if not use_hipblaslt():
            pytest.skip("CUDA graph capture not supported with rocBLAS path")
93
94

def assert_all_equal(l1: List[torch.Tensor], l2: List[torch.Tensor], names=None) -> bool:
95
    """Check that two lists of tensors match exactly."""
96
    assert len(l1) == len(l2), "Unequal number of outputs."
97
98
    failure_message = "Output mismatches in:"
    failed_tensors = []
99
100
    for i, (t1, t2) in enumerate(zip(l1, l2)):
        if not torch.equal(t1, t2):
101
102
103
104
105
106
107
108
109
110
            failure_message += "\n    "
            if names is None:
                failure_message += f"tensor at idx={i}"
            else:
                failure_message += names[i]
            failed_tensors.append((t1, t2))
    if failed_tensors:
        print(failure_message)
        t1, t2 = failed_tensors[0]
        torch.testing.assert_close(t1, t2, rtol=0, atol=0)
111
112
113


def generate_data(
114
    model_config: ModelConfig,
115
116
    dtype: torch.dtype,
    warmup: bool = False,
117
118
    requires_grad: bool = True,
) -> torch.Tensor:
119
120
    """Generate synthetic data."""
    gen_func = torch.ones if warmup else torch.randn
121
122
123
124
    return gen_func(
        model_config.sequence_length,
        model_config.batch_size,
        model_config.hidden_size,
125
        device="cuda",
126
        requires_grad=requires_grad,
127
128
        dtype=dtype,
    )
129
130


131
132
133
134
def get_outputs(
    model: torch.nn.Module,
    output: Union[torch.Tensor, Iterable[torch.Tensor]],
) -> List[torch.Tensor]:
135
136
137
138
139
140
    """Return grads and params for comparsion."""
    values = []
    for param in model.parameters():
        values.append(param)
        if param.grad is not None:
            values.append(param.grad)
141
142
143
144
    if isinstance(output, torch.Tensor):
        values.append(output)
    else:
        values.extend(output)
145
146
147
    return values


148
149
150
151
152
153
154
155
156
157
class _Sequential(torch.nn.Sequential):
    """Sequential model that forwards keyword arguments to modules"""

    def forward(self, input_: torch.Tensor, **kwargs) -> torch.Tensor:
        x = input_
        for module in self:
            x = module(x, **kwargs)
        return x


158
159
160
161
162
163
164
165
166
167
168
# Supported modules
_test_cuda_graphs_modules: List[str] = [
    "transformer",
    "layernorm_mlp",
    "layernorm_linear",
    "linear",
    "mha",
    "linear_op",
]


169
170
def _test_cuda_graphs(
    *,
171
172
173
    graph_mode: str,
    module: str,
    model_config: ModelConfig,
174
175
176
177
178
    num_layers: int,
    dtype: torch.dtype,
    fp8: bool,
    fp8_params: bool,
    fp8_weight_caching: bool,
179
    fp8_recipe: recipe.Recipe,
180
) -> List[torch.Tensor]:
181
    """Helper function for CUDA graph test."""
182
183
184
    reset_rng_states()
    FP8GlobalStateManager.reset()

185
186
187
188
189
    # Operation-based API does not support FP8 weight caching.
    if module == "linear_op":
        fp8_weight_caching = False

    # Create modules.
190
    with fp8_model_init(enabled=fp8_params, recipe=fp8_recipe):
191
        if module == "transformer":
192
193
            modules = [
                TransformerLayer(
194
195
196
                    model_config.hidden_size,
                    model_config.hidden_size,
                    model_config.num_heads,
197
198
199
200
201
202
203
                    hidden_dropout=0.0,
                    attention_dropout=0.0,
                    fuse_qkv_params=True,
                    params_dtype=dtype,
                )
                for _ in range(num_layers)
            ]
204
        elif module == "layernorm_mlp":
205
            modules = [
206
207
208
209
210
                LayerNormMLP(
                    model_config.hidden_size,
                    model_config.hidden_size,
                    params_dtype=dtype,
                )
211
212
                for _ in range(num_layers)
            ]
213
        elif module == "layernorm_linear":
214
            modules = [
215
216
217
218
219
                LayerNormLinear(
                    model_config.hidden_size,
                    model_config.hidden_size,
                    params_dtype=dtype,
                )
220
221
                for _ in range(num_layers)
            ]
222
        elif module == "mha":
223
224
            modules = [
                MultiheadAttention(
225
226
                    model_config.hidden_size,
                    model_config.num_heads,
227
228
229
230
231
232
                    attention_dropout=0.0,
                    params_dtype=dtype,
                    fuse_qkv_params=True,
                )
                for _ in range(num_layers)
            ]
233
        elif module == "linear":
234
            modules = [
235
236
237
238
239
240
                Linear(
                    model_config.hidden_size,
                    model_config.hidden_size,
                    device="cuda",
                    params_dtype=dtype,
                )
241
242
                for _ in range(num_layers)
            ]
243
        elif module == "linear_op":
244
            modules = [
245
246
247
248
249
250
251
                te_ops.Sequential(
                    te_ops.Linear(
                        model_config.hidden_size,
                        model_config.hidden_size,
                        dtype=dtype,
                    ),
                )
252
253
                for _ in range(num_layers)
            ]
254
255
        else:
            raise ValueError(f"Unknown module type ({module})")
256

257
258
259
260
261
        # Initialize gradient buffers.
        for module in modules:
            for param in module.parameters():
                param.grad = torch.empty_like(param)

262
        # Generate model and wrap API to return graphed version.
263
264
        if graph_mode == "full":
            # Graph entire model at once.
265
            model = torch.nn.Sequential(*modules)
266
267
            model = make_graphed_callables(
                model,
268
                (generate_data(model_config, dtype, warmup=True),),
269
270
271
                num_warmup_iters=10,
                fp8_enabled=fp8,
                fp8_weight_caching=fp8_weight_caching,
272
                fp8_recipe=fp8_recipe,
273
274
            )
        elif graph_mode == "individual":
275
            # Graph individual modules.
276
277
            modules = [
                make_graphed_callables(
278
                    module,
279
                    (generate_data(model_config, dtype, warmup=True),),
280
                    num_warmup_iters=10,
281
282
                    fp8_enabled=fp8,
                    fp8_weight_caching=fp8_weight_caching,
283
                    fp8_recipe=fp8_recipe,
284
285
286
                )
                for module in modules
            ]
287
            model = _Sequential(*modules)
288
        else:
289
            model = _Sequential(*modules)
290

291
    # Optimizer.
292
    optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
293

294
    # Training steps.
295
    for _ in range(3):
296
        optimizer.zero_grad(set_to_none=False)
297
        for grad_accumulation_step in range(2):
298
299
            input_ = generate_data(model_config, dtype)
            grad_output = generate_data(model_config, dtype, requires_grad=False)
300
            with fp8_autocast(enabled=fp8, fp8_recipe=fp8_recipe):
301
302
                kwargs = {}
                if fp8_weight_caching:
303
                    kwargs["is_first_microbatch"] = grad_accumulation_step == 0
304
                output = model(input_, **kwargs)
305
            output.backward(grad_output)
306
        optimizer.step()
307
308
309
310

    return get_outputs(model, output)


311
@pytest.mark.parametrize("module", _test_cuda_graphs_modules)
312
@pytest.mark.parametrize("dtype", dtypes)
313
314
@pytest.mark.parametrize("fp8", (False, True))
@pytest.mark.parametrize("fp8_params", (False, True))
315
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
316
317
318
319
320
def test_make_graphed_callables(
    *,
    module: str,
    model_config: str = "small",
    num_layers: int = 3,
321
322
323
    dtype: torch.dtype,
    fp8: bool,
    fp8_params: bool,
324
    fp8_recipe: recipe.Recipe,
325
    fp8_weight_caching: bool = False,
326
) -> None:
327
328

    # Skip invalid configurations.
329
330
331
332
    if fp8 and not fp8_available:
        pytest.skip(reason_for_no_fp8)
    if fp8_params and not fp8:
        pytest.skip("FP8 needed for FP8 parameters.")
333
334
    if fp8_weight_caching and not fp8:
        pytest.skip("FP8 needed for FP8 parameters.")
335
336
    if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
        pytest.skip(reason_for_no_fp8_block_scaling)
337
338
    if fp8_recipe.mxfp8() and not mxfp8_available:
        pytest.skip(reason_for_no_mxfp8)
339

340
341
    if fp8_recipe.float8_block_scaling() and module == "linear_op":
        pytest.skip("Module not yet supported for float8_block_scaling with CUDA graphs")
342
343
    # Run model with different CUDA graph settings.
    model_config = model_configs[model_config]
344
    kwargs = dict(
345
346
        module=module,
        model_config=model_config,
347
348
349
350
351
        num_layers=num_layers,
        dtype=dtype,
        fp8=fp8,
        fp8_params=fp8_params,
        fp8_weight_caching=fp8_weight_caching,
352
        fp8_recipe=fp8_recipe,
353
354
355
356
    )
    outputs = _test_cuda_graphs(graph_mode="none", **kwargs)
    graph_outputs_mode1 = _test_cuda_graphs(graph_mode="full", **kwargs)
    graph_outputs_mode2 = _test_cuda_graphs(graph_mode="individual", **kwargs)
357

358
    # Check that results match.
359
360
    assert_all_equal(outputs, graph_outputs_mode1)
    assert_all_equal(outputs, graph_outputs_mode2)
361
362


363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
_test_make_graphed_callables_with_fp8_weight_caching_modules = [
    "transformer",
    "layernorm_mlp",
    "layernorm_linear",
    "linear",
    "mha",
]


@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.parametrize(
    "module",
    _test_make_graphed_callables_with_fp8_weight_caching_modules,
)
@pytest.mark.parametrize("fp8_params", (False, True))
378
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
379
def test_make_graphed_callables_with_fp8_weight_caching(
380
    *,
381
382
    module: str,
    fp8_params: bool,
383
    fp8_recipe: recipe.Recipe,
384
385
386
387
388
389
) -> None:
    test_make_graphed_callables(
        module=module,
        dtype=torch.float32,
        fp8=True,
        fp8_params=fp8_params,
390
        fp8_recipe=fp8_recipe,
391
392
393
394
395
396
        fp8_weight_caching=True,
    )


def generate_data_for_dot_product_attention(
    model_config: ModelConfig,
397
    dtype: torch.dtype,
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
    warmup: bool = False,
) -> List[torch.Tensor]:
    """Generate synthetic data for dot product attention."""
    gen_func = torch.ones if warmup else torch.randn
    return [
        gen_func(
            model_config.sequence_length,
            model_config.batch_size,
            model_config.num_heads,
            model_config.kv_channels,
            device="cuda",
            requires_grad=True,
            dtype=dtype,
        )
        for _ in range(3)
    ]


def _test_cuda_graphs_with_dot_product_attention(
    *,
418
    with_graph: bool,
419
420
    model_config: ModelConfig,
    dtype: torch.dtype,
421
) -> List[torch.Tensor]:
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
    """Helper function for CUDA graph test."""
    reset_rng_states()
    FP8GlobalStateManager.reset()

    # Create dot product attention module.
    assert model_config.hidden_size % model_config.num_heads == 0
    model = DotProductAttention(
        model_config.num_heads,
        model_config.kv_channels,
        attention_dropout=0.0,
    )

    # Graph model if needed.
    if with_graph:
        model = make_graphed_callables(
            model,
            generate_data_for_dot_product_attention(model_config, dtype, warmup=True),
            num_warmup_iters=10,
            fp8_enabled=False,
        )

    # Forward and backward passes.
    for _ in range(3):
        inputs = generate_data_for_dot_product_attention(model_config, dtype)
        grad_output = generate_data(model_config, dtype, requires_grad=False)
        output = model(*inputs)
        output.backward(grad_output)

    return get_outputs(model, output)


@pytest.mark.parametrize("dtype", dtypes)
def test_make_graphed_callables_with_dot_product_attention(
    *,
    model_config: str = "small",
    dtype: torch.dtype,
) -> None:
    """Test CUDA graphs with dot product attention."""
    model_config = model_configs[model_config]
    kwargs = dict(model_config=model_config, dtype=dtype)
    outputs = _test_cuda_graphs_with_dot_product_attention(with_graph=False, **kwargs)
    graph_outputs = _test_cuda_graphs_with_dot_product_attention(with_graph=True, **kwargs)
    assert_all_equal(outputs, graph_outputs)


def _test_cuda_graphs_with_kwargs(
    *,
    with_graph: bool,
    model_config: ModelConfig,
    dtype: torch.dtype,
) -> List[torch.Tensor]:
    """Helper function for CUDA graph test with keyword arguments."""
474
475
476
477
    reset_rng_states()

    # Initialize model.
    model = TransformerLayer(
478
479
480
        model_config.hidden_size,
        model_config.hidden_size,
        model_config.num_heads,
481
482
483
484
485
486
487
488
489
490
491
492
493
494
        hidden_dropout=0.0,
        attention_dropout=0.0,
        self_attn_mask_type="arbitrary",
        fuse_qkv_params=True,
        params_dtype=dtype,
    )

    # Initialize gradient buffers.
    for param in model.parameters():
        param.grad = torch.empty_like(param)

    # Make graphed version of model if needed.
    if with_graph:
        attn_mask = torch.zeros(
495
496
497
498
499
500
            (
                model_config.batch_size,
                1,
                model_config.sequence_length,
                model_config.sequence_length,
            ),
501
502
503
504
505
            dtype=torch.bool,
            device="cuda",
        )
        model = make_graphed_callables(
            model,
506
            (generate_data(model_config, dtype, warmup=True),),
507
508
509
510
511
512
513
514
515
516
517
            sample_kwargs=dict(attention_mask=attn_mask),
            allow_unused_input=True,
        )

    # Optimizer.
    optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

    # Training loop.
    for _ in range(3):
        optimizer.zero_grad(set_to_none=False)
        for grad_accumulation_step in range(2):
518
519
            input_ = generate_data(model_config, dtype)
            grad_output = generate_data(model_config, dtype, requires_grad=False)
520
521
            attn_mask = torch.randint(
                2,
522
523
524
525
526
527
                (
                    model_config.batch_size,
                    1,
                    model_config.sequence_length,
                    model_config.sequence_length,
                ),
528
529
530
                dtype=torch.bool,
                device="cuda",
            )
531
            output = model(input_, attention_mask=attn_mask)
532
533
534
535
536
537
538
            output.backward(grad_output)
        optimizer.step()

    return get_outputs(model, output)


def test_make_graphed_callables_with_kwargs(
539
540
    *,
    model_config: str = "small",
541
542
543
    dtype: torch.dtype = torch.float32,
) -> None:
    """Test CUDA graphs with keyword arguments."""
544
545
    model_config = model_configs[model_config]
    kwargs = dict(model_config=model_config, dtype=dtype)
546
547
548
549
550
551
552
553
    outputs = _test_cuda_graphs_with_kwargs(with_graph=False, **kwargs)
    graph_outputs = _test_cuda_graphs_with_kwargs(with_graph=True, **kwargs)
    assert_all_equal(outputs, graph_outputs)


def _test_cuda_graphs_with_interleaved_pipeline_parallelism(
    *,
    with_graph: bool,
554
555
    model_config: ModelConfig,
    dtype: torch.dtype,
556
557
558
559
560
561
562
563
564
565
566
567
568
) -> List[torch.Tensor]:
    """Simulate Megatron-LM interleaved pipeline parallelism."""
    reset_rng_states()

    # Pipeline parallel configuration.
    num_layers = 2
    num_microbatches = 3
    layer_order = [1, 2, 1, 2, -2, -1, 1, 2, -2, -1, -2, -1]

    # Initialize model.
    model = torch.nn.ModuleList(
        [
            Linear(
569
570
                model_config.hidden_size,
                model_config.hidden_size,
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
                params_dtype=dtype,
            )
            for _ in range(num_layers)
        ]
    )

    # Initialize gradient buffers.
    for param in model.parameters():
        param.grad = torch.empty_like(param)

    # Make graphed version of model if needed.
    layer_forwards = {
        (i % num_layers, i // num_layers): model[i % num_layers]
        for i in range(num_layers * num_microbatches)
    }
    if with_graph:
        sample_args = tuple(
588
589
            (generate_data(model_config, dtype, warmup=True),)
            for _ in range(num_layers * num_microbatches)
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
        )
        layer_forwards = make_graphed_callables(
            tuple(model),
            sample_args,
            allow_unused_input=True,
            _order=layer_order,
        )
        layer_forwards = {
            (i // num_microbatches, i % num_microbatches): forward
            for i, forward in enumerate(layer_forwards)
        }

    # Optimizer.
    optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

    # Training loop.
    for _ in range(3):
        optimizer.zero_grad(set_to_none=False)

        # Generate data.
        inputs = {}
        grad_outputs = {}
        for layer_idx in range(num_layers):
            for microbatch_idx in range(num_microbatches):
614
615
                x = generate_data(model_config, dtype)
                dy = generate_data(model_config, dtype, requires_grad=False)
616
                idxs = (layer_idx, microbatch_idx)
617
                inputs[idxs] = x
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
                grad_outputs[idxs] = dy

        # Cache for layer outputs.
        outputs = {}

        def forward(layer_idx: int, microbatch_idx: int):
            """Helper function for forward steps"""
            idxs = (layer_idx, microbatch_idx)
            outputs[idxs] = layer_forwards[idxs](inputs[idxs])

        def backward(layer_idx: int, microbatch_idx: int):
            """Helper function for backward steps"""
            outputs[layer_idx, microbatch_idx].backward(grad_outputs[layer_idx, microbatch_idx])

        # Forward and backward steps.
        forward(0, 0)
        forward(1, 0)
        forward(0, 1)
        forward(1, 1)
        backward(1, 0)
        backward(0, 0)
        forward(0, 2)
        forward(1, 2)
        backward(1, 1)
        backward(0, 1)
        backward(1, 2)
        backward(0, 2)

        # Optimizer step.
        optimizer.step()

    outputs = [y for _, y in sorted(outputs.items())]
    return get_outputs(model, outputs)


def test_make_graphed_callables_with_interleaved_pipeline_parallelism(
654
655
    *,
    model_config: str = "small",
656
657
658
    dtype: torch.dtype = torch.float16,
) -> None:
    """Test CUDA graphs with Megatron-LM interleaved pipeline parallelism."""
659
660
    model_config = model_configs[model_config]
    kwargs = dict(model_config=model_config, dtype=dtype)
661
662
663
664
665
666
667
668
669
    outputs = _test_cuda_graphs_with_interleaved_pipeline_parallelism(
        with_graph=False,
        **kwargs,
    )
    graph_outputs = _test_cuda_graphs_with_interleaved_pipeline_parallelism(
        with_graph=True,
        **kwargs,
    )
    assert_all_equal(outputs, graph_outputs)