test_cuda_graphs.py 19.7 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.

5
from dataclasses import dataclass
6
7
import itertools
from typing import Iterable, List, Tuple, Union
8
9
10
11
import pytest

import torch
from transformer_engine.pytorch import (
12
13
14
15
16
17
18
19
    DotProductAttention,
    LayerNormLinear,
    LayerNormMLP,
    Linear,
    MultiheadAttention,
    TransformerLayer,
    fp8_autocast,
    fp8_model_init,
20
    make_graphed_callables,
21
22
23
)
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.utils import is_bf16_compatible
24
import transformer_engine.pytorch.ops as te_ops
25
from transformer_engine.common import recipe
26
27


28
# Check if FP8 is supported.
29
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
30
31
32
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = (
    FP8GlobalStateManager.is_fp8_block_scaling_available()
)
33
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
34
35


36
# Record initial RNG state.
37
38
39
40
41
42
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
_cpu_rng_state = torch.get_rng_state()
_cuda_rng_state = torch.cuda.get_rng_state()

43

44
@dataclass
45
class ModelConfig:
46
    """Data tensor dimensions within Transformer model"""
47

48
49
50
51
52
    sequence_length: int
    batch_size: int
    hidden_size: int
    num_heads: int
    kv_channels: int
53

54

55
model_configs = {"small": ModelConfig(2, 32, 64, 2, 32)}
56

57
58
59
fp8_recipes = [
    recipe.DelayedScaling(),
    recipe.MXFP8BlockScaling(),
60
    recipe.Float8CurrentScaling(),
61
    recipe.Float8BlockScaling(),
62
63
]

64
65
# Supported data types
dtypes: List[torch.dtype] = [torch.float32, torch.float16]
66
67
68
69
70
if is_bf16_compatible():  # bf16 requires sm_80 or higher
    dtypes.append(torch.bfloat16)


def reset_rng_states() -> None:
71
    """Revert to initial RNG state."""
72
73
74
75
76
77
78
79
80
81
82
    torch.set_rng_state(_cpu_rng_state)
    torch.cuda.set_rng_state(_cuda_rng_state)


@pytest.fixture(autouse=True)
def reset_global_fp8_state():
    yield
    FP8GlobalStateManager.reset()


def assert_all_equal(l1: List[torch.Tensor], l2: List[torch.Tensor], names=None) -> bool:
83
    """Check that two lists of tensors match exactly."""
84
    assert len(l1) == len(l2), "Unequal number of outputs."
85
86
    failure_message = "Output mismatches in:"
    failed_tensors = []
87
88
    for i, (t1, t2) in enumerate(zip(l1, l2)):
        if not torch.equal(t1, t2):
89
90
91
92
93
94
95
96
97
98
            failure_message += "\n    "
            if names is None:
                failure_message += f"tensor at idx={i}"
            else:
                failure_message += names[i]
            failed_tensors.append((t1, t2))
    if failed_tensors:
        print(failure_message)
        t1, t2 = failed_tensors[0]
        torch.testing.assert_close(t1, t2, rtol=0, atol=0)
99
100
101


def generate_data(
102
    model_config: ModelConfig,
103
104
    dtype: torch.dtype,
    warmup: bool = False,
105
106
    requires_grad: bool = True,
) -> torch.Tensor:
107
108
    """Generate synthetic data."""
    gen_func = torch.ones if warmup else torch.randn
109
110
111
112
    return gen_func(
        model_config.sequence_length,
        model_config.batch_size,
        model_config.hidden_size,
113
        device="cuda",
114
        requires_grad=requires_grad,
115
116
        dtype=dtype,
    )
117
118


119
120
121
122
def get_outputs(
    model: torch.nn.Module,
    output: Union[torch.Tensor, Iterable[torch.Tensor]],
) -> List[torch.Tensor]:
123
124
125
126
127
128
    """Return grads and params for comparsion."""
    values = []
    for param in model.parameters():
        values.append(param)
        if param.grad is not None:
            values.append(param.grad)
129
130
131
132
    if isinstance(output, torch.Tensor):
        values.append(output)
    else:
        values.extend(output)
133
134
135
    return values


136
137
138
139
140
141
142
143
144
145
class _Sequential(torch.nn.Sequential):
    """Sequential model that forwards keyword arguments to modules"""

    def forward(self, input_: torch.Tensor, **kwargs) -> torch.Tensor:
        x = input_
        for module in self:
            x = module(x, **kwargs)
        return x


146
147
148
149
150
151
152
153
154
155
156
# Supported modules
_test_cuda_graphs_modules: List[str] = [
    "transformer",
    "layernorm_mlp",
    "layernorm_linear",
    "linear",
    "mha",
    "linear_op",
]


157
158
def _test_cuda_graphs(
    *,
159
160
161
    graph_mode: str,
    module: str,
    model_config: ModelConfig,
162
163
164
165
166
    num_layers: int,
    dtype: torch.dtype,
    fp8: bool,
    fp8_params: bool,
    fp8_weight_caching: bool,
167
    fp8_recipe: recipe.Recipe,
168
) -> List[torch.Tensor]:
169
    """Helper function for CUDA graph test."""
170
171
172
    reset_rng_states()
    FP8GlobalStateManager.reset()

173
174
175
176
177
    # Operation-based API does not support FP8 weight caching.
    if module == "linear_op":
        fp8_weight_caching = False

    # Create modules.
178
    with fp8_model_init(enabled=fp8_params, recipe=fp8_recipe):
179
        if module == "transformer":
180
181
            modules = [
                TransformerLayer(
182
183
184
                    model_config.hidden_size,
                    model_config.hidden_size,
                    model_config.num_heads,
185
186
187
188
189
190
191
                    hidden_dropout=0.0,
                    attention_dropout=0.0,
                    fuse_qkv_params=True,
                    params_dtype=dtype,
                )
                for _ in range(num_layers)
            ]
192
        elif module == "layernorm_mlp":
193
            modules = [
194
195
196
197
198
                LayerNormMLP(
                    model_config.hidden_size,
                    model_config.hidden_size,
                    params_dtype=dtype,
                )
199
200
                for _ in range(num_layers)
            ]
201
        elif module == "layernorm_linear":
202
            modules = [
203
204
205
206
207
                LayerNormLinear(
                    model_config.hidden_size,
                    model_config.hidden_size,
                    params_dtype=dtype,
                )
208
209
                for _ in range(num_layers)
            ]
210
        elif module == "mha":
211
212
            modules = [
                MultiheadAttention(
213
214
                    model_config.hidden_size,
                    model_config.num_heads,
215
216
217
218
219
220
                    attention_dropout=0.0,
                    params_dtype=dtype,
                    fuse_qkv_params=True,
                )
                for _ in range(num_layers)
            ]
221
        elif module == "linear":
222
            modules = [
223
224
225
226
227
228
                Linear(
                    model_config.hidden_size,
                    model_config.hidden_size,
                    device="cuda",
                    params_dtype=dtype,
                )
229
230
                for _ in range(num_layers)
            ]
231
        elif module == "linear_op":
232
            modules = [
233
234
235
236
237
238
239
                te_ops.Sequential(
                    te_ops.Linear(
                        model_config.hidden_size,
                        model_config.hidden_size,
                        dtype=dtype,
                    ),
                )
240
241
                for _ in range(num_layers)
            ]
242
243
        else:
            raise ValueError(f"Unknown module type ({module})")
244

245
246
247
248
249
        # Initialize gradient buffers.
        for module in modules:
            for param in module.parameters():
                param.grad = torch.empty_like(param)

250
        # Generate model and wrap API to return graphed version.
251
252
        if graph_mode == "full":
            # Graph entire model at once.
253
            model = torch.nn.Sequential(*modules)
254
255
            model = make_graphed_callables(
                model,
256
                (generate_data(model_config, dtype, warmup=True),),
257
258
259
                num_warmup_iters=10,
                fp8_enabled=fp8,
                fp8_weight_caching=fp8_weight_caching,
260
                fp8_recipe=fp8_recipe,
261
262
            )
        elif graph_mode == "individual":
263
            # Graph individual modules.
264
265
            modules = [
                make_graphed_callables(
266
                    module,
267
                    (generate_data(model_config, dtype, warmup=True),),
268
                    num_warmup_iters=10,
269
270
                    fp8_enabled=fp8,
                    fp8_weight_caching=fp8_weight_caching,
271
                    fp8_recipe=fp8_recipe,
272
273
274
                )
                for module in modules
            ]
275
            model = _Sequential(*modules)
276
        else:
277
            model = _Sequential(*modules)
278

279
    # Optimizer.
280
    optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
281

282
    # Training steps.
283
    for _ in range(3):
284
        optimizer.zero_grad(set_to_none=False)
285
        for grad_accumulation_step in range(2):
286
287
            input_ = generate_data(model_config, dtype)
            grad_output = generate_data(model_config, dtype, requires_grad=False)
288
            with fp8_autocast(enabled=fp8, fp8_recipe=fp8_recipe):
289
290
                kwargs = {}
                if fp8_weight_caching:
291
                    kwargs["is_first_microbatch"] = grad_accumulation_step == 0
292
                output = model(input_, **kwargs)
293
            output.backward(grad_output)
294
        optimizer.step()
295
296
297
298

    return get_outputs(model, output)


299
@pytest.mark.parametrize("module", _test_cuda_graphs_modules)
300
@pytest.mark.parametrize("dtype", dtypes)
301
302
@pytest.mark.parametrize("fp8", (False, True))
@pytest.mark.parametrize("fp8_params", (False, True))
303
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
304
305
306
307
308
def test_make_graphed_callables(
    *,
    module: str,
    model_config: str = "small",
    num_layers: int = 3,
309
310
311
    dtype: torch.dtype,
    fp8: bool,
    fp8_params: bool,
312
    fp8_recipe: recipe.Recipe,
313
    fp8_weight_caching: bool = False,
314
) -> None:
315
316

    # Skip invalid configurations.
317
318
319
320
    if fp8 and not fp8_available:
        pytest.skip(reason_for_no_fp8)
    if fp8_params and not fp8:
        pytest.skip("FP8 needed for FP8 parameters.")
321
322
    if fp8_weight_caching and not fp8:
        pytest.skip("FP8 needed for FP8 parameters.")
323
324
    if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
        pytest.skip(reason_for_no_fp8_block_scaling)
325
326
    if fp8_recipe.mxfp8() and not mxfp8_available:
        pytest.skip(reason_for_no_mxfp8)
327

328
329
    if fp8_recipe.float8_block_scaling() and module == "linear_op":
        pytest.skip("Module not yet supported for float8_block_scaling with CUDA graphs")
330
331
    # Run model with different CUDA graph settings.
    model_config = model_configs[model_config]
332
    kwargs = dict(
333
334
        module=module,
        model_config=model_config,
335
336
337
338
339
        num_layers=num_layers,
        dtype=dtype,
        fp8=fp8,
        fp8_params=fp8_params,
        fp8_weight_caching=fp8_weight_caching,
340
        fp8_recipe=fp8_recipe,
341
342
343
344
    )
    outputs = _test_cuda_graphs(graph_mode="none", **kwargs)
    graph_outputs_mode1 = _test_cuda_graphs(graph_mode="full", **kwargs)
    graph_outputs_mode2 = _test_cuda_graphs(graph_mode="individual", **kwargs)
345

346
    # Check that results match.
347
348
    assert_all_equal(outputs, graph_outputs_mode1)
    assert_all_equal(outputs, graph_outputs_mode2)
349
350


351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
_test_make_graphed_callables_with_fp8_weight_caching_modules = [
    "transformer",
    "layernorm_mlp",
    "layernorm_linear",
    "linear",
    "mha",
]


@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.parametrize(
    "module",
    _test_make_graphed_callables_with_fp8_weight_caching_modules,
)
@pytest.mark.parametrize("fp8_params", (False, True))
366
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
367
def test_make_graphed_callables_with_fp8_weight_caching(
368
    *,
369
370
    module: str,
    fp8_params: bool,
371
    fp8_recipe: recipe.Recipe,
372
373
374
375
376
377
) -> None:
    test_make_graphed_callables(
        module=module,
        dtype=torch.float32,
        fp8=True,
        fp8_params=fp8_params,
378
        fp8_recipe=fp8_recipe,
379
380
381
382
383
384
        fp8_weight_caching=True,
    )


def generate_data_for_dot_product_attention(
    model_config: ModelConfig,
385
    dtype: torch.dtype,
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
    warmup: bool = False,
) -> List[torch.Tensor]:
    """Generate synthetic data for dot product attention."""
    gen_func = torch.ones if warmup else torch.randn
    return [
        gen_func(
            model_config.sequence_length,
            model_config.batch_size,
            model_config.num_heads,
            model_config.kv_channels,
            device="cuda",
            requires_grad=True,
            dtype=dtype,
        )
        for _ in range(3)
    ]


def _test_cuda_graphs_with_dot_product_attention(
    *,
406
    with_graph: bool,
407
408
    model_config: ModelConfig,
    dtype: torch.dtype,
409
) -> List[torch.Tensor]:
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
    """Helper function for CUDA graph test."""
    reset_rng_states()
    FP8GlobalStateManager.reset()

    # Create dot product attention module.
    assert model_config.hidden_size % model_config.num_heads == 0
    model = DotProductAttention(
        model_config.num_heads,
        model_config.kv_channels,
        attention_dropout=0.0,
    )

    # Graph model if needed.
    if with_graph:
        model = make_graphed_callables(
            model,
            generate_data_for_dot_product_attention(model_config, dtype, warmup=True),
            num_warmup_iters=10,
            fp8_enabled=False,
        )

    # Forward and backward passes.
    for _ in range(3):
        inputs = generate_data_for_dot_product_attention(model_config, dtype)
        grad_output = generate_data(model_config, dtype, requires_grad=False)
        output = model(*inputs)
        output.backward(grad_output)

    return get_outputs(model, output)


@pytest.mark.parametrize("dtype", dtypes)
def test_make_graphed_callables_with_dot_product_attention(
    *,
    model_config: str = "small",
    dtype: torch.dtype,
) -> None:
    """Test CUDA graphs with dot product attention."""
    model_config = model_configs[model_config]
    kwargs = dict(model_config=model_config, dtype=dtype)
    outputs = _test_cuda_graphs_with_dot_product_attention(with_graph=False, **kwargs)
    graph_outputs = _test_cuda_graphs_with_dot_product_attention(with_graph=True, **kwargs)
    assert_all_equal(outputs, graph_outputs)


def _test_cuda_graphs_with_kwargs(
    *,
    with_graph: bool,
    model_config: ModelConfig,
    dtype: torch.dtype,
) -> List[torch.Tensor]:
    """Helper function for CUDA graph test with keyword arguments."""
462
463
464
465
    reset_rng_states()

    # Initialize model.
    model = TransformerLayer(
466
467
468
        model_config.hidden_size,
        model_config.hidden_size,
        model_config.num_heads,
469
470
471
472
473
474
475
476
477
478
479
480
481
482
        hidden_dropout=0.0,
        attention_dropout=0.0,
        self_attn_mask_type="arbitrary",
        fuse_qkv_params=True,
        params_dtype=dtype,
    )

    # Initialize gradient buffers.
    for param in model.parameters():
        param.grad = torch.empty_like(param)

    # Make graphed version of model if needed.
    if with_graph:
        attn_mask = torch.zeros(
483
484
485
486
487
488
            (
                model_config.batch_size,
                1,
                model_config.sequence_length,
                model_config.sequence_length,
            ),
489
490
491
492
493
            dtype=torch.bool,
            device="cuda",
        )
        model = make_graphed_callables(
            model,
494
            (generate_data(model_config, dtype, warmup=True),),
495
496
497
498
499
500
501
502
503
504
505
            sample_kwargs=dict(attention_mask=attn_mask),
            allow_unused_input=True,
        )

    # Optimizer.
    optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

    # Training loop.
    for _ in range(3):
        optimizer.zero_grad(set_to_none=False)
        for grad_accumulation_step in range(2):
506
507
            input_ = generate_data(model_config, dtype)
            grad_output = generate_data(model_config, dtype, requires_grad=False)
508
509
            attn_mask = torch.randint(
                2,
510
511
512
513
514
515
                (
                    model_config.batch_size,
                    1,
                    model_config.sequence_length,
                    model_config.sequence_length,
                ),
516
517
518
                dtype=torch.bool,
                device="cuda",
            )
519
            output = model(input_, attention_mask=attn_mask)
520
521
522
523
524
525
526
            output.backward(grad_output)
        optimizer.step()

    return get_outputs(model, output)


def test_make_graphed_callables_with_kwargs(
527
528
    *,
    model_config: str = "small",
529
530
531
    dtype: torch.dtype = torch.float32,
) -> None:
    """Test CUDA graphs with keyword arguments."""
532
533
    model_config = model_configs[model_config]
    kwargs = dict(model_config=model_config, dtype=dtype)
534
535
536
537
538
539
540
541
    outputs = _test_cuda_graphs_with_kwargs(with_graph=False, **kwargs)
    graph_outputs = _test_cuda_graphs_with_kwargs(with_graph=True, **kwargs)
    assert_all_equal(outputs, graph_outputs)


def _test_cuda_graphs_with_interleaved_pipeline_parallelism(
    *,
    with_graph: bool,
542
543
    model_config: ModelConfig,
    dtype: torch.dtype,
544
545
546
547
548
549
550
551
552
553
554
555
556
) -> List[torch.Tensor]:
    """Simulate Megatron-LM interleaved pipeline parallelism."""
    reset_rng_states()

    # Pipeline parallel configuration.
    num_layers = 2
    num_microbatches = 3
    layer_order = [1, 2, 1, 2, -2, -1, 1, 2, -2, -1, -2, -1]

    # Initialize model.
    model = torch.nn.ModuleList(
        [
            Linear(
557
558
                model_config.hidden_size,
                model_config.hidden_size,
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
                params_dtype=dtype,
            )
            for _ in range(num_layers)
        ]
    )

    # Initialize gradient buffers.
    for param in model.parameters():
        param.grad = torch.empty_like(param)

    # Make graphed version of model if needed.
    layer_forwards = {
        (i % num_layers, i // num_layers): model[i % num_layers]
        for i in range(num_layers * num_microbatches)
    }
    if with_graph:
        sample_args = tuple(
576
577
            (generate_data(model_config, dtype, warmup=True),)
            for _ in range(num_layers * num_microbatches)
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
        )
        layer_forwards = make_graphed_callables(
            tuple(model),
            sample_args,
            allow_unused_input=True,
            _order=layer_order,
        )
        layer_forwards = {
            (i // num_microbatches, i % num_microbatches): forward
            for i, forward in enumerate(layer_forwards)
        }

    # Optimizer.
    optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

    # Training loop.
    for _ in range(3):
        optimizer.zero_grad(set_to_none=False)

        # Generate data.
        inputs = {}
        grad_outputs = {}
        for layer_idx in range(num_layers):
            for microbatch_idx in range(num_microbatches):
602
603
                x = generate_data(model_config, dtype)
                dy = generate_data(model_config, dtype, requires_grad=False)
604
                idxs = (layer_idx, microbatch_idx)
605
                inputs[idxs] = x
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
                grad_outputs[idxs] = dy

        # Cache for layer outputs.
        outputs = {}

        def forward(layer_idx: int, microbatch_idx: int):
            """Helper function for forward steps"""
            idxs = (layer_idx, microbatch_idx)
            outputs[idxs] = layer_forwards[idxs](inputs[idxs])

        def backward(layer_idx: int, microbatch_idx: int):
            """Helper function for backward steps"""
            outputs[layer_idx, microbatch_idx].backward(grad_outputs[layer_idx, microbatch_idx])

        # Forward and backward steps.
        forward(0, 0)
        forward(1, 0)
        forward(0, 1)
        forward(1, 1)
        backward(1, 0)
        backward(0, 0)
        forward(0, 2)
        forward(1, 2)
        backward(1, 1)
        backward(0, 1)
        backward(1, 2)
        backward(0, 2)

        # Optimizer step.
        optimizer.step()

    outputs = [y for _, y in sorted(outputs.items())]
    return get_outputs(model, outputs)


def test_make_graphed_callables_with_interleaved_pipeline_parallelism(
642
643
    *,
    model_config: str = "small",
644
645
646
    dtype: torch.dtype = torch.float16,
) -> None:
    """Test CUDA graphs with Megatron-LM interleaved pipeline parallelism."""
647
648
    model_config = model_configs[model_config]
    kwargs = dict(model_config=model_config, dtype=dtype)
649
650
651
652
653
654
655
656
657
    outputs = _test_cuda_graphs_with_interleaved_pipeline_parallelism(
        with_graph=False,
        **kwargs,
    )
    graph_outputs = _test_cuda_graphs_with_interleaved_pipeline_parallelism(
        with_graph=True,
        **kwargs,
    )
    assert_all_equal(outputs, graph_outputs)