test_cuda_graphs.py 23.7 KB
Newer Older
1
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.

5
from typing import Callable, Dict, Iterable, List, Tuple, Union
6
7
8
9
import pytest

import torch
from transformer_engine.pytorch import (
10
11
12
13
14
15
    DotProductAttention,
    LayerNormLinear,
    LayerNormMLP,
    Linear,
    MultiheadAttention,
    TransformerLayer,
16
17
    autocast,
    quantized_model_init,
18
    make_graphed_callables,
19
20
21
22
    is_fp8_available,
    is_fp8_block_scaling_available,
    is_mxfp8_available,
    is_bf16_available,
23
)
24
from transformer_engine.pytorch.quantization import FP8GlobalStateManager
25
import transformer_engine.pytorch.ops as te_ops
26
from transformer_engine.common import recipe
27
from utils import ModelConfig, reset_rng_states
yuguo's avatar
yuguo committed
28
29
30
31
from torch.utils.cpp_extension import IS_HIP_EXTENSION
if IS_HIP_EXTENSION:
    import os
    from functools import cache
32

33
# Check if FP8 is supported.
34
35
36
fp8_available = is_fp8_available()
fp8_block_scaling_available = is_fp8_block_scaling_available()
mxfp8_available = is_mxfp8_available()
37

38
39
# Reset RNG states.
reset_rng_states()
40

41
model_configs = {
42
    "small": ModelConfig(2, 32, 2, 32),
43
}
44

45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90

def nvfp4_vanilla():
    nvfp4_recipe = recipe.NVFP4BlockScaling()
    nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams()
    nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams()
    nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams()
    return nvfp4_recipe


def nvfp4_rht_and_2d_quantization():
    nvfp4_recipe = recipe.NVFP4BlockScaling()
    nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams(
        random_hadamard_transform=True, fp4_2d_quantization=False
    )
    nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams(
        random_hadamard_transform=False, fp4_2d_quantization=True
    )
    nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams(
        random_hadamard_transform=True, fp4_2d_quantization=False
    )
    return nvfp4_recipe


def check_rht_usage(recipe: recipe.Recipe) -> bool:
    # if using RHT, we can only support bf16
    # check fp4_quant_fwd_inp, fp4_quant_fwd_weight, fp4_quant_bwd_grad
    if recipe.nvfp4():
        if (
            recipe.fp4_quant_fwd_inp.random_hadamard_transform
            or recipe.fp4_quant_fwd_weight.random_hadamard_transform
            or recipe.fp4_quant_bwd_grad.random_hadamard_transform
        ):
            return True
    return False


def get_nvfp4_inp_supported_dtypes(recipe: recipe.Recipe, dtype: torch.dtype) -> bool:
    supported_input_dtypes = []
    if recipe.nvfp4():
        supported_input_dtypes.append(torch.bfloat16)
        # if not using RHT, we can add fp32 as well
    if not check_rht_usage(recipe):
        supported_input_dtypes.append(torch.float32)
    return supported_input_dtypes


91
92
93
fp8_recipes = []
if mxfp8_available:
    fp8_recipes.append(recipe.MXFP8BlockScaling())
94
    fp8_recipes.append(nvfp4_rht_and_2d_quantization())
95
96
97
98
99
if fp8_block_scaling_available:
    fp8_recipes.append(recipe.Float8BlockScaling())
if fp8_available:
    fp8_recipes.append(recipe.Float8CurrentScaling())
    fp8_recipes.append(recipe.DelayedScaling())
100

101
102
# Supported data types
dtypes: List[torch.dtype] = [torch.float32, torch.float16]
103
if is_bf16_available():  # bf16 requires sm_80 or higher
104
105
106
107
108
109
110
111
    dtypes.append(torch.bfloat16)


@pytest.fixture(autouse=True)
def reset_global_fp8_state():
    yield
    FP8GlobalStateManager.reset()

yuguo's avatar
yuguo committed
112
113
114
115
116
117
118
119
120
if IS_HIP_EXTENSION:
    @cache
    def use_hipblaslt() -> bool:
        return (os.getenv("NVTE_USE_HIPBLASLT") is not None
                or os.getenv("NVTE_USE_ROCBLAS") is None )
    @pytest.fixture(autouse=True)
    def skip_rocblas():
        if not use_hipblaslt():
            pytest.skip("CUDA graph capture not supported with rocBLAS path")
121
122

def assert_all_equal(l1: List[torch.Tensor], l2: List[torch.Tensor], names=None) -> bool:
123
    """Check that two lists of tensors match exactly."""
124
    assert len(l1) == len(l2), "Unequal number of outputs."
125
126
    failure_message = "Output mismatches in:"
    failed_tensors = []
127
128
    for i, (t1, t2) in enumerate(zip(l1, l2)):
        if not torch.equal(t1, t2):
129
130
131
132
133
134
135
136
137
138
            failure_message += "\n    "
            if names is None:
                failure_message += f"tensor at idx={i}"
            else:
                failure_message += names[i]
            failed_tensors.append((t1, t2))
    if failed_tensors:
        print(failure_message)
        t1, t2 = failed_tensors[0]
        torch.testing.assert_close(t1, t2, rtol=0, atol=0)
139
140
141


def generate_data(
142
    model_config: ModelConfig,
143
144
    dtype: torch.dtype,
    warmup: bool = False,
145
146
    requires_grad: bool = True,
) -> torch.Tensor:
147
148
    """Generate synthetic data."""
    gen_func = torch.ones if warmup else torch.randn
149
    return gen_func(
150
        model_config.max_seqlen_q,
151
152
        model_config.batch_size,
        model_config.hidden_size,
153
        device="cuda",
154
        requires_grad=requires_grad,
155
156
        dtype=dtype,
    )
157
158


159
160
161
162
def get_outputs(
    model: torch.nn.Module,
    output: Union[torch.Tensor, Iterable[torch.Tensor]],
) -> List[torch.Tensor]:
163
164
165
166
167
168
    """Return grads and params for comparsion."""
    values = []
    for param in model.parameters():
        values.append(param)
        if param.grad is not None:
            values.append(param.grad)
169
170
171
172
    if isinstance(output, torch.Tensor):
        values.append(output)
    else:
        values.extend(output)
173
174
175
    return values


176
177
178
179
180
181
182
183
184
185
186
187
188
189
def reset_graphs(
    graphed_callables: Union[Callable, Tuple[Callable, ...], Dict[Tuple[int, int], Callable]],
) -> None:
    """Reset CUDA graphs."""
    if isinstance(graphed_callables, tuple) or isinstance(graphed_callables, list):
        for callable in graphed_callables:
            callable.reset()
    elif isinstance(graphed_callables, dict):
        for callable in graphed_callables.values():
            callable.reset()
    else:
        graphed_callables.reset()


190
191
192
193
194
195
196
197
198
199
class _Sequential(torch.nn.Sequential):
    """Sequential model that forwards keyword arguments to modules"""

    def forward(self, input_: torch.Tensor, **kwargs) -> torch.Tensor:
        x = input_
        for module in self:
            x = module(x, **kwargs)
        return x


200
201
# Supported modules
_test_cuda_graphs_modules: List[str] = [
202
203
204
    # Put linear first to test the case where the cuda context might not be set in
    # creating TMA descriptor for MXFP8 quantization.
    "linear",
205
    "transformer",
206
207
    "layernorm_mlp_nocheckpoint",
    "layernorm_mlp_checkpoint",
208
209
210
211
212
213
    "layernorm_linear",
    "mha",
    "linear_op",
]


214
215
def _test_cuda_graphs(
    *,
216
217
218
    graph_mode: str,
    module: str,
    model_config: ModelConfig,
219
220
221
222
223
    num_layers: int,
    dtype: torch.dtype,
    fp8: bool,
    fp8_params: bool,
    fp8_weight_caching: bool,
224
    fp8_recipe: recipe.Recipe,
225
) -> List[torch.Tensor]:
226
    """Helper function for CUDA graph test."""
227
228
229
    reset_rng_states()
    FP8GlobalStateManager.reset()

230
231
232
233
234
    # Operation-based API does not support FP8 weight caching.
    if module == "linear_op":
        fp8_weight_caching = False

    # Create modules.
235
    with quantized_model_init(enabled=fp8_params, recipe=fp8_recipe):
236
        if module == "transformer":
237
238
            modules = [
                TransformerLayer(
239
240
241
                    model_config.hidden_size,
                    model_config.hidden_size,
                    model_config.num_heads,
242
243
244
245
246
247
248
                    hidden_dropout=0.0,
                    attention_dropout=0.0,
                    fuse_qkv_params=True,
                    params_dtype=dtype,
                )
                for _ in range(num_layers)
            ]
249
        elif module == "layernorm_mlp_nocheckpoint":
250
            modules = [
251
252
253
254
                LayerNormMLP(
                    model_config.hidden_size,
                    model_config.hidden_size,
                    params_dtype=dtype,
255
256
257
258
259
260
261
262
263
264
265
                    checkpoint=False,
                )
                for _ in range(num_layers)
            ]
        elif module == "layernorm_mlp_checkpoint":
            modules = [
                LayerNormMLP(
                    model_config.hidden_size,
                    model_config.hidden_size,
                    params_dtype=dtype,
                    checkpoint=True,
266
                )
267
268
                for _ in range(num_layers)
            ]
269
        elif module == "layernorm_linear":
270
            modules = [
271
272
273
274
275
                LayerNormLinear(
                    model_config.hidden_size,
                    model_config.hidden_size,
                    params_dtype=dtype,
                )
276
277
                for _ in range(num_layers)
            ]
278
        elif module == "mha":
279
280
            modules = [
                MultiheadAttention(
281
282
                    model_config.hidden_size,
                    model_config.num_heads,
283
284
285
286
287
288
                    attention_dropout=0.0,
                    params_dtype=dtype,
                    fuse_qkv_params=True,
                )
                for _ in range(num_layers)
            ]
289
        elif module == "linear":
290
            modules = [
291
292
293
294
295
296
                Linear(
                    model_config.hidden_size,
                    model_config.hidden_size,
                    device="cuda",
                    params_dtype=dtype,
                )
297
298
                for _ in range(num_layers)
            ]
299
        elif module == "linear_op":
300
            modules = [
301
302
303
304
305
306
307
                te_ops.Sequential(
                    te_ops.Linear(
                        model_config.hidden_size,
                        model_config.hidden_size,
                        dtype=dtype,
                    ),
                )
308
309
                for _ in range(num_layers)
            ]
310
311
        else:
            raise ValueError(f"Unknown module type ({module})")
312

313
314
315
316
317
        # Initialize gradient buffers.
        for module in modules:
            for param in module.parameters():
                param.grad = torch.empty_like(param)

318
        # Generate model and wrap API to return graphed version.
319
320
        if graph_mode == "full":
            # Graph entire model at once.
321
            model = torch.nn.Sequential(*modules)
322
323
            model = make_graphed_callables(
                model,
324
                (generate_data(model_config, dtype, warmup=True),),
325
                num_warmup_iters=10,
326
327
328
                enabled=fp8,
                cache_quantized_params=fp8_weight_caching,
                recipe=fp8_recipe,
329
330
            )
        elif graph_mode == "individual":
331
            # Graph individual modules.
332
333
            modules = [
                make_graphed_callables(
334
                    module,
335
                    (generate_data(model_config, dtype, warmup=True),),
336
                    num_warmup_iters=10,
337
338
339
                    enabled=fp8,
                    cache_quantized_params=fp8_weight_caching,
                    recipe=fp8_recipe,
340
341
342
                )
                for module in modules
            ]
343
            model = _Sequential(*modules)
344
        else:
345
            model = _Sequential(*modules)
346

347
    # Optimizer.
348
    optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
349

350
    # Training steps.
351
    for _ in range(3):
352
        optimizer.zero_grad(set_to_none=False)
353
        for grad_accumulation_step in range(2):
354
355
            input_ = generate_data(model_config, dtype)
            grad_output = generate_data(model_config, dtype, requires_grad=False)
356
            with autocast(enabled=fp8, recipe=fp8_recipe):
357
358
                kwargs = {}
                if fp8_weight_caching:
359
                    kwargs["is_first_microbatch"] = grad_accumulation_step == 0
360
                output = model(input_, **kwargs)
361
            output.backward(grad_output)
362
        optimizer.step()
363

364
365
366
367
368
369
    outputs = get_outputs(model, output)
    if graph_mode == "full":
        reset_graphs(model)
    elif graph_mode == "individual":
        reset_graphs(modules)
    return outputs
370
371


372
@pytest.mark.parametrize("module", _test_cuda_graphs_modules)
373
@pytest.mark.parametrize("dtype", dtypes)
374
@pytest.mark.parametrize("fp8_params", (False, True))
375
@pytest.mark.parametrize("fp8_recipe", fp8_recipes + [None], ids=lambda r: type(r).__name__)
376
377
378
379
380
def test_make_graphed_callables(
    *,
    module: str,
    model_config: str = "small",
    num_layers: int = 3,
381
382
    dtype: torch.dtype,
    fp8_params: bool,
383
    fp8_recipe: recipe.Recipe,
384
    fp8_weight_caching: bool = False,
385
) -> None:
386

387
    fp8 = fp8_recipe is not None
388
389
    if fp8_params and not fp8:
        pytest.skip("FP8 needed for FP8 parameters.")
390
391
    if fp8_weight_caching and not fp8:
        pytest.skip("FP8 needed for FP8 parameters.")
392
393
394
395
396
397
398
399
400
401
402
403
    if fp8 and (fp8_recipe.float8_block_scaling() or fp8_recipe.nvfp4()) and module == "linear_op":
        pytest.skip(
            f"Module not yet supported for {fp8_recipe.__class__.__name__} with CUDA graphs"
        )
    if fp8 and fp8_recipe.nvfp4():
        if dtype not in get_nvfp4_inp_supported_dtypes(fp8_recipe, dtype):
            pytest.skip(
                f"Input dtype {dtype} not supported for NVFP4 Recipe"
                f" {fp8_recipe.__class__.__name__}"
            )
        if fp8_params:
            pytest.skip("NVFP4 params not supported")
404
405
406
407
408
409
410
411
412
413
414
    if (
        fp8
        and fp8_recipe.delayed()
        and torch.cuda.get_device_capability() >= (10, 0)
        and module == "layernorm_mlp_checkpoint"
    ):
        pytest.skip(
            "CUDA graphs not supported for LayerNormMLP "
            "with checkpoint=True, SM>=10, "
            "and DelayedScaling recipe"
        )
415
    if fp8 and not fp8_available:
wenjh's avatar
wenjh committed
416
        pytest.skip("FP8 not supported on rocm GPU.")
417
    if fp8 and fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
wenjh's avatar
wenjh committed
418
        pytest.skip("FP8 block scaling not supported on rocm GPU.")
419
    if fp8 and fp8_recipe.mxfp8() and not mxfp8_available:
wenjh's avatar
wenjh committed
420
421
        pytest.skip("MXFP8 not supported on rocm GPU.")
        
422
423
    # Run model with different CUDA graph settings.
    model_config = model_configs[model_config]
424
    kwargs = dict(
425
426
        module=module,
        model_config=model_config,
427
428
429
430
431
        num_layers=num_layers,
        dtype=dtype,
        fp8=fp8,
        fp8_params=fp8_params,
        fp8_weight_caching=fp8_weight_caching,
432
        fp8_recipe=fp8_recipe,
433
    )
434
435
    # Put graphed callables first to test the case where the cuda context might not be set in
    # creating TMA descriptor for MXFP8 quantization.
436
437
    graph_outputs_mode1 = _test_cuda_graphs(graph_mode="full", **kwargs)
    graph_outputs_mode2 = _test_cuda_graphs(graph_mode="individual", **kwargs)
438
    outputs = _test_cuda_graphs(graph_mode="none", **kwargs)
439

440
    # Check that results match.
441
442
    assert_all_equal(outputs, graph_outputs_mode1)
    assert_all_equal(outputs, graph_outputs_mode2)
443
444


445
446
_test_make_graphed_callables_with_fp8_weight_caching_modules = [
    "transformer",
447
448
    "layernorm_mlp_nocheckpoint",
    "layernorm_mlp_checkpoint",
449
450
451
452
453
454
455
456
457
458
    "layernorm_linear",
    "linear",
    "mha",
]


@pytest.mark.parametrize(
    "module",
    _test_make_graphed_callables_with_fp8_weight_caching_modules,
)
459
@pytest.mark.parametrize("dtype", dtypes)
460
@pytest.mark.parametrize("fp8_params", (False, True))
461
@pytest.mark.parametrize("fp8_recipe", fp8_recipes, ids=lambda r: type(r).__name__)
462
def test_make_graphed_callables_with_fp8_weight_caching(
463
    *,
464
    module: str,
465
    dtype: torch.dtype,
466
    fp8_params: bool,
467
    fp8_recipe: recipe.Recipe,
468
469
470
) -> None:
    test_make_graphed_callables(
        module=module,
471
        dtype=dtype,
472
        fp8_params=fp8_params,
473
        fp8_recipe=fp8_recipe,
474
475
476
477
478
479
        fp8_weight_caching=True,
    )


def generate_data_for_dot_product_attention(
    model_config: ModelConfig,
480
    dtype: torch.dtype,
481
482
483
484
485
486
    warmup: bool = False,
) -> List[torch.Tensor]:
    """Generate synthetic data for dot product attention."""
    gen_func = torch.ones if warmup else torch.randn
    return [
        gen_func(
487
            model_config.max_seqlen_q,
488
489
490
491
492
493
494
495
496
497
498
499
500
            model_config.batch_size,
            model_config.num_heads,
            model_config.kv_channels,
            device="cuda",
            requires_grad=True,
            dtype=dtype,
        )
        for _ in range(3)
    ]


def _test_cuda_graphs_with_dot_product_attention(
    *,
501
    with_graph: bool,
502
503
    model_config: ModelConfig,
    dtype: torch.dtype,
504
) -> List[torch.Tensor]:
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
    """Helper function for CUDA graph test."""
    reset_rng_states()
    FP8GlobalStateManager.reset()

    # Create dot product attention module.
    assert model_config.hidden_size % model_config.num_heads == 0
    model = DotProductAttention(
        model_config.num_heads,
        model_config.kv_channels,
        attention_dropout=0.0,
    )

    # Graph model if needed.
    if with_graph:
        model = make_graphed_callables(
            model,
            generate_data_for_dot_product_attention(model_config, dtype, warmup=True),
            num_warmup_iters=10,
523
            enabled=False,
524
525
526
527
528
529
530
531
532
        )

    # Forward and backward passes.
    for _ in range(3):
        inputs = generate_data_for_dot_product_attention(model_config, dtype)
        grad_output = generate_data(model_config, dtype, requires_grad=False)
        output = model(*inputs)
        output.backward(grad_output)

533
534
535
536
    outputs = get_outputs(model, output)
    if with_graph:
        reset_graphs(model)
    return outputs
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559


@pytest.mark.parametrize("dtype", dtypes)
def test_make_graphed_callables_with_dot_product_attention(
    *,
    model_config: str = "small",
    dtype: torch.dtype,
) -> None:
    """Test CUDA graphs with dot product attention."""
    model_config = model_configs[model_config]
    kwargs = dict(model_config=model_config, dtype=dtype)
    outputs = _test_cuda_graphs_with_dot_product_attention(with_graph=False, **kwargs)
    graph_outputs = _test_cuda_graphs_with_dot_product_attention(with_graph=True, **kwargs)
    assert_all_equal(outputs, graph_outputs)


def _test_cuda_graphs_with_kwargs(
    *,
    with_graph: bool,
    model_config: ModelConfig,
    dtype: torch.dtype,
) -> List[torch.Tensor]:
    """Helper function for CUDA graph test with keyword arguments."""
560
561
562
563
    reset_rng_states()

    # Initialize model.
    model = TransformerLayer(
564
565
566
        model_config.hidden_size,
        model_config.hidden_size,
        model_config.num_heads,
567
568
569
570
571
572
573
574
575
576
577
578
579
580
        hidden_dropout=0.0,
        attention_dropout=0.0,
        self_attn_mask_type="arbitrary",
        fuse_qkv_params=True,
        params_dtype=dtype,
    )

    # Initialize gradient buffers.
    for param in model.parameters():
        param.grad = torch.empty_like(param)

    # Make graphed version of model if needed.
    if with_graph:
        attn_mask = torch.zeros(
581
582
583
            (
                model_config.batch_size,
                1,
584
585
                model_config.max_seqlen_q,
                model_config.max_seqlen_kv,
586
            ),
587
588
589
590
591
            dtype=torch.bool,
            device="cuda",
        )
        model = make_graphed_callables(
            model,
592
            (generate_data(model_config, dtype, warmup=True),),
593
594
595
596
597
598
599
600
601
602
603
            sample_kwargs=dict(attention_mask=attn_mask),
            allow_unused_input=True,
        )

    # Optimizer.
    optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

    # Training loop.
    for _ in range(3):
        optimizer.zero_grad(set_to_none=False)
        for grad_accumulation_step in range(2):
604
605
            input_ = generate_data(model_config, dtype)
            grad_output = generate_data(model_config, dtype, requires_grad=False)
606
607
            attn_mask = torch.randint(
                2,
608
609
610
                (
                    model_config.batch_size,
                    1,
611
612
                    model_config.max_seqlen_q,
                    model_config.max_seqlen_kv,
613
                ),
614
615
616
                dtype=torch.bool,
                device="cuda",
            )
617
            output = model(input_, attention_mask=attn_mask)
618
619
620
            output.backward(grad_output)
        optimizer.step()

621
622
623
624
    outputs = get_outputs(model, output)
    if with_graph:
        reset_graphs(model)
    return outputs
625
626
627


def test_make_graphed_callables_with_kwargs(
628
629
    *,
    model_config: str = "small",
630
631
632
    dtype: torch.dtype = torch.float32,
) -> None:
    """Test CUDA graphs with keyword arguments."""
633
634
    model_config = model_configs[model_config]
    kwargs = dict(model_config=model_config, dtype=dtype)
635
636
637
638
639
640
641
642
    outputs = _test_cuda_graphs_with_kwargs(with_graph=False, **kwargs)
    graph_outputs = _test_cuda_graphs_with_kwargs(with_graph=True, **kwargs)
    assert_all_equal(outputs, graph_outputs)


def _test_cuda_graphs_with_interleaved_pipeline_parallelism(
    *,
    with_graph: bool,
643
644
    model_config: ModelConfig,
    dtype: torch.dtype,
645
646
647
648
649
650
651
652
653
654
655
656
657
) -> List[torch.Tensor]:
    """Simulate Megatron-LM interleaved pipeline parallelism."""
    reset_rng_states()

    # Pipeline parallel configuration.
    num_layers = 2
    num_microbatches = 3
    layer_order = [1, 2, 1, 2, -2, -1, 1, 2, -2, -1, -2, -1]

    # Initialize model.
    model = torch.nn.ModuleList(
        [
            Linear(
658
659
                model_config.hidden_size,
                model_config.hidden_size,
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
                params_dtype=dtype,
            )
            for _ in range(num_layers)
        ]
    )

    # Initialize gradient buffers.
    for param in model.parameters():
        param.grad = torch.empty_like(param)

    # Make graphed version of model if needed.
    layer_forwards = {
        (i % num_layers, i // num_layers): model[i % num_layers]
        for i in range(num_layers * num_microbatches)
    }
    if with_graph:
        sample_args = tuple(
677
678
            (generate_data(model_config, dtype, warmup=True),)
            for _ in range(num_layers * num_microbatches)
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
        )
        layer_forwards = make_graphed_callables(
            tuple(model),
            sample_args,
            allow_unused_input=True,
            _order=layer_order,
        )
        layer_forwards = {
            (i // num_microbatches, i % num_microbatches): forward
            for i, forward in enumerate(layer_forwards)
        }

    # Optimizer.
    optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

    # Training loop.
    for _ in range(3):
        optimizer.zero_grad(set_to_none=False)

        # Generate data.
        inputs = {}
        grad_outputs = {}
        for layer_idx in range(num_layers):
            for microbatch_idx in range(num_microbatches):
703
704
                x = generate_data(model_config, dtype)
                dy = generate_data(model_config, dtype, requires_grad=False)
705
                idxs = (layer_idx, microbatch_idx)
706
                inputs[idxs] = x
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
                grad_outputs[idxs] = dy

        # Cache for layer outputs.
        outputs = {}

        def forward(layer_idx: int, microbatch_idx: int):
            """Helper function for forward steps"""
            idxs = (layer_idx, microbatch_idx)
            outputs[idxs] = layer_forwards[idxs](inputs[idxs])

        def backward(layer_idx: int, microbatch_idx: int):
            """Helper function for backward steps"""
            outputs[layer_idx, microbatch_idx].backward(grad_outputs[layer_idx, microbatch_idx])

        # Forward and backward steps.
        forward(0, 0)
        forward(1, 0)
        forward(0, 1)
        forward(1, 1)
        backward(1, 0)
        backward(0, 0)
        forward(0, 2)
        forward(1, 2)
        backward(1, 1)
        backward(0, 1)
        backward(1, 2)
        backward(0, 2)

        # Optimizer step.
        optimizer.step()

    outputs = [y for _, y in sorted(outputs.items())]
739
740
741
742
    outputs = get_outputs(model, outputs)
    if with_graph:
        reset_graphs(layer_forwards)
    return outputs
743
744
745


def test_make_graphed_callables_with_interleaved_pipeline_parallelism(
746
747
    *,
    model_config: str = "small",
748
749
750
    dtype: torch.dtype = torch.float16,
) -> None:
    """Test CUDA graphs with Megatron-LM interleaved pipeline parallelism."""
751
752
    model_config = model_configs[model_config]
    kwargs = dict(model_config=model_config, dtype=dtype)
753
754
755
756
757
758
759
760
761
    outputs = _test_cuda_graphs_with_interleaved_pipeline_parallelism(
        with_graph=False,
        **kwargs,
    )
    graph_outputs = _test_cuda_graphs_with_interleaved_pipeline_parallelism(
        with_graph=True,
        **kwargs,
    )
    assert_all_equal(outputs, graph_outputs)