test_cuda_graphs.py 22 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.

5
from typing import Iterable, List, Union
6
7
8
9
import pytest

import torch
from transformer_engine.pytorch import (
10
11
12
13
14
15
    DotProductAttention,
    LayerNormLinear,
    LayerNormMLP,
    Linear,
    MultiheadAttention,
    TransformerLayer,
16
17
    autocast,
    quantized_model_init,
18
    make_graphed_callables,
19
20
21
22
    is_fp8_available,
    is_fp8_block_scaling_available,
    is_mxfp8_available,
    is_bf16_available,
23
)
24
from transformer_engine.pytorch.quantization import FP8GlobalStateManager
25
import transformer_engine.pytorch.ops as te_ops
26
from transformer_engine.common import recipe
27
from utils import ModelConfig, reset_rng_states
yuguo's avatar
yuguo committed
28
29
30
31
from torch.utils.cpp_extension import IS_HIP_EXTENSION
if IS_HIP_EXTENSION:
    import os
    from functools import cache
32

33
# Check if FP8 is supported.
34
35
36
fp8_available = is_fp8_available()
fp8_block_scaling_available = is_fp8_block_scaling_available()
mxfp8_available = is_mxfp8_available()
37

38
39
# Reset RNG states.
reset_rng_states()
40

41
model_configs = {
42
    "small": ModelConfig(2, 32, 2, 32),
43
}
44

45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90

def nvfp4_vanilla():
    nvfp4_recipe = recipe.NVFP4BlockScaling()
    nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams()
    nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams()
    nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams()
    return nvfp4_recipe


def nvfp4_rht_and_2d_quantization():
    nvfp4_recipe = recipe.NVFP4BlockScaling()
    nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams(
        random_hadamard_transform=True, fp4_2d_quantization=False
    )
    nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams(
        random_hadamard_transform=False, fp4_2d_quantization=True
    )
    nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams(
        random_hadamard_transform=True, fp4_2d_quantization=False
    )
    return nvfp4_recipe


def check_rht_usage(recipe: recipe.Recipe) -> bool:
    # if using RHT, we can only support bf16
    # check fp4_quant_fwd_inp, fp4_quant_fwd_weight, fp4_quant_bwd_grad
    if recipe.nvfp4():
        if (
            recipe.fp4_quant_fwd_inp.random_hadamard_transform
            or recipe.fp4_quant_fwd_weight.random_hadamard_transform
            or recipe.fp4_quant_bwd_grad.random_hadamard_transform
        ):
            return True
    return False


def get_nvfp4_inp_supported_dtypes(recipe: recipe.Recipe, dtype: torch.dtype) -> bool:
    supported_input_dtypes = []
    if recipe.nvfp4():
        supported_input_dtypes.append(torch.bfloat16)
        # if not using RHT, we can add fp32 as well
    if not check_rht_usage(recipe):
        supported_input_dtypes.append(torch.float32)
    return supported_input_dtypes


91
92
93
fp8_recipes = []
if mxfp8_available:
    fp8_recipes.append(recipe.MXFP8BlockScaling())
94
    fp8_recipes.append(nvfp4_rht_and_2d_quantization())
95
96
97
98
99
if fp8_block_scaling_available:
    fp8_recipes.append(recipe.Float8BlockScaling())
if fp8_available:
    fp8_recipes.append(recipe.Float8CurrentScaling())
    fp8_recipes.append(recipe.DelayedScaling())
100

101
102
# Supported data types
dtypes: List[torch.dtype] = [torch.float32, torch.float16]
103
if is_bf16_available():  # bf16 requires sm_80 or higher
104
105
106
107
108
109
110
111
    dtypes.append(torch.bfloat16)


@pytest.fixture(autouse=True)
def reset_global_fp8_state():
    yield
    FP8GlobalStateManager.reset()

yuguo's avatar
yuguo committed
112
113
114
115
116
117
118
119
120
if IS_HIP_EXTENSION:
    @cache
    def use_hipblaslt() -> bool:
        return (os.getenv("NVTE_USE_HIPBLASLT") is not None
                or os.getenv("NVTE_USE_ROCBLAS") is None )
    @pytest.fixture(autouse=True)
    def skip_rocblas():
        if not use_hipblaslt():
            pytest.skip("CUDA graph capture not supported with rocBLAS path")
121
122

def assert_all_equal(l1: List[torch.Tensor], l2: List[torch.Tensor], names=None) -> bool:
123
    """Check that two lists of tensors match exactly."""
124
    assert len(l1) == len(l2), "Unequal number of outputs."
125
126
    failure_message = "Output mismatches in:"
    failed_tensors = []
127
128
    for i, (t1, t2) in enumerate(zip(l1, l2)):
        if not torch.equal(t1, t2):
129
130
131
132
133
134
135
136
137
138
            failure_message += "\n    "
            if names is None:
                failure_message += f"tensor at idx={i}"
            else:
                failure_message += names[i]
            failed_tensors.append((t1, t2))
    if failed_tensors:
        print(failure_message)
        t1, t2 = failed_tensors[0]
        torch.testing.assert_close(t1, t2, rtol=0, atol=0)
139
140
141


def generate_data(
142
    model_config: ModelConfig,
143
144
    dtype: torch.dtype,
    warmup: bool = False,
145
146
    requires_grad: bool = True,
) -> torch.Tensor:
147
148
    """Generate synthetic data."""
    gen_func = torch.ones if warmup else torch.randn
149
    return gen_func(
150
        model_config.max_seqlen_q,
151
152
        model_config.batch_size,
        model_config.hidden_size,
153
        device="cuda",
154
        requires_grad=requires_grad,
155
156
        dtype=dtype,
    )
157
158


159
160
161
162
def get_outputs(
    model: torch.nn.Module,
    output: Union[torch.Tensor, Iterable[torch.Tensor]],
) -> List[torch.Tensor]:
163
164
165
166
167
168
    """Return grads and params for comparsion."""
    values = []
    for param in model.parameters():
        values.append(param)
        if param.grad is not None:
            values.append(param.grad)
169
170
171
172
    if isinstance(output, torch.Tensor):
        values.append(output)
    else:
        values.extend(output)
173
174
175
    return values


176
177
178
179
180
181
182
183
184
185
class _Sequential(torch.nn.Sequential):
    """Sequential model that forwards keyword arguments to modules"""

    def forward(self, input_: torch.Tensor, **kwargs) -> torch.Tensor:
        x = input_
        for module in self:
            x = module(x, **kwargs)
        return x


186
187
# Supported modules
_test_cuda_graphs_modules: List[str] = [
188
189
190
    # Put linear first to test the case where the cuda context might not be set in
    # creating TMA descriptor for MXFP8 quantization.
    "linear",
191
192
193
194
195
196
197
198
    "transformer",
    "layernorm_mlp",
    "layernorm_linear",
    "mha",
    "linear_op",
]


199
200
def _test_cuda_graphs(
    *,
201
202
203
    graph_mode: str,
    module: str,
    model_config: ModelConfig,
204
205
206
207
208
    num_layers: int,
    dtype: torch.dtype,
    fp8: bool,
    fp8_params: bool,
    fp8_weight_caching: bool,
209
    fp8_recipe: recipe.Recipe,
210
) -> List[torch.Tensor]:
211
    """Helper function for CUDA graph test."""
212
213
214
    reset_rng_states()
    FP8GlobalStateManager.reset()

215
216
217
218
219
    # Operation-based API does not support FP8 weight caching.
    if module == "linear_op":
        fp8_weight_caching = False

    # Create modules.
220
    with quantized_model_init(enabled=fp8_params, recipe=fp8_recipe):
221
        if module == "transformer":
222
223
            modules = [
                TransformerLayer(
224
225
226
                    model_config.hidden_size,
                    model_config.hidden_size,
                    model_config.num_heads,
227
228
229
230
231
232
233
                    hidden_dropout=0.0,
                    attention_dropout=0.0,
                    fuse_qkv_params=True,
                    params_dtype=dtype,
                )
                for _ in range(num_layers)
            ]
234
        elif module == "layernorm_mlp":
235
            modules = [
236
237
238
239
240
                LayerNormMLP(
                    model_config.hidden_size,
                    model_config.hidden_size,
                    params_dtype=dtype,
                )
241
242
                for _ in range(num_layers)
            ]
243
        elif module == "layernorm_linear":
244
            modules = [
245
246
247
248
249
                LayerNormLinear(
                    model_config.hidden_size,
                    model_config.hidden_size,
                    params_dtype=dtype,
                )
250
251
                for _ in range(num_layers)
            ]
252
        elif module == "mha":
253
254
            modules = [
                MultiheadAttention(
255
256
                    model_config.hidden_size,
                    model_config.num_heads,
257
258
259
260
261
262
                    attention_dropout=0.0,
                    params_dtype=dtype,
                    fuse_qkv_params=True,
                )
                for _ in range(num_layers)
            ]
263
        elif module == "linear":
264
            modules = [
265
266
267
268
269
270
                Linear(
                    model_config.hidden_size,
                    model_config.hidden_size,
                    device="cuda",
                    params_dtype=dtype,
                )
271
272
                for _ in range(num_layers)
            ]
273
        elif module == "linear_op":
274
            modules = [
275
276
277
278
279
280
281
                te_ops.Sequential(
                    te_ops.Linear(
                        model_config.hidden_size,
                        model_config.hidden_size,
                        dtype=dtype,
                    ),
                )
282
283
                for _ in range(num_layers)
            ]
284
285
        else:
            raise ValueError(f"Unknown module type ({module})")
286

287
288
289
290
291
        # Initialize gradient buffers.
        for module in modules:
            for param in module.parameters():
                param.grad = torch.empty_like(param)

292
        # Generate model and wrap API to return graphed version.
293
294
        if graph_mode == "full":
            # Graph entire model at once.
295
            model = torch.nn.Sequential(*modules)
296
297
            model = make_graphed_callables(
                model,
298
                (generate_data(model_config, dtype, warmup=True),),
299
                num_warmup_iters=10,
300
301
302
                enabled=fp8,
                cache_quantized_params=fp8_weight_caching,
                recipe=fp8_recipe,
303
304
            )
        elif graph_mode == "individual":
305
            # Graph individual modules.
306
307
            modules = [
                make_graphed_callables(
308
                    module,
309
                    (generate_data(model_config, dtype, warmup=True),),
310
                    num_warmup_iters=10,
311
312
313
                    enabled=fp8,
                    cache_quantized_params=fp8_weight_caching,
                    recipe=fp8_recipe,
314
315
316
                )
                for module in modules
            ]
317
            model = _Sequential(*modules)
318
        else:
319
            model = _Sequential(*modules)
320

321
    # Optimizer.
322
    optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
323

324
    # Training steps.
325
    for _ in range(3):
326
        optimizer.zero_grad(set_to_none=False)
327
        for grad_accumulation_step in range(2):
328
329
            input_ = generate_data(model_config, dtype)
            grad_output = generate_data(model_config, dtype, requires_grad=False)
330
            with autocast(enabled=fp8, recipe=fp8_recipe):
331
332
                kwargs = {}
                if fp8_weight_caching:
333
                    kwargs["is_first_microbatch"] = grad_accumulation_step == 0
334
                output = model(input_, **kwargs)
335
            output.backward(grad_output)
336
        optimizer.step()
337
338
339
340

    return get_outputs(model, output)


341
@pytest.mark.parametrize("module", _test_cuda_graphs_modules)
342
@pytest.mark.parametrize("dtype", dtypes)
343
@pytest.mark.parametrize("fp8_params", (False, True))
344
@pytest.mark.parametrize("fp8_recipe", fp8_recipes + [None], ids=lambda r: type(r).__name__)
345
346
347
348
349
def test_make_graphed_callables(
    *,
    module: str,
    model_config: str = "small",
    num_layers: int = 3,
350
351
    dtype: torch.dtype,
    fp8_params: bool,
352
    fp8_recipe: recipe.Recipe,
353
    fp8_weight_caching: bool = False,
354
) -> None:
355

356
    fp8 = fp8_recipe is not None
357
358
    if fp8_params and not fp8:
        pytest.skip("FP8 needed for FP8 parameters.")
359
360
    if fp8_weight_caching and not fp8:
        pytest.skip("FP8 needed for FP8 parameters.")
361
362
363
364
365
366
367
368
369
370
371
372
    if fp8 and (fp8_recipe.float8_block_scaling() or fp8_recipe.nvfp4()) and module == "linear_op":
        pytest.skip(
            f"Module not yet supported for {fp8_recipe.__class__.__name__} with CUDA graphs"
        )
    if fp8 and fp8_recipe.nvfp4():
        if dtype not in get_nvfp4_inp_supported_dtypes(fp8_recipe, dtype):
            pytest.skip(
                f"Input dtype {dtype} not supported for NVFP4 Recipe"
                f" {fp8_recipe.__class__.__name__}"
            )
        if fp8_params:
            pytest.skip("NVFP4 params not supported")
373
    if fp8 and not fp8_available:
wenjh's avatar
wenjh committed
374
        pytest.skip("FP8 not supported on rocm GPU.")
375
    if fp8 and fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
wenjh's avatar
wenjh committed
376
        pytest.skip("FP8 block scaling not supported on rocm GPU.")
377
    if fp8 and fp8_recipe.mxfp8() and not mxfp8_available:
wenjh's avatar
wenjh committed
378
379
        pytest.skip("MXFP8 not supported on rocm GPU.")
        
380
381
    # Run model with different CUDA graph settings.
    model_config = model_configs[model_config]
382
    kwargs = dict(
383
384
        module=module,
        model_config=model_config,
385
386
387
388
389
        num_layers=num_layers,
        dtype=dtype,
        fp8=fp8,
        fp8_params=fp8_params,
        fp8_weight_caching=fp8_weight_caching,
390
        fp8_recipe=fp8_recipe,
391
    )
392
393
    # Put graphed callables first to test the case where the cuda context might not be set in
    # creating TMA descriptor for MXFP8 quantization.
394
395
    graph_outputs_mode1 = _test_cuda_graphs(graph_mode="full", **kwargs)
    graph_outputs_mode2 = _test_cuda_graphs(graph_mode="individual", **kwargs)
396
    outputs = _test_cuda_graphs(graph_mode="none", **kwargs)
397

398
    # Check that results match.
399
400
    assert_all_equal(outputs, graph_outputs_mode1)
    assert_all_equal(outputs, graph_outputs_mode2)
401
402


403
404
405
406
407
408
409
410
411
412
413
414
415
_test_make_graphed_callables_with_fp8_weight_caching_modules = [
    "transformer",
    "layernorm_mlp",
    "layernorm_linear",
    "linear",
    "mha",
]


@pytest.mark.parametrize(
    "module",
    _test_make_graphed_callables_with_fp8_weight_caching_modules,
)
416
@pytest.mark.parametrize("dtype", dtypes)
417
@pytest.mark.parametrize("fp8_params", (False, True))
418
@pytest.mark.parametrize("fp8_recipe", fp8_recipes, ids=lambda r: type(r).__name__)
419
def test_make_graphed_callables_with_fp8_weight_caching(
420
    *,
421
    module: str,
422
    dtype: torch.dtype,
423
    fp8_params: bool,
424
    fp8_recipe: recipe.Recipe,
425
426
427
) -> None:
    test_make_graphed_callables(
        module=module,
428
        dtype=dtype,
429
        fp8_params=fp8_params,
430
        fp8_recipe=fp8_recipe,
431
432
433
434
435
436
        fp8_weight_caching=True,
    )


def generate_data_for_dot_product_attention(
    model_config: ModelConfig,
437
    dtype: torch.dtype,
438
439
440
441
442
443
    warmup: bool = False,
) -> List[torch.Tensor]:
    """Generate synthetic data for dot product attention."""
    gen_func = torch.ones if warmup else torch.randn
    return [
        gen_func(
444
            model_config.max_seqlen_q,
445
446
447
448
449
450
451
452
453
454
455
456
457
            model_config.batch_size,
            model_config.num_heads,
            model_config.kv_channels,
            device="cuda",
            requires_grad=True,
            dtype=dtype,
        )
        for _ in range(3)
    ]


def _test_cuda_graphs_with_dot_product_attention(
    *,
458
    with_graph: bool,
459
460
    model_config: ModelConfig,
    dtype: torch.dtype,
461
) -> List[torch.Tensor]:
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
    """Helper function for CUDA graph test."""
    reset_rng_states()
    FP8GlobalStateManager.reset()

    # Create dot product attention module.
    assert model_config.hidden_size % model_config.num_heads == 0
    model = DotProductAttention(
        model_config.num_heads,
        model_config.kv_channels,
        attention_dropout=0.0,
    )

    # Graph model if needed.
    if with_graph:
        model = make_graphed_callables(
            model,
            generate_data_for_dot_product_attention(model_config, dtype, warmup=True),
            num_warmup_iters=10,
480
            enabled=False,
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
        )

    # Forward and backward passes.
    for _ in range(3):
        inputs = generate_data_for_dot_product_attention(model_config, dtype)
        grad_output = generate_data(model_config, dtype, requires_grad=False)
        output = model(*inputs)
        output.backward(grad_output)

    return get_outputs(model, output)


@pytest.mark.parametrize("dtype", dtypes)
def test_make_graphed_callables_with_dot_product_attention(
    *,
    model_config: str = "small",
    dtype: torch.dtype,
) -> None:
    """Test CUDA graphs with dot product attention."""
    model_config = model_configs[model_config]
    kwargs = dict(model_config=model_config, dtype=dtype)
    outputs = _test_cuda_graphs_with_dot_product_attention(with_graph=False, **kwargs)
    graph_outputs = _test_cuda_graphs_with_dot_product_attention(with_graph=True, **kwargs)
    assert_all_equal(outputs, graph_outputs)


def _test_cuda_graphs_with_kwargs(
    *,
    with_graph: bool,
    model_config: ModelConfig,
    dtype: torch.dtype,
) -> List[torch.Tensor]:
    """Helper function for CUDA graph test with keyword arguments."""
514
515
516
517
    reset_rng_states()

    # Initialize model.
    model = TransformerLayer(
518
519
520
        model_config.hidden_size,
        model_config.hidden_size,
        model_config.num_heads,
521
522
523
524
525
526
527
528
529
530
531
532
533
534
        hidden_dropout=0.0,
        attention_dropout=0.0,
        self_attn_mask_type="arbitrary",
        fuse_qkv_params=True,
        params_dtype=dtype,
    )

    # Initialize gradient buffers.
    for param in model.parameters():
        param.grad = torch.empty_like(param)

    # Make graphed version of model if needed.
    if with_graph:
        attn_mask = torch.zeros(
535
536
537
            (
                model_config.batch_size,
                1,
538
539
                model_config.max_seqlen_q,
                model_config.max_seqlen_kv,
540
            ),
541
542
543
544
545
            dtype=torch.bool,
            device="cuda",
        )
        model = make_graphed_callables(
            model,
546
            (generate_data(model_config, dtype, warmup=True),),
547
548
549
550
551
552
553
554
555
556
557
            sample_kwargs=dict(attention_mask=attn_mask),
            allow_unused_input=True,
        )

    # Optimizer.
    optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

    # Training loop.
    for _ in range(3):
        optimizer.zero_grad(set_to_none=False)
        for grad_accumulation_step in range(2):
558
559
            input_ = generate_data(model_config, dtype)
            grad_output = generate_data(model_config, dtype, requires_grad=False)
560
561
            attn_mask = torch.randint(
                2,
562
563
564
                (
                    model_config.batch_size,
                    1,
565
566
                    model_config.max_seqlen_q,
                    model_config.max_seqlen_kv,
567
                ),
568
569
570
                dtype=torch.bool,
                device="cuda",
            )
571
            output = model(input_, attention_mask=attn_mask)
572
573
574
575
576
577
578
            output.backward(grad_output)
        optimizer.step()

    return get_outputs(model, output)


def test_make_graphed_callables_with_kwargs(
579
580
    *,
    model_config: str = "small",
581
582
583
    dtype: torch.dtype = torch.float32,
) -> None:
    """Test CUDA graphs with keyword arguments."""
584
585
    model_config = model_configs[model_config]
    kwargs = dict(model_config=model_config, dtype=dtype)
586
587
588
589
590
591
592
593
    outputs = _test_cuda_graphs_with_kwargs(with_graph=False, **kwargs)
    graph_outputs = _test_cuda_graphs_with_kwargs(with_graph=True, **kwargs)
    assert_all_equal(outputs, graph_outputs)


def _test_cuda_graphs_with_interleaved_pipeline_parallelism(
    *,
    with_graph: bool,
594
595
    model_config: ModelConfig,
    dtype: torch.dtype,
596
597
598
599
600
601
602
603
604
605
606
607
608
) -> List[torch.Tensor]:
    """Simulate Megatron-LM interleaved pipeline parallelism."""
    reset_rng_states()

    # Pipeline parallel configuration.
    num_layers = 2
    num_microbatches = 3
    layer_order = [1, 2, 1, 2, -2, -1, 1, 2, -2, -1, -2, -1]

    # Initialize model.
    model = torch.nn.ModuleList(
        [
            Linear(
609
610
                model_config.hidden_size,
                model_config.hidden_size,
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
                params_dtype=dtype,
            )
            for _ in range(num_layers)
        ]
    )

    # Initialize gradient buffers.
    for param in model.parameters():
        param.grad = torch.empty_like(param)

    # Make graphed version of model if needed.
    layer_forwards = {
        (i % num_layers, i // num_layers): model[i % num_layers]
        for i in range(num_layers * num_microbatches)
    }
    if with_graph:
        sample_args = tuple(
628
629
            (generate_data(model_config, dtype, warmup=True),)
            for _ in range(num_layers * num_microbatches)
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
        )
        layer_forwards = make_graphed_callables(
            tuple(model),
            sample_args,
            allow_unused_input=True,
            _order=layer_order,
        )
        layer_forwards = {
            (i // num_microbatches, i % num_microbatches): forward
            for i, forward in enumerate(layer_forwards)
        }

    # Optimizer.
    optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

    # Training loop.
    for _ in range(3):
        optimizer.zero_grad(set_to_none=False)

        # Generate data.
        inputs = {}
        grad_outputs = {}
        for layer_idx in range(num_layers):
            for microbatch_idx in range(num_microbatches):
654
655
                x = generate_data(model_config, dtype)
                dy = generate_data(model_config, dtype, requires_grad=False)
656
                idxs = (layer_idx, microbatch_idx)
657
                inputs[idxs] = x
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
                grad_outputs[idxs] = dy

        # Cache for layer outputs.
        outputs = {}

        def forward(layer_idx: int, microbatch_idx: int):
            """Helper function for forward steps"""
            idxs = (layer_idx, microbatch_idx)
            outputs[idxs] = layer_forwards[idxs](inputs[idxs])

        def backward(layer_idx: int, microbatch_idx: int):
            """Helper function for backward steps"""
            outputs[layer_idx, microbatch_idx].backward(grad_outputs[layer_idx, microbatch_idx])

        # Forward and backward steps.
        forward(0, 0)
        forward(1, 0)
        forward(0, 1)
        forward(1, 1)
        backward(1, 0)
        backward(0, 0)
        forward(0, 2)
        forward(1, 2)
        backward(1, 1)
        backward(0, 1)
        backward(1, 2)
        backward(0, 2)

        # Optimizer step.
        optimizer.step()

    outputs = [y for _, y in sorted(outputs.items())]
    return get_outputs(model, outputs)


def test_make_graphed_callables_with_interleaved_pipeline_parallelism(
694
695
    *,
    model_config: str = "small",
696
697
698
    dtype: torch.dtype = torch.float16,
) -> None:
    """Test CUDA graphs with Megatron-LM interleaved pipeline parallelism."""
699
700
    model_config = model_configs[model_config]
    kwargs = dict(model_config=model_config, dtype=dtype)
701
702
703
704
705
706
707
708
709
    outputs = _test_cuda_graphs_with_interleaved_pipeline_parallelism(
        with_graph=False,
        **kwargs,
    )
    graph_outputs = _test_cuda_graphs_with_interleaved_pipeline_parallelism(
        with_graph=True,
        **kwargs,
    )
    assert_all_equal(outputs, graph_outputs)