test_cuda_graphs.py 22.9 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.

5
from typing import Callable, Dict, Iterable, List, Tuple, Union
6
7
8
9
import pytest

import torch
from transformer_engine.pytorch import (
10
11
12
13
14
15
    DotProductAttention,
    LayerNormLinear,
    LayerNormMLP,
    Linear,
    MultiheadAttention,
    TransformerLayer,
16
17
    autocast,
    quantized_model_init,
18
    make_graphed_callables,
19
20
21
22
    is_fp8_available,
    is_fp8_block_scaling_available,
    is_mxfp8_available,
    is_bf16_available,
23
)
24
from transformer_engine.pytorch.quantization import FP8GlobalStateManager
25
import transformer_engine.pytorch.ops as te_ops
26
from transformer_engine.common import recipe
27
from utils import ModelConfig, reset_rng_states
28

29
# Check if FP8 is supported.
30
31
32
fp8_available = is_fp8_available()
fp8_block_scaling_available = is_fp8_block_scaling_available()
mxfp8_available = is_mxfp8_available()
33

34
35
# Reset RNG states.
reset_rng_states()
36

37
model_configs = {
38
    "small": ModelConfig(2, 32, 2, 32),
39
}
40

41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86

def nvfp4_vanilla():
    nvfp4_recipe = recipe.NVFP4BlockScaling()
    nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams()
    nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams()
    nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams()
    return nvfp4_recipe


def nvfp4_rht_and_2d_quantization():
    nvfp4_recipe = recipe.NVFP4BlockScaling()
    nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams(
        random_hadamard_transform=True, fp4_2d_quantization=False
    )
    nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams(
        random_hadamard_transform=False, fp4_2d_quantization=True
    )
    nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams(
        random_hadamard_transform=True, fp4_2d_quantization=False
    )
    return nvfp4_recipe


def check_rht_usage(recipe: recipe.Recipe) -> bool:
    # if using RHT, we can only support bf16
    # check fp4_quant_fwd_inp, fp4_quant_fwd_weight, fp4_quant_bwd_grad
    if recipe.nvfp4():
        if (
            recipe.fp4_quant_fwd_inp.random_hadamard_transform
            or recipe.fp4_quant_fwd_weight.random_hadamard_transform
            or recipe.fp4_quant_bwd_grad.random_hadamard_transform
        ):
            return True
    return False


def get_nvfp4_inp_supported_dtypes(recipe: recipe.Recipe, dtype: torch.dtype) -> bool:
    supported_input_dtypes = []
    if recipe.nvfp4():
        supported_input_dtypes.append(torch.bfloat16)
        # if not using RHT, we can add fp32 as well
    if not check_rht_usage(recipe):
        supported_input_dtypes.append(torch.float32)
    return supported_input_dtypes


87
88
89
fp8_recipes = []
if mxfp8_available:
    fp8_recipes.append(recipe.MXFP8BlockScaling())
90
    fp8_recipes.append(nvfp4_rht_and_2d_quantization())
91
92
93
94
95
if fp8_block_scaling_available:
    fp8_recipes.append(recipe.Float8BlockScaling())
if fp8_available:
    fp8_recipes.append(recipe.Float8CurrentScaling())
    fp8_recipes.append(recipe.DelayedScaling())
96

97
98
# Supported data types
dtypes: List[torch.dtype] = [torch.float32, torch.float16]
99
if is_bf16_available():  # bf16 requires sm_80 or higher
100
101
102
103
104
105
106
107
108
109
    dtypes.append(torch.bfloat16)


@pytest.fixture(autouse=True)
def reset_global_fp8_state():
    yield
    FP8GlobalStateManager.reset()


def assert_all_equal(l1: List[torch.Tensor], l2: List[torch.Tensor], names=None) -> bool:
110
    """Check that two lists of tensors match exactly."""
111
    assert len(l1) == len(l2), "Unequal number of outputs."
112
113
    failure_message = "Output mismatches in:"
    failed_tensors = []
114
115
    for i, (t1, t2) in enumerate(zip(l1, l2)):
        if not torch.equal(t1, t2):
116
117
118
119
120
121
122
123
124
125
            failure_message += "\n    "
            if names is None:
                failure_message += f"tensor at idx={i}"
            else:
                failure_message += names[i]
            failed_tensors.append((t1, t2))
    if failed_tensors:
        print(failure_message)
        t1, t2 = failed_tensors[0]
        torch.testing.assert_close(t1, t2, rtol=0, atol=0)
126
127
128


def generate_data(
129
    model_config: ModelConfig,
130
131
    dtype: torch.dtype,
    warmup: bool = False,
132
133
    requires_grad: bool = True,
) -> torch.Tensor:
134
135
    """Generate synthetic data."""
    gen_func = torch.ones if warmup else torch.randn
136
    return gen_func(
137
        model_config.max_seqlen_q,
138
139
        model_config.batch_size,
        model_config.hidden_size,
140
        device="cuda",
141
        requires_grad=requires_grad,
142
143
        dtype=dtype,
    )
144
145


146
147
148
149
def get_outputs(
    model: torch.nn.Module,
    output: Union[torch.Tensor, Iterable[torch.Tensor]],
) -> List[torch.Tensor]:
150
151
152
153
154
155
    """Return grads and params for comparsion."""
    values = []
    for param in model.parameters():
        values.append(param)
        if param.grad is not None:
            values.append(param.grad)
156
157
158
159
    if isinstance(output, torch.Tensor):
        values.append(output)
    else:
        values.extend(output)
160
161
162
    return values


163
164
165
166
167
168
169
170
171
172
173
174
175
176
def reset_graphs(
    graphed_callables: Union[Callable, Tuple[Callable, ...], Dict[Tuple[int, int], Callable]],
) -> None:
    """Reset CUDA graphs."""
    if isinstance(graphed_callables, tuple) or isinstance(graphed_callables, list):
        for callable in graphed_callables:
            callable.reset()
    elif isinstance(graphed_callables, dict):
        for callable in graphed_callables.values():
            callable.reset()
    else:
        graphed_callables.reset()


177
178
179
180
181
182
183
184
185
186
class _Sequential(torch.nn.Sequential):
    """Sequential model that forwards keyword arguments to modules"""

    def forward(self, input_: torch.Tensor, **kwargs) -> torch.Tensor:
        x = input_
        for module in self:
            x = module(x, **kwargs)
        return x


187
188
# Supported modules
_test_cuda_graphs_modules: List[str] = [
189
190
191
    # Put linear first to test the case where the cuda context might not be set in
    # creating TMA descriptor for MXFP8 quantization.
    "linear",
192
    "transformer",
193
194
    "layernorm_mlp_nocheckpoint",
    "layernorm_mlp_checkpoint",
195
196
197
198
199
200
    "layernorm_linear",
    "mha",
    "linear_op",
]


201
202
def _test_cuda_graphs(
    *,
203
204
205
    graph_mode: str,
    module: str,
    model_config: ModelConfig,
206
207
208
209
210
    num_layers: int,
    dtype: torch.dtype,
    fp8: bool,
    fp8_params: bool,
    fp8_weight_caching: bool,
211
    fp8_recipe: recipe.Recipe,
212
) -> List[torch.Tensor]:
213
    """Helper function for CUDA graph test."""
214
215
216
    reset_rng_states()
    FP8GlobalStateManager.reset()

217
218
219
220
221
    # Operation-based API does not support FP8 weight caching.
    if module == "linear_op":
        fp8_weight_caching = False

    # Create modules.
222
    with quantized_model_init(enabled=fp8_params, recipe=fp8_recipe):
223
        if module == "transformer":
224
225
            modules = [
                TransformerLayer(
226
227
228
                    model_config.hidden_size,
                    model_config.hidden_size,
                    model_config.num_heads,
229
230
231
232
233
234
235
                    hidden_dropout=0.0,
                    attention_dropout=0.0,
                    fuse_qkv_params=True,
                    params_dtype=dtype,
                )
                for _ in range(num_layers)
            ]
236
        elif module == "layernorm_mlp_nocheckpoint":
237
            modules = [
238
239
240
241
                LayerNormMLP(
                    model_config.hidden_size,
                    model_config.hidden_size,
                    params_dtype=dtype,
242
243
244
245
246
247
248
249
250
251
252
                    checkpoint=False,
                )
                for _ in range(num_layers)
            ]
        elif module == "layernorm_mlp_checkpoint":
            modules = [
                LayerNormMLP(
                    model_config.hidden_size,
                    model_config.hidden_size,
                    params_dtype=dtype,
                    checkpoint=True,
253
                )
254
255
                for _ in range(num_layers)
            ]
256
        elif module == "layernorm_linear":
257
            modules = [
258
259
260
261
262
                LayerNormLinear(
                    model_config.hidden_size,
                    model_config.hidden_size,
                    params_dtype=dtype,
                )
263
264
                for _ in range(num_layers)
            ]
265
        elif module == "mha":
266
267
            modules = [
                MultiheadAttention(
268
269
                    model_config.hidden_size,
                    model_config.num_heads,
270
271
272
273
274
275
                    attention_dropout=0.0,
                    params_dtype=dtype,
                    fuse_qkv_params=True,
                )
                for _ in range(num_layers)
            ]
276
        elif module == "linear":
277
            modules = [
278
279
280
281
282
283
                Linear(
                    model_config.hidden_size,
                    model_config.hidden_size,
                    device="cuda",
                    params_dtype=dtype,
                )
284
285
                for _ in range(num_layers)
            ]
286
        elif module == "linear_op":
287
            modules = [
288
289
290
291
292
293
294
                te_ops.Sequential(
                    te_ops.Linear(
                        model_config.hidden_size,
                        model_config.hidden_size,
                        dtype=dtype,
                    ),
                )
295
296
                for _ in range(num_layers)
            ]
297
298
        else:
            raise ValueError(f"Unknown module type ({module})")
299

300
301
302
303
304
        # Initialize gradient buffers.
        for module in modules:
            for param in module.parameters():
                param.grad = torch.empty_like(param)

305
        # Generate model and wrap API to return graphed version.
306
307
        if graph_mode == "full":
            # Graph entire model at once.
308
            model = torch.nn.Sequential(*modules)
309
310
            model = make_graphed_callables(
                model,
311
                (generate_data(model_config, dtype, warmup=True),),
312
                num_warmup_iters=10,
313
314
315
                enabled=fp8,
                cache_quantized_params=fp8_weight_caching,
                recipe=fp8_recipe,
316
317
            )
        elif graph_mode == "individual":
318
            # Graph individual modules.
319
320
            modules = [
                make_graphed_callables(
321
                    module,
322
                    (generate_data(model_config, dtype, warmup=True),),
323
                    num_warmup_iters=10,
324
325
326
                    enabled=fp8,
                    cache_quantized_params=fp8_weight_caching,
                    recipe=fp8_recipe,
327
328
329
                )
                for module in modules
            ]
330
            model = _Sequential(*modules)
331
        else:
332
            model = _Sequential(*modules)
333

334
    # Optimizer.
335
    optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
336

337
    # Training steps.
338
    for _ in range(3):
339
        optimizer.zero_grad(set_to_none=False)
340
        for grad_accumulation_step in range(2):
341
342
            input_ = generate_data(model_config, dtype)
            grad_output = generate_data(model_config, dtype, requires_grad=False)
343
            with autocast(enabled=fp8, recipe=fp8_recipe):
344
345
                kwargs = {}
                if fp8_weight_caching:
346
                    kwargs["is_first_microbatch"] = grad_accumulation_step == 0
347
                output = model(input_, **kwargs)
348
            output.backward(grad_output)
349
        optimizer.step()
350

351
352
353
354
355
356
    outputs = get_outputs(model, output)
    if graph_mode == "full":
        reset_graphs(model)
    elif graph_mode == "individual":
        reset_graphs(modules)
    return outputs
357
358


359
@pytest.mark.parametrize("module", _test_cuda_graphs_modules)
360
@pytest.mark.parametrize("dtype", dtypes)
361
@pytest.mark.parametrize("fp8_params", (False, True))
362
@pytest.mark.parametrize("fp8_recipe", fp8_recipes + [None], ids=lambda r: type(r).__name__)
363
364
365
366
367
def test_make_graphed_callables(
    *,
    module: str,
    model_config: str = "small",
    num_layers: int = 3,
368
369
    dtype: torch.dtype,
    fp8_params: bool,
370
    fp8_recipe: recipe.Recipe,
371
    fp8_weight_caching: bool = False,
372
) -> None:
373

374
    fp8 = fp8_recipe is not None
375
376
    if fp8_params and not fp8:
        pytest.skip("FP8 needed for FP8 parameters.")
377
378
    if fp8_weight_caching and not fp8:
        pytest.skip("FP8 needed for FP8 parameters.")
379
380
381
382
383
384
385
386
387
388
389
390
    if fp8 and (fp8_recipe.float8_block_scaling() or fp8_recipe.nvfp4()) and module == "linear_op":
        pytest.skip(
            f"Module not yet supported for {fp8_recipe.__class__.__name__} with CUDA graphs"
        )
    if fp8 and fp8_recipe.nvfp4():
        if dtype not in get_nvfp4_inp_supported_dtypes(fp8_recipe, dtype):
            pytest.skip(
                f"Input dtype {dtype} not supported for NVFP4 Recipe"
                f" {fp8_recipe.__class__.__name__}"
            )
        if fp8_params:
            pytest.skip("NVFP4 params not supported")
391
392
393
394
395
396
397
398
399
400
401
    if (
        fp8
        and fp8_recipe.delayed()
        and torch.cuda.get_device_capability() >= (10, 0)
        and module == "layernorm_mlp_checkpoint"
    ):
        pytest.skip(
            "CUDA graphs not supported for LayerNormMLP "
            "with checkpoint=True, SM>=10, "
            "and DelayedScaling recipe"
        )
402

403
404
    # Run model with different CUDA graph settings.
    model_config = model_configs[model_config]
405
    kwargs = dict(
406
407
        module=module,
        model_config=model_config,
408
409
410
411
412
        num_layers=num_layers,
        dtype=dtype,
        fp8=fp8,
        fp8_params=fp8_params,
        fp8_weight_caching=fp8_weight_caching,
413
        fp8_recipe=fp8_recipe,
414
    )
415
416
    # Put graphed callables first to test the case where the cuda context might not be set in
    # creating TMA descriptor for MXFP8 quantization.
417
418
    graph_outputs_mode1 = _test_cuda_graphs(graph_mode="full", **kwargs)
    graph_outputs_mode2 = _test_cuda_graphs(graph_mode="individual", **kwargs)
419
    outputs = _test_cuda_graphs(graph_mode="none", **kwargs)
420

421
    # Check that results match.
422
423
    assert_all_equal(outputs, graph_outputs_mode1)
    assert_all_equal(outputs, graph_outputs_mode2)
424
425


426
427
_test_make_graphed_callables_with_fp8_weight_caching_modules = [
    "transformer",
428
429
    "layernorm_mlp_nocheckpoint",
    "layernorm_mlp_checkpoint",
430
431
432
433
434
435
436
437
438
439
    "layernorm_linear",
    "linear",
    "mha",
]


@pytest.mark.parametrize(
    "module",
    _test_make_graphed_callables_with_fp8_weight_caching_modules,
)
440
@pytest.mark.parametrize("dtype", dtypes)
441
@pytest.mark.parametrize("fp8_params", (False, True))
442
@pytest.mark.parametrize("fp8_recipe", fp8_recipes, ids=lambda r: type(r).__name__)
443
def test_make_graphed_callables_with_fp8_weight_caching(
444
    *,
445
    module: str,
446
    dtype: torch.dtype,
447
    fp8_params: bool,
448
    fp8_recipe: recipe.Recipe,
449
450
451
) -> None:
    test_make_graphed_callables(
        module=module,
452
        dtype=dtype,
453
        fp8_params=fp8_params,
454
        fp8_recipe=fp8_recipe,
455
456
457
458
459
460
        fp8_weight_caching=True,
    )


def generate_data_for_dot_product_attention(
    model_config: ModelConfig,
461
    dtype: torch.dtype,
462
463
464
465
466
467
    warmup: bool = False,
) -> List[torch.Tensor]:
    """Generate synthetic data for dot product attention."""
    gen_func = torch.ones if warmup else torch.randn
    return [
        gen_func(
468
            model_config.max_seqlen_q,
469
470
471
472
473
474
475
476
477
478
479
480
481
            model_config.batch_size,
            model_config.num_heads,
            model_config.kv_channels,
            device="cuda",
            requires_grad=True,
            dtype=dtype,
        )
        for _ in range(3)
    ]


def _test_cuda_graphs_with_dot_product_attention(
    *,
482
    with_graph: bool,
483
484
    model_config: ModelConfig,
    dtype: torch.dtype,
485
) -> List[torch.Tensor]:
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
    """Helper function for CUDA graph test."""
    reset_rng_states()
    FP8GlobalStateManager.reset()

    # Create dot product attention module.
    assert model_config.hidden_size % model_config.num_heads == 0
    model = DotProductAttention(
        model_config.num_heads,
        model_config.kv_channels,
        attention_dropout=0.0,
    )

    # Graph model if needed.
    if with_graph:
        model = make_graphed_callables(
            model,
            generate_data_for_dot_product_attention(model_config, dtype, warmup=True),
            num_warmup_iters=10,
504
            enabled=False,
505
506
507
508
509
510
511
512
513
        )

    # Forward and backward passes.
    for _ in range(3):
        inputs = generate_data_for_dot_product_attention(model_config, dtype)
        grad_output = generate_data(model_config, dtype, requires_grad=False)
        output = model(*inputs)
        output.backward(grad_output)

514
515
516
517
    outputs = get_outputs(model, output)
    if with_graph:
        reset_graphs(model)
    return outputs
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540


@pytest.mark.parametrize("dtype", dtypes)
def test_make_graphed_callables_with_dot_product_attention(
    *,
    model_config: str = "small",
    dtype: torch.dtype,
) -> None:
    """Test CUDA graphs with dot product attention."""
    model_config = model_configs[model_config]
    kwargs = dict(model_config=model_config, dtype=dtype)
    outputs = _test_cuda_graphs_with_dot_product_attention(with_graph=False, **kwargs)
    graph_outputs = _test_cuda_graphs_with_dot_product_attention(with_graph=True, **kwargs)
    assert_all_equal(outputs, graph_outputs)


def _test_cuda_graphs_with_kwargs(
    *,
    with_graph: bool,
    model_config: ModelConfig,
    dtype: torch.dtype,
) -> List[torch.Tensor]:
    """Helper function for CUDA graph test with keyword arguments."""
541
542
543
544
    reset_rng_states()

    # Initialize model.
    model = TransformerLayer(
545
546
547
        model_config.hidden_size,
        model_config.hidden_size,
        model_config.num_heads,
548
549
550
551
552
553
554
555
556
557
558
559
560
561
        hidden_dropout=0.0,
        attention_dropout=0.0,
        self_attn_mask_type="arbitrary",
        fuse_qkv_params=True,
        params_dtype=dtype,
    )

    # Initialize gradient buffers.
    for param in model.parameters():
        param.grad = torch.empty_like(param)

    # Make graphed version of model if needed.
    if with_graph:
        attn_mask = torch.zeros(
562
563
564
            (
                model_config.batch_size,
                1,
565
566
                model_config.max_seqlen_q,
                model_config.max_seqlen_kv,
567
            ),
568
569
570
571
572
            dtype=torch.bool,
            device="cuda",
        )
        model = make_graphed_callables(
            model,
573
            (generate_data(model_config, dtype, warmup=True),),
574
575
576
577
578
579
580
581
582
583
584
            sample_kwargs=dict(attention_mask=attn_mask),
            allow_unused_input=True,
        )

    # Optimizer.
    optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

    # Training loop.
    for _ in range(3):
        optimizer.zero_grad(set_to_none=False)
        for grad_accumulation_step in range(2):
585
586
            input_ = generate_data(model_config, dtype)
            grad_output = generate_data(model_config, dtype, requires_grad=False)
587
588
            attn_mask = torch.randint(
                2,
589
590
591
                (
                    model_config.batch_size,
                    1,
592
593
                    model_config.max_seqlen_q,
                    model_config.max_seqlen_kv,
594
                ),
595
596
597
                dtype=torch.bool,
                device="cuda",
            )
598
            output = model(input_, attention_mask=attn_mask)
599
600
601
            output.backward(grad_output)
        optimizer.step()

602
603
604
605
    outputs = get_outputs(model, output)
    if with_graph:
        reset_graphs(model)
    return outputs
606
607
608


def test_make_graphed_callables_with_kwargs(
609
610
    *,
    model_config: str = "small",
611
612
613
    dtype: torch.dtype = torch.float32,
) -> None:
    """Test CUDA graphs with keyword arguments."""
614
615
    model_config = model_configs[model_config]
    kwargs = dict(model_config=model_config, dtype=dtype)
616
617
618
619
620
621
622
623
    outputs = _test_cuda_graphs_with_kwargs(with_graph=False, **kwargs)
    graph_outputs = _test_cuda_graphs_with_kwargs(with_graph=True, **kwargs)
    assert_all_equal(outputs, graph_outputs)


def _test_cuda_graphs_with_interleaved_pipeline_parallelism(
    *,
    with_graph: bool,
624
625
    model_config: ModelConfig,
    dtype: torch.dtype,
626
627
628
629
630
631
632
633
634
635
636
637
638
) -> List[torch.Tensor]:
    """Simulate Megatron-LM interleaved pipeline parallelism."""
    reset_rng_states()

    # Pipeline parallel configuration.
    num_layers = 2
    num_microbatches = 3
    layer_order = [1, 2, 1, 2, -2, -1, 1, 2, -2, -1, -2, -1]

    # Initialize model.
    model = torch.nn.ModuleList(
        [
            Linear(
639
640
                model_config.hidden_size,
                model_config.hidden_size,
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
                params_dtype=dtype,
            )
            for _ in range(num_layers)
        ]
    )

    # Initialize gradient buffers.
    for param in model.parameters():
        param.grad = torch.empty_like(param)

    # Make graphed version of model if needed.
    layer_forwards = {
        (i % num_layers, i // num_layers): model[i % num_layers]
        for i in range(num_layers * num_microbatches)
    }
    if with_graph:
        sample_args = tuple(
658
659
            (generate_data(model_config, dtype, warmup=True),)
            for _ in range(num_layers * num_microbatches)
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
        )
        layer_forwards = make_graphed_callables(
            tuple(model),
            sample_args,
            allow_unused_input=True,
            _order=layer_order,
        )
        layer_forwards = {
            (i // num_microbatches, i % num_microbatches): forward
            for i, forward in enumerate(layer_forwards)
        }

    # Optimizer.
    optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

    # Training loop.
    for _ in range(3):
        optimizer.zero_grad(set_to_none=False)

        # Generate data.
        inputs = {}
        grad_outputs = {}
        for layer_idx in range(num_layers):
            for microbatch_idx in range(num_microbatches):
684
685
                x = generate_data(model_config, dtype)
                dy = generate_data(model_config, dtype, requires_grad=False)
686
                idxs = (layer_idx, microbatch_idx)
687
                inputs[idxs] = x
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
                grad_outputs[idxs] = dy

        # Cache for layer outputs.
        outputs = {}

        def forward(layer_idx: int, microbatch_idx: int):
            """Helper function for forward steps"""
            idxs = (layer_idx, microbatch_idx)
            outputs[idxs] = layer_forwards[idxs](inputs[idxs])

        def backward(layer_idx: int, microbatch_idx: int):
            """Helper function for backward steps"""
            outputs[layer_idx, microbatch_idx].backward(grad_outputs[layer_idx, microbatch_idx])

        # Forward and backward steps.
        forward(0, 0)
        forward(1, 0)
        forward(0, 1)
        forward(1, 1)
        backward(1, 0)
        backward(0, 0)
        forward(0, 2)
        forward(1, 2)
        backward(1, 1)
        backward(0, 1)
        backward(1, 2)
        backward(0, 2)

        # Optimizer step.
        optimizer.step()

    outputs = [y for _, y in sorted(outputs.items())]
720
721
722
723
    outputs = get_outputs(model, outputs)
    if with_graph:
        reset_graphs(layer_forwards)
    return outputs
724
725
726


def test_make_graphed_callables_with_interleaved_pipeline_parallelism(
727
728
    *,
    model_config: str = "small",
729
730
731
    dtype: torch.dtype = torch.float16,
) -> None:
    """Test CUDA graphs with Megatron-LM interleaved pipeline parallelism."""
732
733
    model_config = model_configs[model_config]
    kwargs = dict(model_config=model_config, dtype=dtype)
734
735
736
737
738
739
740
741
742
    outputs = _test_cuda_graphs_with_interleaved_pipeline_parallelism(
        with_graph=False,
        **kwargs,
    )
    graph_outputs = _test_cuda_graphs_with_interleaved_pipeline_parallelism(
        with_graph=True,
        **kwargs,
    )
    assert_all_equal(outputs, graph_outputs)