test_cuda_graphs.py 21.7 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.

5
from typing import Iterable, List, Union
6
7
8
9
import pytest

import torch
from transformer_engine.pytorch import (
10
11
12
13
14
15
16
17
    DotProductAttention,
    LayerNormLinear,
    LayerNormMLP,
    Linear,
    MultiheadAttention,
    TransformerLayer,
    fp8_autocast,
    fp8_model_init,
18
    make_graphed_callables,
19
20
21
)
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.utils import is_bf16_compatible
22
import transformer_engine.pytorch.ops as te_ops
23
from transformer_engine.common import recipe
24
from utils import ModelConfig, reset_rng_states
yuguo's avatar
yuguo committed
25
26
27
28
from torch.utils.cpp_extension import IS_HIP_EXTENSION
if IS_HIP_EXTENSION:
    import os
    from functools import cache
29

30
# Check if FP8 is supported.
31
32
33
fp8_available, _ = FP8GlobalStateManager.is_fp8_available()
fp8_block_scaling_available, _ = FP8GlobalStateManager.is_fp8_block_scaling_available()
mxfp8_available, _ = FP8GlobalStateManager.is_mxfp8_available()
34

35
36
# Reset RNG states.
reset_rng_states()
37

38
model_configs = {
39
    "small": ModelConfig(2, 32, 2, 32),
40
}
41

42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87

def nvfp4_vanilla():
    nvfp4_recipe = recipe.NVFP4BlockScaling()
    nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams()
    nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams()
    nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams()
    return nvfp4_recipe


def nvfp4_rht_and_2d_quantization():
    nvfp4_recipe = recipe.NVFP4BlockScaling()
    nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams(
        random_hadamard_transform=True, fp4_2d_quantization=False
    )
    nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams(
        random_hadamard_transform=False, fp4_2d_quantization=True
    )
    nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams(
        random_hadamard_transform=True, fp4_2d_quantization=False
    )
    return nvfp4_recipe


def check_rht_usage(recipe: recipe.Recipe) -> bool:
    # if using RHT, we can only support bf16
    # check fp4_quant_fwd_inp, fp4_quant_fwd_weight, fp4_quant_bwd_grad
    if recipe.nvfp4():
        if (
            recipe.fp4_quant_fwd_inp.random_hadamard_transform
            or recipe.fp4_quant_fwd_weight.random_hadamard_transform
            or recipe.fp4_quant_bwd_grad.random_hadamard_transform
        ):
            return True
    return False


def get_nvfp4_inp_supported_dtypes(recipe: recipe.Recipe, dtype: torch.dtype) -> bool:
    supported_input_dtypes = []
    if recipe.nvfp4():
        supported_input_dtypes.append(torch.bfloat16)
        # if not using RHT, we can add fp32 as well
    if not check_rht_usage(recipe):
        supported_input_dtypes.append(torch.float32)
    return supported_input_dtypes


88
89
90
fp8_recipes = []
if mxfp8_available:
    fp8_recipes.append(recipe.MXFP8BlockScaling())
91
    fp8_recipes.append(nvfp4_rht_and_2d_quantization())
92
93
94
95
96
if fp8_block_scaling_available:
    fp8_recipes.append(recipe.Float8BlockScaling())
if fp8_available:
    fp8_recipes.append(recipe.Float8CurrentScaling())
    fp8_recipes.append(recipe.DelayedScaling())
97

98
99
# Supported data types
dtypes: List[torch.dtype] = [torch.float32, torch.float16]
100
101
102
103
104
105
106
107
108
if is_bf16_compatible():  # bf16 requires sm_80 or higher
    dtypes.append(torch.bfloat16)


@pytest.fixture(autouse=True)
def reset_global_fp8_state():
    yield
    FP8GlobalStateManager.reset()

yuguo's avatar
yuguo committed
109
110
111
112
113
114
115
116
117
if IS_HIP_EXTENSION:
    @cache
    def use_hipblaslt() -> bool:
        return (os.getenv("NVTE_USE_HIPBLASLT") is not None
                or os.getenv("NVTE_USE_ROCBLAS") is None )
    @pytest.fixture(autouse=True)
    def skip_rocblas():
        if not use_hipblaslt():
            pytest.skip("CUDA graph capture not supported with rocBLAS path")
118
119

def assert_all_equal(l1: List[torch.Tensor], l2: List[torch.Tensor], names=None) -> bool:
120
    """Check that two lists of tensors match exactly."""
121
    assert len(l1) == len(l2), "Unequal number of outputs."
122
123
    failure_message = "Output mismatches in:"
    failed_tensors = []
124
125
    for i, (t1, t2) in enumerate(zip(l1, l2)):
        if not torch.equal(t1, t2):
126
127
128
129
130
131
132
133
134
135
            failure_message += "\n    "
            if names is None:
                failure_message += f"tensor at idx={i}"
            else:
                failure_message += names[i]
            failed_tensors.append((t1, t2))
    if failed_tensors:
        print(failure_message)
        t1, t2 = failed_tensors[0]
        torch.testing.assert_close(t1, t2, rtol=0, atol=0)
136
137
138


def generate_data(
139
    model_config: ModelConfig,
140
141
    dtype: torch.dtype,
    warmup: bool = False,
142
143
    requires_grad: bool = True,
) -> torch.Tensor:
144
145
    """Generate synthetic data."""
    gen_func = torch.ones if warmup else torch.randn
146
    return gen_func(
147
        model_config.max_seqlen_q,
148
149
        model_config.batch_size,
        model_config.hidden_size,
150
        device="cuda",
151
        requires_grad=requires_grad,
152
153
        dtype=dtype,
    )
154
155


156
157
158
159
def get_outputs(
    model: torch.nn.Module,
    output: Union[torch.Tensor, Iterable[torch.Tensor]],
) -> List[torch.Tensor]:
160
161
162
163
164
165
    """Return grads and params for comparsion."""
    values = []
    for param in model.parameters():
        values.append(param)
        if param.grad is not None:
            values.append(param.grad)
166
167
168
169
    if isinstance(output, torch.Tensor):
        values.append(output)
    else:
        values.extend(output)
170
171
172
    return values


173
174
175
176
177
178
179
180
181
182
class _Sequential(torch.nn.Sequential):
    """Sequential model that forwards keyword arguments to modules"""

    def forward(self, input_: torch.Tensor, **kwargs) -> torch.Tensor:
        x = input_
        for module in self:
            x = module(x, **kwargs)
        return x


183
184
# Supported modules
_test_cuda_graphs_modules: List[str] = [
185
186
187
    # Put linear first to test the case where the cuda context might not be set in
    # creating TMA descriptor for MXFP8 quantization.
    "linear",
188
189
190
191
192
193
194
195
    "transformer",
    "layernorm_mlp",
    "layernorm_linear",
    "mha",
    "linear_op",
]


196
197
def _test_cuda_graphs(
    *,
198
199
200
    graph_mode: str,
    module: str,
    model_config: ModelConfig,
201
202
203
204
205
    num_layers: int,
    dtype: torch.dtype,
    fp8: bool,
    fp8_params: bool,
    fp8_weight_caching: bool,
206
    fp8_recipe: recipe.Recipe,
207
) -> List[torch.Tensor]:
208
    """Helper function for CUDA graph test."""
209
210
211
    reset_rng_states()
    FP8GlobalStateManager.reset()

212
213
214
215
216
    # Operation-based API does not support FP8 weight caching.
    if module == "linear_op":
        fp8_weight_caching = False

    # Create modules.
217
    with fp8_model_init(enabled=fp8_params, recipe=fp8_recipe):
218
        if module == "transformer":
219
220
            modules = [
                TransformerLayer(
221
222
223
                    model_config.hidden_size,
                    model_config.hidden_size,
                    model_config.num_heads,
224
225
226
227
228
229
230
                    hidden_dropout=0.0,
                    attention_dropout=0.0,
                    fuse_qkv_params=True,
                    params_dtype=dtype,
                )
                for _ in range(num_layers)
            ]
231
        elif module == "layernorm_mlp":
232
            modules = [
233
234
235
236
237
                LayerNormMLP(
                    model_config.hidden_size,
                    model_config.hidden_size,
                    params_dtype=dtype,
                )
238
239
                for _ in range(num_layers)
            ]
240
        elif module == "layernorm_linear":
241
            modules = [
242
243
244
245
246
                LayerNormLinear(
                    model_config.hidden_size,
                    model_config.hidden_size,
                    params_dtype=dtype,
                )
247
248
                for _ in range(num_layers)
            ]
249
        elif module == "mha":
250
251
            modules = [
                MultiheadAttention(
252
253
                    model_config.hidden_size,
                    model_config.num_heads,
254
255
256
257
258
259
                    attention_dropout=0.0,
                    params_dtype=dtype,
                    fuse_qkv_params=True,
                )
                for _ in range(num_layers)
            ]
260
        elif module == "linear":
261
            modules = [
262
263
264
265
266
267
                Linear(
                    model_config.hidden_size,
                    model_config.hidden_size,
                    device="cuda",
                    params_dtype=dtype,
                )
268
269
                for _ in range(num_layers)
            ]
270
        elif module == "linear_op":
271
            modules = [
272
273
274
275
276
277
278
                te_ops.Sequential(
                    te_ops.Linear(
                        model_config.hidden_size,
                        model_config.hidden_size,
                        dtype=dtype,
                    ),
                )
279
280
                for _ in range(num_layers)
            ]
281
282
        else:
            raise ValueError(f"Unknown module type ({module})")
283

284
285
286
287
288
        # Initialize gradient buffers.
        for module in modules:
            for param in module.parameters():
                param.grad = torch.empty_like(param)

289
        # Generate model and wrap API to return graphed version.
290
291
        if graph_mode == "full":
            # Graph entire model at once.
292
            model = torch.nn.Sequential(*modules)
293
294
            model = make_graphed_callables(
                model,
295
                (generate_data(model_config, dtype, warmup=True),),
296
297
298
                num_warmup_iters=10,
                fp8_enabled=fp8,
                fp8_weight_caching=fp8_weight_caching,
299
                fp8_recipe=fp8_recipe,
300
301
            )
        elif graph_mode == "individual":
302
            # Graph individual modules.
303
304
            modules = [
                make_graphed_callables(
305
                    module,
306
                    (generate_data(model_config, dtype, warmup=True),),
307
                    num_warmup_iters=10,
308
309
                    fp8_enabled=fp8,
                    fp8_weight_caching=fp8_weight_caching,
310
                    fp8_recipe=fp8_recipe,
311
312
313
                )
                for module in modules
            ]
314
            model = _Sequential(*modules)
315
        else:
316
            model = _Sequential(*modules)
317

318
    # Optimizer.
319
    optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
320

321
    # Training steps.
322
    for _ in range(3):
323
        optimizer.zero_grad(set_to_none=False)
324
        for grad_accumulation_step in range(2):
325
326
            input_ = generate_data(model_config, dtype)
            grad_output = generate_data(model_config, dtype, requires_grad=False)
327
            with fp8_autocast(enabled=fp8, fp8_recipe=fp8_recipe):
328
329
                kwargs = {}
                if fp8_weight_caching:
330
                    kwargs["is_first_microbatch"] = grad_accumulation_step == 0
331
                output = model(input_, **kwargs)
332
            output.backward(grad_output)
333
        optimizer.step()
334
335
336
337

    return get_outputs(model, output)


338
@pytest.mark.parametrize("module", _test_cuda_graphs_modules)
339
@pytest.mark.parametrize("dtype", dtypes)
340
@pytest.mark.parametrize("fp8_params", (False, True))
341
@pytest.mark.parametrize("fp8_recipe", fp8_recipes + [None], ids=lambda r: type(r).__name__)
342
343
344
345
346
def test_make_graphed_callables(
    *,
    module: str,
    model_config: str = "small",
    num_layers: int = 3,
347
348
    dtype: torch.dtype,
    fp8_params: bool,
349
    fp8_recipe: recipe.Recipe,
350
    fp8_weight_caching: bool = False,
351
) -> None:
352

353
    fp8 = fp8_recipe is not None
354
355
    if fp8_params and not fp8:
        pytest.skip("FP8 needed for FP8 parameters.")
356
357
    if fp8_weight_caching and not fp8:
        pytest.skip("FP8 needed for FP8 parameters.")
358
359
360
361
362
363
364
365
366
367
368
369
    if fp8 and (fp8_recipe.float8_block_scaling() or fp8_recipe.nvfp4()) and module == "linear_op":
        pytest.skip(
            f"Module not yet supported for {fp8_recipe.__class__.__name__} with CUDA graphs"
        )
    if fp8 and fp8_recipe.nvfp4():
        if dtype not in get_nvfp4_inp_supported_dtypes(fp8_recipe, dtype):
            pytest.skip(
                f"Input dtype {dtype} not supported for NVFP4 Recipe"
                f" {fp8_recipe.__class__.__name__}"
            )
        if fp8_params:
            pytest.skip("NVFP4 params not supported")
370

371
372
    # Run model with different CUDA graph settings.
    model_config = model_configs[model_config]
373
    kwargs = dict(
374
375
        module=module,
        model_config=model_config,
376
377
378
379
380
        num_layers=num_layers,
        dtype=dtype,
        fp8=fp8,
        fp8_params=fp8_params,
        fp8_weight_caching=fp8_weight_caching,
381
        fp8_recipe=fp8_recipe,
382
    )
383
384
    # Put graphed callables first to test the case where the cuda context might not be set in
    # creating TMA descriptor for MXFP8 quantization.
385
386
    graph_outputs_mode1 = _test_cuda_graphs(graph_mode="full", **kwargs)
    graph_outputs_mode2 = _test_cuda_graphs(graph_mode="individual", **kwargs)
387
    outputs = _test_cuda_graphs(graph_mode="none", **kwargs)
388

389
    # Check that results match.
390
391
    assert_all_equal(outputs, graph_outputs_mode1)
    assert_all_equal(outputs, graph_outputs_mode2)
392
393


394
395
396
397
398
399
400
401
402
403
404
405
406
_test_make_graphed_callables_with_fp8_weight_caching_modules = [
    "transformer",
    "layernorm_mlp",
    "layernorm_linear",
    "linear",
    "mha",
]


@pytest.mark.parametrize(
    "module",
    _test_make_graphed_callables_with_fp8_weight_caching_modules,
)
407
@pytest.mark.parametrize("dtype", dtypes)
408
@pytest.mark.parametrize("fp8_params", (False, True))
409
@pytest.mark.parametrize("fp8_recipe", fp8_recipes, ids=lambda r: type(r).__name__)
410
def test_make_graphed_callables_with_fp8_weight_caching(
411
    *,
412
    module: str,
413
    dtype: torch.dtype,
414
    fp8_params: bool,
415
    fp8_recipe: recipe.Recipe,
416
417
418
) -> None:
    test_make_graphed_callables(
        module=module,
419
        dtype=dtype,
420
        fp8_params=fp8_params,
421
        fp8_recipe=fp8_recipe,
422
423
424
425
426
427
        fp8_weight_caching=True,
    )


def generate_data_for_dot_product_attention(
    model_config: ModelConfig,
428
    dtype: torch.dtype,
429
430
431
432
433
434
    warmup: bool = False,
) -> List[torch.Tensor]:
    """Generate synthetic data for dot product attention."""
    gen_func = torch.ones if warmup else torch.randn
    return [
        gen_func(
435
            model_config.max_seqlen_q,
436
437
438
439
440
441
442
443
444
445
446
447
448
            model_config.batch_size,
            model_config.num_heads,
            model_config.kv_channels,
            device="cuda",
            requires_grad=True,
            dtype=dtype,
        )
        for _ in range(3)
    ]


def _test_cuda_graphs_with_dot_product_attention(
    *,
449
    with_graph: bool,
450
451
    model_config: ModelConfig,
    dtype: torch.dtype,
452
) -> List[torch.Tensor]:
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
    """Helper function for CUDA graph test."""
    reset_rng_states()
    FP8GlobalStateManager.reset()

    # Create dot product attention module.
    assert model_config.hidden_size % model_config.num_heads == 0
    model = DotProductAttention(
        model_config.num_heads,
        model_config.kv_channels,
        attention_dropout=0.0,
    )

    # Graph model if needed.
    if with_graph:
        model = make_graphed_callables(
            model,
            generate_data_for_dot_product_attention(model_config, dtype, warmup=True),
            num_warmup_iters=10,
            fp8_enabled=False,
        )

    # Forward and backward passes.
    for _ in range(3):
        inputs = generate_data_for_dot_product_attention(model_config, dtype)
        grad_output = generate_data(model_config, dtype, requires_grad=False)
        output = model(*inputs)
        output.backward(grad_output)

    return get_outputs(model, output)


@pytest.mark.parametrize("dtype", dtypes)
def test_make_graphed_callables_with_dot_product_attention(
    *,
    model_config: str = "small",
    dtype: torch.dtype,
) -> None:
    """Test CUDA graphs with dot product attention."""
    model_config = model_configs[model_config]
    kwargs = dict(model_config=model_config, dtype=dtype)
    outputs = _test_cuda_graphs_with_dot_product_attention(with_graph=False, **kwargs)
    graph_outputs = _test_cuda_graphs_with_dot_product_attention(with_graph=True, **kwargs)
    assert_all_equal(outputs, graph_outputs)


def _test_cuda_graphs_with_kwargs(
    *,
    with_graph: bool,
    model_config: ModelConfig,
    dtype: torch.dtype,
) -> List[torch.Tensor]:
    """Helper function for CUDA graph test with keyword arguments."""
505
506
507
508
    reset_rng_states()

    # Initialize model.
    model = TransformerLayer(
509
510
511
        model_config.hidden_size,
        model_config.hidden_size,
        model_config.num_heads,
512
513
514
515
516
517
518
519
520
521
522
523
524
525
        hidden_dropout=0.0,
        attention_dropout=0.0,
        self_attn_mask_type="arbitrary",
        fuse_qkv_params=True,
        params_dtype=dtype,
    )

    # Initialize gradient buffers.
    for param in model.parameters():
        param.grad = torch.empty_like(param)

    # Make graphed version of model if needed.
    if with_graph:
        attn_mask = torch.zeros(
526
527
528
            (
                model_config.batch_size,
                1,
529
530
                model_config.max_seqlen_q,
                model_config.max_seqlen_kv,
531
            ),
532
533
534
535
536
            dtype=torch.bool,
            device="cuda",
        )
        model = make_graphed_callables(
            model,
537
            (generate_data(model_config, dtype, warmup=True),),
538
539
540
541
542
543
544
545
546
547
548
            sample_kwargs=dict(attention_mask=attn_mask),
            allow_unused_input=True,
        )

    # Optimizer.
    optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

    # Training loop.
    for _ in range(3):
        optimizer.zero_grad(set_to_none=False)
        for grad_accumulation_step in range(2):
549
550
            input_ = generate_data(model_config, dtype)
            grad_output = generate_data(model_config, dtype, requires_grad=False)
551
552
            attn_mask = torch.randint(
                2,
553
554
555
                (
                    model_config.batch_size,
                    1,
556
557
                    model_config.max_seqlen_q,
                    model_config.max_seqlen_kv,
558
                ),
559
560
561
                dtype=torch.bool,
                device="cuda",
            )
562
            output = model(input_, attention_mask=attn_mask)
563
564
565
566
567
568
569
            output.backward(grad_output)
        optimizer.step()

    return get_outputs(model, output)


def test_make_graphed_callables_with_kwargs(
570
571
    *,
    model_config: str = "small",
572
573
574
    dtype: torch.dtype = torch.float32,
) -> None:
    """Test CUDA graphs with keyword arguments."""
575
576
    model_config = model_configs[model_config]
    kwargs = dict(model_config=model_config, dtype=dtype)
577
578
579
580
581
582
583
584
    outputs = _test_cuda_graphs_with_kwargs(with_graph=False, **kwargs)
    graph_outputs = _test_cuda_graphs_with_kwargs(with_graph=True, **kwargs)
    assert_all_equal(outputs, graph_outputs)


def _test_cuda_graphs_with_interleaved_pipeline_parallelism(
    *,
    with_graph: bool,
585
586
    model_config: ModelConfig,
    dtype: torch.dtype,
587
588
589
590
591
592
593
594
595
596
597
598
599
) -> List[torch.Tensor]:
    """Simulate Megatron-LM interleaved pipeline parallelism."""
    reset_rng_states()

    # Pipeline parallel configuration.
    num_layers = 2
    num_microbatches = 3
    layer_order = [1, 2, 1, 2, -2, -1, 1, 2, -2, -1, -2, -1]

    # Initialize model.
    model = torch.nn.ModuleList(
        [
            Linear(
600
601
                model_config.hidden_size,
                model_config.hidden_size,
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
                params_dtype=dtype,
            )
            for _ in range(num_layers)
        ]
    )

    # Initialize gradient buffers.
    for param in model.parameters():
        param.grad = torch.empty_like(param)

    # Make graphed version of model if needed.
    layer_forwards = {
        (i % num_layers, i // num_layers): model[i % num_layers]
        for i in range(num_layers * num_microbatches)
    }
    if with_graph:
        sample_args = tuple(
619
620
            (generate_data(model_config, dtype, warmup=True),)
            for _ in range(num_layers * num_microbatches)
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
        )
        layer_forwards = make_graphed_callables(
            tuple(model),
            sample_args,
            allow_unused_input=True,
            _order=layer_order,
        )
        layer_forwards = {
            (i // num_microbatches, i % num_microbatches): forward
            for i, forward in enumerate(layer_forwards)
        }

    # Optimizer.
    optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

    # Training loop.
    for _ in range(3):
        optimizer.zero_grad(set_to_none=False)

        # Generate data.
        inputs = {}
        grad_outputs = {}
        for layer_idx in range(num_layers):
            for microbatch_idx in range(num_microbatches):
645
646
                x = generate_data(model_config, dtype)
                dy = generate_data(model_config, dtype, requires_grad=False)
647
                idxs = (layer_idx, microbatch_idx)
648
                inputs[idxs] = x
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
                grad_outputs[idxs] = dy

        # Cache for layer outputs.
        outputs = {}

        def forward(layer_idx: int, microbatch_idx: int):
            """Helper function for forward steps"""
            idxs = (layer_idx, microbatch_idx)
            outputs[idxs] = layer_forwards[idxs](inputs[idxs])

        def backward(layer_idx: int, microbatch_idx: int):
            """Helper function for backward steps"""
            outputs[layer_idx, microbatch_idx].backward(grad_outputs[layer_idx, microbatch_idx])

        # Forward and backward steps.
        forward(0, 0)
        forward(1, 0)
        forward(0, 1)
        forward(1, 1)
        backward(1, 0)
        backward(0, 0)
        forward(0, 2)
        forward(1, 2)
        backward(1, 1)
        backward(0, 1)
        backward(1, 2)
        backward(0, 2)

        # Optimizer step.
        optimizer.step()

    outputs = [y for _, y in sorted(outputs.items())]
    return get_outputs(model, outputs)


def test_make_graphed_callables_with_interleaved_pipeline_parallelism(
685
686
    *,
    model_config: str = "small",
687
688
689
    dtype: torch.dtype = torch.float16,
) -> None:
    """Test CUDA graphs with Megatron-LM interleaved pipeline parallelism."""
690
691
    model_config = model_configs[model_config]
    kwargs = dict(model_config=model_config, dtype=dtype)
692
693
694
695
696
697
698
699
700
    outputs = _test_cuda_graphs_with_interleaved_pipeline_parallelism(
        with_graph=False,
        **kwargs,
    )
    graph_outputs = _test_cuda_graphs_with_interleaved_pipeline_parallelism(
        with_graph=True,
        **kwargs,
    )
    assert_all_equal(outputs, graph_outputs)