test_cuda_graphs.py 19.7 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.

5
from dataclasses import dataclass
6
7
import itertools
from typing import Iterable, List, Tuple, Union
8
9
10
11
import pytest

import torch
from transformer_engine.pytorch import (
12
13
14
15
16
17
18
19
    DotProductAttention,
    LayerNormLinear,
    LayerNormMLP,
    Linear,
    MultiheadAttention,
    TransformerLayer,
    fp8_autocast,
    fp8_model_init,
20
    make_graphed_callables,
21
22
23
)
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.utils import is_bf16_compatible
24
import transformer_engine.pytorch.ops as te_ops
25
from transformer_engine.common import recipe
yuguo's avatar
yuguo committed
26
27
28
29
from torch.utils.cpp_extension import IS_HIP_EXTENSION
if IS_HIP_EXTENSION:
    import os
    from functools import cache
30

31
# Check if FP8 is supported.
32
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
33
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
34
35


36
# Record initial RNG state.
37
38
39
40
41
42
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
_cpu_rng_state = torch.get_rng_state()
_cuda_rng_state = torch.cuda.get_rng_state()

43

44
@dataclass
45
class ModelConfig:
46
    """Data tensor dimensions within Transformer model"""
47

48
49
50
51
52
    sequence_length: int
    batch_size: int
    hidden_size: int
    num_heads: int
    kv_channels: int
53

54

55
model_configs = {"small": ModelConfig(2, 32, 64, 2, 32)}
56

57
58
59
fp8_recipes = [
    recipe.DelayedScaling(),
    recipe.MXFP8BlockScaling(),
60
    recipe.Float8CurrentScaling(),
61
62
]

63
64
# Supported data types
dtypes: List[torch.dtype] = [torch.float32, torch.float16]
65
66
67
68
69
if is_bf16_compatible():  # bf16 requires sm_80 or higher
    dtypes.append(torch.bfloat16)


def reset_rng_states() -> None:
70
    """Revert to initial RNG state."""
71
72
73
74
75
76
77
78
79
    torch.set_rng_state(_cpu_rng_state)
    torch.cuda.set_rng_state(_cuda_rng_state)


@pytest.fixture(autouse=True)
def reset_global_fp8_state():
    yield
    FP8GlobalStateManager.reset()

yuguo's avatar
yuguo committed
80
81
82
83
84
85
86
87
88
if IS_HIP_EXTENSION:
    @cache
    def use_hipblaslt() -> bool:
        return (os.getenv("NVTE_USE_HIPBLASLT") is not None
                or os.getenv("NVTE_USE_ROCBLAS") is None )
    @pytest.fixture(autouse=True)
    def skip_rocblas():
        if not use_hipblaslt():
            pytest.skip("CUDA graph capture not supported with rocBLAS path")
89
90

def assert_all_equal(l1: List[torch.Tensor], l2: List[torch.Tensor], names=None) -> bool:
91
    """Check that two lists of tensors match exactly."""
92
    assert len(l1) == len(l2), "Unequal number of outputs."
93
94
    failure_message = "Output mismatches in:"
    failed_tensors = []
95
96
    for i, (t1, t2) in enumerate(zip(l1, l2)):
        if not torch.equal(t1, t2):
97
98
99
100
101
102
103
104
105
106
            failure_message += "\n    "
            if names is None:
                failure_message += f"tensor at idx={i}"
            else:
                failure_message += names[i]
            failed_tensors.append((t1, t2))
    if failed_tensors:
        print(failure_message)
        t1, t2 = failed_tensors[0]
        torch.testing.assert_close(t1, t2, rtol=0, atol=0)
107
108
109


def generate_data(
110
    model_config: ModelConfig,
111
112
    dtype: torch.dtype,
    warmup: bool = False,
113
114
    requires_grad: bool = True,
) -> torch.Tensor:
115
116
    """Generate synthetic data."""
    gen_func = torch.ones if warmup else torch.randn
117
118
119
120
    return gen_func(
        model_config.sequence_length,
        model_config.batch_size,
        model_config.hidden_size,
121
        device="cuda",
122
        requires_grad=requires_grad,
123
124
        dtype=dtype,
    )
125
126


127
128
129
130
def get_outputs(
    model: torch.nn.Module,
    output: Union[torch.Tensor, Iterable[torch.Tensor]],
) -> List[torch.Tensor]:
131
132
133
134
135
136
    """Return grads and params for comparsion."""
    values = []
    for param in model.parameters():
        values.append(param)
        if param.grad is not None:
            values.append(param.grad)
137
138
139
140
    if isinstance(output, torch.Tensor):
        values.append(output)
    else:
        values.extend(output)
141
142
143
    return values


144
145
146
147
148
149
150
151
152
153
class _Sequential(torch.nn.Sequential):
    """Sequential model that forwards keyword arguments to modules"""

    def forward(self, input_: torch.Tensor, **kwargs) -> torch.Tensor:
        x = input_
        for module in self:
            x = module(x, **kwargs)
        return x


154
155
156
157
158
159
160
161
162
163
164
# Supported modules
_test_cuda_graphs_modules: List[str] = [
    "transformer",
    "layernorm_mlp",
    "layernorm_linear",
    "linear",
    "mha",
    "linear_op",
]


165
166
def _test_cuda_graphs(
    *,
167
168
169
    graph_mode: str,
    module: str,
    model_config: ModelConfig,
170
171
172
173
174
    num_layers: int,
    dtype: torch.dtype,
    fp8: bool,
    fp8_params: bool,
    fp8_weight_caching: bool,
175
    fp8_recipe: recipe.Recipe,
176
) -> List[torch.Tensor]:
177
    """Helper function for CUDA graph test."""
178
179
180
    reset_rng_states()
    FP8GlobalStateManager.reset()

181
182
183
184
185
    # Operation-based API does not support FP8 weight caching.
    if module == "linear_op":
        fp8_weight_caching = False

    # Create modules.
186
    with fp8_model_init(enabled=fp8_params, recipe=fp8_recipe):
187
        if module == "transformer":
188
189
            modules = [
                TransformerLayer(
190
191
192
                    model_config.hidden_size,
                    model_config.hidden_size,
                    model_config.num_heads,
193
194
195
196
197
198
199
                    hidden_dropout=0.0,
                    attention_dropout=0.0,
                    fuse_qkv_params=True,
                    params_dtype=dtype,
                )
                for _ in range(num_layers)
            ]
200
        elif module == "layernorm_mlp":
201
            modules = [
202
203
204
205
206
                LayerNormMLP(
                    model_config.hidden_size,
                    model_config.hidden_size,
                    params_dtype=dtype,
                )
207
208
                for _ in range(num_layers)
            ]
209
        elif module == "layernorm_linear":
210
            modules = [
211
212
213
214
215
                LayerNormLinear(
                    model_config.hidden_size,
                    model_config.hidden_size,
                    params_dtype=dtype,
                )
216
217
                for _ in range(num_layers)
            ]
218
        elif module == "mha":
219
220
            modules = [
                MultiheadAttention(
221
222
                    model_config.hidden_size,
                    model_config.num_heads,
223
224
225
226
227
228
                    attention_dropout=0.0,
                    params_dtype=dtype,
                    fuse_qkv_params=True,
                )
                for _ in range(num_layers)
            ]
229
        elif module == "linear":
230
            modules = [
231
232
233
234
235
236
                Linear(
                    model_config.hidden_size,
                    model_config.hidden_size,
                    device="cuda",
                    params_dtype=dtype,
                )
237
238
                for _ in range(num_layers)
            ]
239
        elif module == "linear_op":
240
            modules = [
241
242
243
244
245
246
247
                te_ops.Sequential(
                    te_ops.Linear(
                        model_config.hidden_size,
                        model_config.hidden_size,
                        dtype=dtype,
                    ),
                )
248
249
                for _ in range(num_layers)
            ]
250
251
        else:
            raise ValueError(f"Unknown module type ({module})")
252

253
254
255
256
257
        # Initialize gradient buffers.
        for module in modules:
            for param in module.parameters():
                param.grad = torch.empty_like(param)

258
        # Generate model and wrap API to return graphed version.
259
260
        if graph_mode == "full":
            # Graph entire model at once.
261
            model = torch.nn.Sequential(*modules)
262
263
            model = make_graphed_callables(
                model,
264
                (generate_data(model_config, dtype, warmup=True),),
265
266
267
                num_warmup_iters=10,
                fp8_enabled=fp8,
                fp8_weight_caching=fp8_weight_caching,
268
                fp8_recipe=fp8_recipe,
269
270
            )
        elif graph_mode == "individual":
271
            # Graph individual modules.
272
273
            modules = [
                make_graphed_callables(
274
                    module,
275
                    (generate_data(model_config, dtype, warmup=True),),
276
                    num_warmup_iters=10,
277
278
                    fp8_enabled=fp8,
                    fp8_weight_caching=fp8_weight_caching,
279
                    fp8_recipe=fp8_recipe,
280
281
282
                )
                for module in modules
            ]
283
            model = _Sequential(*modules)
284
        else:
285
            model = _Sequential(*modules)
286

287
    # Optimizer.
288
    optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
289

290
    # Training steps.
291
    for _ in range(3):
292
        optimizer.zero_grad(set_to_none=False)
293
        for grad_accumulation_step in range(2):
294
295
            input_ = generate_data(model_config, dtype)
            grad_output = generate_data(model_config, dtype, requires_grad=False)
296
            with fp8_autocast(enabled=fp8, fp8_recipe=fp8_recipe):
297
298
                kwargs = {}
                if fp8_weight_caching:
299
                    kwargs["is_first_microbatch"] = grad_accumulation_step == 0
300
                output = model(input_, **kwargs)
301
            output.backward(grad_output)
302
        optimizer.step()
303
304
305
306

    return get_outputs(model, output)


307
@pytest.mark.parametrize("module", _test_cuda_graphs_modules)
308
@pytest.mark.parametrize("dtype", dtypes)
309
310
@pytest.mark.parametrize("fp8", (False, True))
@pytest.mark.parametrize("fp8_params", (False, True))
311
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
312
313
314
315
316
def test_make_graphed_callables(
    *,
    module: str,
    model_config: str = "small",
    num_layers: int = 3,
317
318
319
    dtype: torch.dtype,
    fp8: bool,
    fp8_params: bool,
320
    fp8_recipe: recipe.Recipe,
321
    fp8_weight_caching: bool = False,
322
) -> None:
323
324

    # Skip invalid configurations.
325
326
327
328
    if fp8 and not fp8_available:
        pytest.skip(reason_for_no_fp8)
    if fp8_params and not fp8:
        pytest.skip("FP8 needed for FP8 parameters.")
329
330
    if fp8_weight_caching and not fp8:
        pytest.skip("FP8 needed for FP8 parameters.")
331
332
    if fp8_recipe.mxfp8() and not mxfp8_available:
        pytest.skip(reason_for_no_mxfp8)
333

334
335
    # Run model with different CUDA graph settings.
    model_config = model_configs[model_config]
336
    kwargs = dict(
337
338
        module=module,
        model_config=model_config,
339
340
341
342
343
        num_layers=num_layers,
        dtype=dtype,
        fp8=fp8,
        fp8_params=fp8_params,
        fp8_weight_caching=fp8_weight_caching,
344
        fp8_recipe=fp8_recipe,
345
346
347
348
    )
    outputs = _test_cuda_graphs(graph_mode="none", **kwargs)
    graph_outputs_mode1 = _test_cuda_graphs(graph_mode="full", **kwargs)
    graph_outputs_mode2 = _test_cuda_graphs(graph_mode="individual", **kwargs)
349

350
    # Check that results match.
351
352
    assert_all_equal(outputs, graph_outputs_mode1)
    assert_all_equal(outputs, graph_outputs_mode2)
353
354


355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
_test_make_graphed_callables_with_fp8_weight_caching_modules = [
    "transformer",
    "layernorm_mlp",
    "layernorm_linear",
    "linear",
    "mha",
]


@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.parametrize(
    "module",
    _test_make_graphed_callables_with_fp8_weight_caching_modules,
)
@pytest.mark.parametrize("fp8_params", (False, True))
370
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
371
def test_make_graphed_callables_with_fp8_weight_caching(
372
    *,
373
374
    module: str,
    fp8_params: bool,
375
    fp8_recipe: recipe.Recipe,
376
377
378
379
380
381
) -> None:
    test_make_graphed_callables(
        module=module,
        dtype=torch.float32,
        fp8=True,
        fp8_params=fp8_params,
382
        fp8_recipe=fp8_recipe,
383
384
385
386
387
388
        fp8_weight_caching=True,
    )


def generate_data_for_dot_product_attention(
    model_config: ModelConfig,
389
    dtype: torch.dtype,
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
    warmup: bool = False,
) -> List[torch.Tensor]:
    """Generate synthetic data for dot product attention."""
    gen_func = torch.ones if warmup else torch.randn
    return [
        gen_func(
            model_config.sequence_length,
            model_config.batch_size,
            model_config.num_heads,
            model_config.kv_channels,
            device="cuda",
            requires_grad=True,
            dtype=dtype,
        )
        for _ in range(3)
    ]


def _test_cuda_graphs_with_dot_product_attention(
    *,
410
    with_graph: bool,
411
412
    model_config: ModelConfig,
    dtype: torch.dtype,
413
) -> List[torch.Tensor]:
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
    """Helper function for CUDA graph test."""
    reset_rng_states()
    FP8GlobalStateManager.reset()

    # Create dot product attention module.
    assert model_config.hidden_size % model_config.num_heads == 0
    model = DotProductAttention(
        model_config.num_heads,
        model_config.kv_channels,
        attention_dropout=0.0,
    )

    # Graph model if needed.
    if with_graph:
        model = make_graphed_callables(
            model,
            generate_data_for_dot_product_attention(model_config, dtype, warmup=True),
            num_warmup_iters=10,
            fp8_enabled=False,
        )

    # Forward and backward passes.
    for _ in range(3):
        inputs = generate_data_for_dot_product_attention(model_config, dtype)
        grad_output = generate_data(model_config, dtype, requires_grad=False)
        output = model(*inputs)
        output.backward(grad_output)

    return get_outputs(model, output)


@pytest.mark.parametrize("dtype", dtypes)
def test_make_graphed_callables_with_dot_product_attention(
    *,
    model_config: str = "small",
    dtype: torch.dtype,
) -> None:
    """Test CUDA graphs with dot product attention."""
    model_config = model_configs[model_config]
    kwargs = dict(model_config=model_config, dtype=dtype)
    outputs = _test_cuda_graphs_with_dot_product_attention(with_graph=False, **kwargs)
    graph_outputs = _test_cuda_graphs_with_dot_product_attention(with_graph=True, **kwargs)
    assert_all_equal(outputs, graph_outputs)


def _test_cuda_graphs_with_kwargs(
    *,
    with_graph: bool,
    model_config: ModelConfig,
    dtype: torch.dtype,
) -> List[torch.Tensor]:
    """Helper function for CUDA graph test with keyword arguments."""
466
467
468
469
    reset_rng_states()

    # Initialize model.
    model = TransformerLayer(
470
471
472
        model_config.hidden_size,
        model_config.hidden_size,
        model_config.num_heads,
473
474
475
476
477
478
479
480
481
482
483
484
485
486
        hidden_dropout=0.0,
        attention_dropout=0.0,
        self_attn_mask_type="arbitrary",
        fuse_qkv_params=True,
        params_dtype=dtype,
    )

    # Initialize gradient buffers.
    for param in model.parameters():
        param.grad = torch.empty_like(param)

    # Make graphed version of model if needed.
    if with_graph:
        attn_mask = torch.zeros(
487
488
489
490
491
492
            (
                model_config.batch_size,
                1,
                model_config.sequence_length,
                model_config.sequence_length,
            ),
493
494
495
496
497
            dtype=torch.bool,
            device="cuda",
        )
        model = make_graphed_callables(
            model,
498
            (generate_data(model_config, dtype, warmup=True),),
499
500
501
502
503
504
505
506
507
508
509
            sample_kwargs=dict(attention_mask=attn_mask),
            allow_unused_input=True,
        )

    # Optimizer.
    optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

    # Training loop.
    for _ in range(3):
        optimizer.zero_grad(set_to_none=False)
        for grad_accumulation_step in range(2):
510
511
            input_ = generate_data(model_config, dtype)
            grad_output = generate_data(model_config, dtype, requires_grad=False)
512
513
            attn_mask = torch.randint(
                2,
514
515
516
517
518
519
                (
                    model_config.batch_size,
                    1,
                    model_config.sequence_length,
                    model_config.sequence_length,
                ),
520
521
522
                dtype=torch.bool,
                device="cuda",
            )
523
            output = model(input_, attention_mask=attn_mask)
524
525
526
527
528
529
530
            output.backward(grad_output)
        optimizer.step()

    return get_outputs(model, output)


def test_make_graphed_callables_with_kwargs(
531
532
    *,
    model_config: str = "small",
533
534
535
    dtype: torch.dtype = torch.float32,
) -> None:
    """Test CUDA graphs with keyword arguments."""
536
537
    model_config = model_configs[model_config]
    kwargs = dict(model_config=model_config, dtype=dtype)
538
539
540
541
542
543
544
545
    outputs = _test_cuda_graphs_with_kwargs(with_graph=False, **kwargs)
    graph_outputs = _test_cuda_graphs_with_kwargs(with_graph=True, **kwargs)
    assert_all_equal(outputs, graph_outputs)


def _test_cuda_graphs_with_interleaved_pipeline_parallelism(
    *,
    with_graph: bool,
546
547
    model_config: ModelConfig,
    dtype: torch.dtype,
548
549
550
551
552
553
554
555
556
557
558
559
560
) -> List[torch.Tensor]:
    """Simulate Megatron-LM interleaved pipeline parallelism."""
    reset_rng_states()

    # Pipeline parallel configuration.
    num_layers = 2
    num_microbatches = 3
    layer_order = [1, 2, 1, 2, -2, -1, 1, 2, -2, -1, -2, -1]

    # Initialize model.
    model = torch.nn.ModuleList(
        [
            Linear(
561
562
                model_config.hidden_size,
                model_config.hidden_size,
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
                params_dtype=dtype,
            )
            for _ in range(num_layers)
        ]
    )

    # Initialize gradient buffers.
    for param in model.parameters():
        param.grad = torch.empty_like(param)

    # Make graphed version of model if needed.
    layer_forwards = {
        (i % num_layers, i // num_layers): model[i % num_layers]
        for i in range(num_layers * num_microbatches)
    }
    if with_graph:
        sample_args = tuple(
580
581
            (generate_data(model_config, dtype, warmup=True),)
            for _ in range(num_layers * num_microbatches)
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
        )
        layer_forwards = make_graphed_callables(
            tuple(model),
            sample_args,
            allow_unused_input=True,
            _order=layer_order,
        )
        layer_forwards = {
            (i // num_microbatches, i % num_microbatches): forward
            for i, forward in enumerate(layer_forwards)
        }

    # Optimizer.
    optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

    # Training loop.
    for _ in range(3):
        optimizer.zero_grad(set_to_none=False)

        # Generate data.
        inputs = {}
        grad_outputs = {}
        for layer_idx in range(num_layers):
            for microbatch_idx in range(num_microbatches):
606
607
                x = generate_data(model_config, dtype)
                dy = generate_data(model_config, dtype, requires_grad=False)
608
                idxs = (layer_idx, microbatch_idx)
609
                inputs[idxs] = x
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
                grad_outputs[idxs] = dy

        # Cache for layer outputs.
        outputs = {}

        def forward(layer_idx: int, microbatch_idx: int):
            """Helper function for forward steps"""
            idxs = (layer_idx, microbatch_idx)
            outputs[idxs] = layer_forwards[idxs](inputs[idxs])

        def backward(layer_idx: int, microbatch_idx: int):
            """Helper function for backward steps"""
            outputs[layer_idx, microbatch_idx].backward(grad_outputs[layer_idx, microbatch_idx])

        # Forward and backward steps.
        forward(0, 0)
        forward(1, 0)
        forward(0, 1)
        forward(1, 1)
        backward(1, 0)
        backward(0, 0)
        forward(0, 2)
        forward(1, 2)
        backward(1, 1)
        backward(0, 1)
        backward(1, 2)
        backward(0, 2)

        # Optimizer step.
        optimizer.step()

    outputs = [y for _, y in sorted(outputs.items())]
    return get_outputs(model, outputs)


def test_make_graphed_callables_with_interleaved_pipeline_parallelism(
646
647
    *,
    model_config: str = "small",
648
649
650
    dtype: torch.dtype = torch.float16,
) -> None:
    """Test CUDA graphs with Megatron-LM interleaved pipeline parallelism."""
651
652
    model_config = model_configs[model_config]
    kwargs = dict(model_config=model_config, dtype=dtype)
653
654
655
656
657
658
659
660
661
    outputs = _test_cuda_graphs_with_interleaved_pipeline_parallelism(
        with_graph=False,
        **kwargs,
    )
    graph_outputs = _test_cuda_graphs_with_interleaved_pipeline_parallelism(
        with_graph=True,
        **kwargs,
    )
    assert_all_equal(outputs, graph_outputs)