test_cuda_graphs.py 19.7 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.

5
from dataclasses import dataclass
6
7
import itertools
from typing import Iterable, List, Tuple, Union
8
9
10
11
import pytest

import torch
from transformer_engine.pytorch import (
12
13
14
15
16
17
18
19
    DotProductAttention,
    LayerNormLinear,
    LayerNormMLP,
    Linear,
    MultiheadAttention,
    TransformerLayer,
    fp8_autocast,
    fp8_model_init,
20
    make_graphed_callables,
21
22
23
)
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.utils import is_bf16_compatible
24
import transformer_engine.pytorch.ops as te_ops
25
from transformer_engine.common import recipe
yuguo's avatar
yuguo committed
26
27
28
29
from torch.utils.cpp_extension import IS_HIP_EXTENSION
if IS_HIP_EXTENSION:
    import os
    from functools import cache
30

31
# Check if FP8 is supported.
32
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
33
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
34
35


36
# Record initial RNG state.
37
38
39
40
41
42
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
_cpu_rng_state = torch.get_rng_state()
_cuda_rng_state = torch.cuda.get_rng_state()

43

44
@dataclass
45
class ModelConfig:
46
    """Data tensor dimensions within Transformer model"""
47

48
49
50
51
52
    sequence_length: int
    batch_size: int
    hidden_size: int
    num_heads: int
    kv_channels: int
53

54

55
model_configs = {"small": ModelConfig(2, 32, 64, 2, 32)}
56

57
58
59
60
61
fp8_recipes = [
    recipe.DelayedScaling(),
    recipe.MXFP8BlockScaling(),
]

62
63
# Supported data types
dtypes: List[torch.dtype] = [torch.float32, torch.float16]
64
65
66
67
68
if is_bf16_compatible():  # bf16 requires sm_80 or higher
    dtypes.append(torch.bfloat16)


def reset_rng_states() -> None:
69
    """Revert to initial RNG state."""
70
71
72
73
74
75
76
77
78
    torch.set_rng_state(_cpu_rng_state)
    torch.cuda.set_rng_state(_cuda_rng_state)


@pytest.fixture(autouse=True)
def reset_global_fp8_state():
    yield
    FP8GlobalStateManager.reset()

yuguo's avatar
yuguo committed
79
80
81
82
83
84
85
86
87
if IS_HIP_EXTENSION:
    @cache
    def use_hipblaslt() -> bool:
        return (os.getenv("NVTE_USE_HIPBLASLT") is not None
                or os.getenv("NVTE_USE_ROCBLAS") is None )
    @pytest.fixture(autouse=True)
    def skip_rocblas():
        if not use_hipblaslt():
            pytest.skip("CUDA graph capture not supported with rocBLAS path")
88
89

def assert_all_equal(l1: List[torch.Tensor], l2: List[torch.Tensor], names=None) -> bool:
90
    """Check that two lists of tensors match exactly."""
91
    assert len(l1) == len(l2), "Unequal number of outputs."
92
93
    failure_message = "Output mismatches in:"
    failed_tensors = []
94
95
    for i, (t1, t2) in enumerate(zip(l1, l2)):
        if not torch.equal(t1, t2):
96
97
98
99
100
101
102
103
104
105
            failure_message += "\n    "
            if names is None:
                failure_message += f"tensor at idx={i}"
            else:
                failure_message += names[i]
            failed_tensors.append((t1, t2))
    if failed_tensors:
        print(failure_message)
        t1, t2 = failed_tensors[0]
        torch.testing.assert_close(t1, t2, rtol=0, atol=0)
106
107
108


def generate_data(
109
    model_config: ModelConfig,
110
111
    dtype: torch.dtype,
    warmup: bool = False,
112
113
    requires_grad: bool = True,
) -> torch.Tensor:
114
115
    """Generate synthetic data."""
    gen_func = torch.ones if warmup else torch.randn
116
117
118
119
    return gen_func(
        model_config.sequence_length,
        model_config.batch_size,
        model_config.hidden_size,
120
        device="cuda",
121
        requires_grad=requires_grad,
122
123
        dtype=dtype,
    )
124
125


126
127
128
129
def get_outputs(
    model: torch.nn.Module,
    output: Union[torch.Tensor, Iterable[torch.Tensor]],
) -> List[torch.Tensor]:
130
131
132
133
134
135
    """Return grads and params for comparsion."""
    values = []
    for param in model.parameters():
        values.append(param)
        if param.grad is not None:
            values.append(param.grad)
136
137
138
139
    if isinstance(output, torch.Tensor):
        values.append(output)
    else:
        values.extend(output)
140
141
142
    return values


143
144
145
146
147
148
149
150
151
152
class _Sequential(torch.nn.Sequential):
    """Sequential model that forwards keyword arguments to modules"""

    def forward(self, input_: torch.Tensor, **kwargs) -> torch.Tensor:
        x = input_
        for module in self:
            x = module(x, **kwargs)
        return x


153
154
155
156
157
158
159
160
161
162
163
# Supported modules
_test_cuda_graphs_modules: List[str] = [
    "transformer",
    "layernorm_mlp",
    "layernorm_linear",
    "linear",
    "mha",
    "linear_op",
]


164
165
def _test_cuda_graphs(
    *,
166
167
168
    graph_mode: str,
    module: str,
    model_config: ModelConfig,
169
170
171
172
173
    num_layers: int,
    dtype: torch.dtype,
    fp8: bool,
    fp8_params: bool,
    fp8_weight_caching: bool,
174
    fp8_recipe: recipe.Recipe,
175
) -> List[torch.Tensor]:
176
    """Helper function for CUDA graph test."""
177
178
179
    reset_rng_states()
    FP8GlobalStateManager.reset()

180
181
182
183
184
    # Operation-based API does not support FP8 weight caching.
    if module == "linear_op":
        fp8_weight_caching = False

    # Create modules.
185
    with fp8_model_init(enabled=fp8_params, recipe=fp8_recipe):
186
        if module == "transformer":
187
188
            modules = [
                TransformerLayer(
189
190
191
                    model_config.hidden_size,
                    model_config.hidden_size,
                    model_config.num_heads,
192
193
194
195
196
197
198
                    hidden_dropout=0.0,
                    attention_dropout=0.0,
                    fuse_qkv_params=True,
                    params_dtype=dtype,
                )
                for _ in range(num_layers)
            ]
199
        elif module == "layernorm_mlp":
200
            modules = [
201
202
203
204
205
                LayerNormMLP(
                    model_config.hidden_size,
                    model_config.hidden_size,
                    params_dtype=dtype,
                )
206
207
                for _ in range(num_layers)
            ]
208
        elif module == "layernorm_linear":
209
            modules = [
210
211
212
213
214
                LayerNormLinear(
                    model_config.hidden_size,
                    model_config.hidden_size,
                    params_dtype=dtype,
                )
215
216
                for _ in range(num_layers)
            ]
217
        elif module == "mha":
218
219
            modules = [
                MultiheadAttention(
220
221
                    model_config.hidden_size,
                    model_config.num_heads,
222
223
224
225
226
227
                    attention_dropout=0.0,
                    params_dtype=dtype,
                    fuse_qkv_params=True,
                )
                for _ in range(num_layers)
            ]
228
        elif module == "linear":
229
            modules = [
230
231
232
233
234
235
                Linear(
                    model_config.hidden_size,
                    model_config.hidden_size,
                    device="cuda",
                    params_dtype=dtype,
                )
236
237
                for _ in range(num_layers)
            ]
238
        elif module == "linear_op":
239
            modules = [
240
241
242
243
244
245
246
                te_ops.Sequential(
                    te_ops.Linear(
                        model_config.hidden_size,
                        model_config.hidden_size,
                        dtype=dtype,
                    ),
                )
247
248
                for _ in range(num_layers)
            ]
249
250
        else:
            raise ValueError(f"Unknown module type ({module})")
251

252
253
254
255
256
        # Initialize gradient buffers.
        for module in modules:
            for param in module.parameters():
                param.grad = torch.empty_like(param)

257
        # Generate model and wrap API to return graphed version.
258
259
        if graph_mode == "full":
            # Graph entire model at once.
260
            model = torch.nn.Sequential(*modules)
261
262
            model = make_graphed_callables(
                model,
263
                (generate_data(model_config, dtype, warmup=True),),
264
265
266
                num_warmup_iters=10,
                fp8_enabled=fp8,
                fp8_weight_caching=fp8_weight_caching,
267
                fp8_recipe=fp8_recipe,
268
269
            )
        elif graph_mode == "individual":
270
            # Graph individual modules.
271
272
            modules = [
                make_graphed_callables(
273
                    module,
274
                    (generate_data(model_config, dtype, warmup=True),),
275
                    num_warmup_iters=10,
276
277
                    fp8_enabled=fp8,
                    fp8_weight_caching=fp8_weight_caching,
278
                    fp8_recipe=fp8_recipe,
279
280
281
                )
                for module in modules
            ]
282
            model = _Sequential(*modules)
283
        else:
284
            model = _Sequential(*modules)
285

286
    # Optimizer.
287
    optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
288

289
    # Training steps.
290
    for _ in range(3):
291
        optimizer.zero_grad(set_to_none=False)
292
        for grad_accumulation_step in range(2):
293
294
            input_ = generate_data(model_config, dtype)
            grad_output = generate_data(model_config, dtype, requires_grad=False)
295
            with fp8_autocast(enabled=fp8, fp8_recipe=fp8_recipe):
296
297
                kwargs = {}
                if fp8_weight_caching:
298
                    kwargs["is_first_microbatch"] = grad_accumulation_step == 0
299
                output = model(input_, **kwargs)
300
            output.backward(grad_output)
301
        optimizer.step()
302
303
304
305

    return get_outputs(model, output)


306
@pytest.mark.parametrize("module", _test_cuda_graphs_modules)
307
@pytest.mark.parametrize("dtype", dtypes)
308
309
@pytest.mark.parametrize("fp8", (False, True))
@pytest.mark.parametrize("fp8_params", (False, True))
310
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
311
312
313
314
315
def test_make_graphed_callables(
    *,
    module: str,
    model_config: str = "small",
    num_layers: int = 3,
316
317
318
    dtype: torch.dtype,
    fp8: bool,
    fp8_params: bool,
319
    fp8_recipe: recipe.Recipe,
320
    fp8_weight_caching: bool = False,
321
) -> None:
322
323

    # Skip invalid configurations.
324
325
326
327
    if fp8 and not fp8_available:
        pytest.skip(reason_for_no_fp8)
    if fp8_params and not fp8:
        pytest.skip("FP8 needed for FP8 parameters.")
328
329
    if fp8_weight_caching and not fp8:
        pytest.skip("FP8 needed for FP8 parameters.")
330
331
    if fp8_recipe.mxfp8() and not mxfp8_available:
        pytest.skip(reason_for_no_mxfp8)
332

333
334
    # Run model with different CUDA graph settings.
    model_config = model_configs[model_config]
335
    kwargs = dict(
336
337
        module=module,
        model_config=model_config,
338
339
340
341
342
        num_layers=num_layers,
        dtype=dtype,
        fp8=fp8,
        fp8_params=fp8_params,
        fp8_weight_caching=fp8_weight_caching,
343
        fp8_recipe=fp8_recipe,
344
345
346
347
    )
    outputs = _test_cuda_graphs(graph_mode="none", **kwargs)
    graph_outputs_mode1 = _test_cuda_graphs(graph_mode="full", **kwargs)
    graph_outputs_mode2 = _test_cuda_graphs(graph_mode="individual", **kwargs)
348

349
    # Check that results match.
350
351
    assert_all_equal(outputs, graph_outputs_mode1)
    assert_all_equal(outputs, graph_outputs_mode2)
352
353


354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
_test_make_graphed_callables_with_fp8_weight_caching_modules = [
    "transformer",
    "layernorm_mlp",
    "layernorm_linear",
    "linear",
    "mha",
]


@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.parametrize(
    "module",
    _test_make_graphed_callables_with_fp8_weight_caching_modules,
)
@pytest.mark.parametrize("fp8_params", (False, True))
369
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
370
def test_make_graphed_callables_with_fp8_weight_caching(
371
    *,
372
373
    module: str,
    fp8_params: bool,
374
    fp8_recipe: recipe.Recipe,
375
376
377
378
379
380
) -> None:
    test_make_graphed_callables(
        module=module,
        dtype=torch.float32,
        fp8=True,
        fp8_params=fp8_params,
381
        fp8_recipe=fp8_recipe,
382
383
384
385
386
387
        fp8_weight_caching=True,
    )


def generate_data_for_dot_product_attention(
    model_config: ModelConfig,
388
    dtype: torch.dtype,
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
    warmup: bool = False,
) -> List[torch.Tensor]:
    """Generate synthetic data for dot product attention."""
    gen_func = torch.ones if warmup else torch.randn
    return [
        gen_func(
            model_config.sequence_length,
            model_config.batch_size,
            model_config.num_heads,
            model_config.kv_channels,
            device="cuda",
            requires_grad=True,
            dtype=dtype,
        )
        for _ in range(3)
    ]


def _test_cuda_graphs_with_dot_product_attention(
    *,
409
    with_graph: bool,
410
411
    model_config: ModelConfig,
    dtype: torch.dtype,
412
) -> List[torch.Tensor]:
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
    """Helper function for CUDA graph test."""
    reset_rng_states()
    FP8GlobalStateManager.reset()

    # Create dot product attention module.
    assert model_config.hidden_size % model_config.num_heads == 0
    model = DotProductAttention(
        model_config.num_heads,
        model_config.kv_channels,
        attention_dropout=0.0,
    )

    # Graph model if needed.
    if with_graph:
        model = make_graphed_callables(
            model,
            generate_data_for_dot_product_attention(model_config, dtype, warmup=True),
            num_warmup_iters=10,
            fp8_enabled=False,
        )

    # Forward and backward passes.
    for _ in range(3):
        inputs = generate_data_for_dot_product_attention(model_config, dtype)
        grad_output = generate_data(model_config, dtype, requires_grad=False)
        output = model(*inputs)
        output.backward(grad_output)

    return get_outputs(model, output)


@pytest.mark.parametrize("dtype", dtypes)
def test_make_graphed_callables_with_dot_product_attention(
    *,
    model_config: str = "small",
    dtype: torch.dtype,
) -> None:
    """Test CUDA graphs with dot product attention."""
    model_config = model_configs[model_config]
    kwargs = dict(model_config=model_config, dtype=dtype)
    outputs = _test_cuda_graphs_with_dot_product_attention(with_graph=False, **kwargs)
    graph_outputs = _test_cuda_graphs_with_dot_product_attention(with_graph=True, **kwargs)
    assert_all_equal(outputs, graph_outputs)


def _test_cuda_graphs_with_kwargs(
    *,
    with_graph: bool,
    model_config: ModelConfig,
    dtype: torch.dtype,
) -> List[torch.Tensor]:
    """Helper function for CUDA graph test with keyword arguments."""
465
466
467
468
    reset_rng_states()

    # Initialize model.
    model = TransformerLayer(
469
470
471
        model_config.hidden_size,
        model_config.hidden_size,
        model_config.num_heads,
472
473
474
475
476
477
478
479
480
481
482
483
484
485
        hidden_dropout=0.0,
        attention_dropout=0.0,
        self_attn_mask_type="arbitrary",
        fuse_qkv_params=True,
        params_dtype=dtype,
    )

    # Initialize gradient buffers.
    for param in model.parameters():
        param.grad = torch.empty_like(param)

    # Make graphed version of model if needed.
    if with_graph:
        attn_mask = torch.zeros(
486
487
488
489
490
491
            (
                model_config.batch_size,
                1,
                model_config.sequence_length,
                model_config.sequence_length,
            ),
492
493
494
495
496
            dtype=torch.bool,
            device="cuda",
        )
        model = make_graphed_callables(
            model,
497
            (generate_data(model_config, dtype, warmup=True),),
498
499
500
501
502
503
504
505
506
507
508
            sample_kwargs=dict(attention_mask=attn_mask),
            allow_unused_input=True,
        )

    # Optimizer.
    optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

    # Training loop.
    for _ in range(3):
        optimizer.zero_grad(set_to_none=False)
        for grad_accumulation_step in range(2):
509
510
            input_ = generate_data(model_config, dtype)
            grad_output = generate_data(model_config, dtype, requires_grad=False)
511
512
            attn_mask = torch.randint(
                2,
513
514
515
516
517
518
                (
                    model_config.batch_size,
                    1,
                    model_config.sequence_length,
                    model_config.sequence_length,
                ),
519
520
521
                dtype=torch.bool,
                device="cuda",
            )
522
            output = model(input_, attention_mask=attn_mask)
523
524
525
526
527
528
529
            output.backward(grad_output)
        optimizer.step()

    return get_outputs(model, output)


def test_make_graphed_callables_with_kwargs(
530
531
    *,
    model_config: str = "small",
532
533
534
    dtype: torch.dtype = torch.float32,
) -> None:
    """Test CUDA graphs with keyword arguments."""
535
536
    model_config = model_configs[model_config]
    kwargs = dict(model_config=model_config, dtype=dtype)
537
538
539
540
541
542
543
544
    outputs = _test_cuda_graphs_with_kwargs(with_graph=False, **kwargs)
    graph_outputs = _test_cuda_graphs_with_kwargs(with_graph=True, **kwargs)
    assert_all_equal(outputs, graph_outputs)


def _test_cuda_graphs_with_interleaved_pipeline_parallelism(
    *,
    with_graph: bool,
545
546
    model_config: ModelConfig,
    dtype: torch.dtype,
547
548
549
550
551
552
553
554
555
556
557
558
559
) -> List[torch.Tensor]:
    """Simulate Megatron-LM interleaved pipeline parallelism."""
    reset_rng_states()

    # Pipeline parallel configuration.
    num_layers = 2
    num_microbatches = 3
    layer_order = [1, 2, 1, 2, -2, -1, 1, 2, -2, -1, -2, -1]

    # Initialize model.
    model = torch.nn.ModuleList(
        [
            Linear(
560
561
                model_config.hidden_size,
                model_config.hidden_size,
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
                params_dtype=dtype,
            )
            for _ in range(num_layers)
        ]
    )

    # Initialize gradient buffers.
    for param in model.parameters():
        param.grad = torch.empty_like(param)

    # Make graphed version of model if needed.
    layer_forwards = {
        (i % num_layers, i // num_layers): model[i % num_layers]
        for i in range(num_layers * num_microbatches)
    }
    if with_graph:
        sample_args = tuple(
579
580
            (generate_data(model_config, dtype, warmup=True),)
            for _ in range(num_layers * num_microbatches)
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
        )
        layer_forwards = make_graphed_callables(
            tuple(model),
            sample_args,
            allow_unused_input=True,
            _order=layer_order,
        )
        layer_forwards = {
            (i // num_microbatches, i % num_microbatches): forward
            for i, forward in enumerate(layer_forwards)
        }

    # Optimizer.
    optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

    # Training loop.
    for _ in range(3):
        optimizer.zero_grad(set_to_none=False)

        # Generate data.
        inputs = {}
        grad_outputs = {}
        for layer_idx in range(num_layers):
            for microbatch_idx in range(num_microbatches):
605
606
                x = generate_data(model_config, dtype)
                dy = generate_data(model_config, dtype, requires_grad=False)
607
                idxs = (layer_idx, microbatch_idx)
608
                inputs[idxs] = x
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
                grad_outputs[idxs] = dy

        # Cache for layer outputs.
        outputs = {}

        def forward(layer_idx: int, microbatch_idx: int):
            """Helper function for forward steps"""
            idxs = (layer_idx, microbatch_idx)
            outputs[idxs] = layer_forwards[idxs](inputs[idxs])

        def backward(layer_idx: int, microbatch_idx: int):
            """Helper function for backward steps"""
            outputs[layer_idx, microbatch_idx].backward(grad_outputs[layer_idx, microbatch_idx])

        # Forward and backward steps.
        forward(0, 0)
        forward(1, 0)
        forward(0, 1)
        forward(1, 1)
        backward(1, 0)
        backward(0, 0)
        forward(0, 2)
        forward(1, 2)
        backward(1, 1)
        backward(0, 1)
        backward(1, 2)
        backward(0, 2)

        # Optimizer step.
        optimizer.step()

    outputs = [y for _, y in sorted(outputs.items())]
    return get_outputs(model, outputs)


def test_make_graphed_callables_with_interleaved_pipeline_parallelism(
645
646
    *,
    model_config: str = "small",
647
648
649
    dtype: torch.dtype = torch.float16,
) -> None:
    """Test CUDA graphs with Megatron-LM interleaved pipeline parallelism."""
650
651
    model_config = model_configs[model_config]
    kwargs = dict(model_config=model_config, dtype=dtype)
652
653
654
655
656
657
658
659
660
    outputs = _test_cuda_graphs_with_interleaved_pipeline_parallelism(
        with_graph=False,
        **kwargs,
    )
    graph_outputs = _test_cuda_graphs_with_interleaved_pipeline_parallelism(
        with_graph=True,
        **kwargs,
    )
    assert_all_equal(outputs, graph_outputs)