test_cuda_graphs.py 22.9 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.

5
from typing import Callable, Dict, Iterable, List, Tuple, Union
6
7
8
9
import pytest

import torch
from transformer_engine.pytorch import (
10
11
12
13
14
15
    DotProductAttention,
    LayerNormLinear,
    LayerNormMLP,
    Linear,
    MultiheadAttention,
    TransformerLayer,
16
17
    autocast,
    quantized_model_init,
18
    make_graphed_callables,
19
20
21
22
    is_fp8_available,
    is_fp8_block_scaling_available,
    is_mxfp8_available,
    is_bf16_available,
23
)
24
from transformer_engine.pytorch.quantization import FP8GlobalStateManager
25
import transformer_engine.pytorch.ops as te_ops
26
from transformer_engine.common import recipe
27
from utils import ModelConfig, reset_rng_states
yuguo's avatar
yuguo committed
28
29
30
31
from torch.utils.cpp_extension import IS_HIP_EXTENSION
if IS_HIP_EXTENSION:
    import os
    from functools import cache
32

33
# Check if FP8 is supported.
34
35
36
fp8_available = is_fp8_available()
fp8_block_scaling_available = is_fp8_block_scaling_available()
mxfp8_available = is_mxfp8_available()
37

38
39
# Reset RNG states.
reset_rng_states()
40

41
model_configs = {
42
    "small": ModelConfig(2, 32, 2, 32),
43
}
44

45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90

def nvfp4_vanilla():
    nvfp4_recipe = recipe.NVFP4BlockScaling()
    nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams()
    nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams()
    nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams()
    return nvfp4_recipe


def nvfp4_rht_and_2d_quantization():
    nvfp4_recipe = recipe.NVFP4BlockScaling()
    nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams(
        random_hadamard_transform=True, fp4_2d_quantization=False
    )
    nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams(
        random_hadamard_transform=False, fp4_2d_quantization=True
    )
    nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams(
        random_hadamard_transform=True, fp4_2d_quantization=False
    )
    return nvfp4_recipe


def check_rht_usage(recipe: recipe.Recipe) -> bool:
    # if using RHT, we can only support bf16
    # check fp4_quant_fwd_inp, fp4_quant_fwd_weight, fp4_quant_bwd_grad
    if recipe.nvfp4():
        if (
            recipe.fp4_quant_fwd_inp.random_hadamard_transform
            or recipe.fp4_quant_fwd_weight.random_hadamard_transform
            or recipe.fp4_quant_bwd_grad.random_hadamard_transform
        ):
            return True
    return False


def get_nvfp4_inp_supported_dtypes(recipe: recipe.Recipe, dtype: torch.dtype) -> bool:
    supported_input_dtypes = []
    if recipe.nvfp4():
        supported_input_dtypes.append(torch.bfloat16)
        # if not using RHT, we can add fp32 as well
    if not check_rht_usage(recipe):
        supported_input_dtypes.append(torch.float32)
    return supported_input_dtypes


91
92
93
fp8_recipes = []
if mxfp8_available:
    fp8_recipes.append(recipe.MXFP8BlockScaling())
94
    fp8_recipes.append(nvfp4_rht_and_2d_quantization())
95
96
97
98
99
if fp8_block_scaling_available:
    fp8_recipes.append(recipe.Float8BlockScaling())
if fp8_available:
    fp8_recipes.append(recipe.Float8CurrentScaling())
    fp8_recipes.append(recipe.DelayedScaling())
100

101
102
# Supported data types
dtypes: List[torch.dtype] = [torch.float32, torch.float16]
103
if is_bf16_available():  # bf16 requires sm_80 or higher
104
105
106
107
108
109
110
111
    dtypes.append(torch.bfloat16)


@pytest.fixture(autouse=True)
def reset_global_fp8_state():
    yield
    FP8GlobalStateManager.reset()

yuguo's avatar
yuguo committed
112
113
114
115
116
117
118
119
120
if IS_HIP_EXTENSION:
    @cache
    def use_hipblaslt() -> bool:
        return (os.getenv("NVTE_USE_HIPBLASLT") is not None
                or os.getenv("NVTE_USE_ROCBLAS") is None )
    @pytest.fixture(autouse=True)
    def skip_rocblas():
        if not use_hipblaslt():
            pytest.skip("CUDA graph capture not supported with rocBLAS path")
121
122

def assert_all_equal(l1: List[torch.Tensor], l2: List[torch.Tensor], names=None) -> bool:
123
    """Check that two lists of tensors match exactly."""
124
    assert len(l1) == len(l2), "Unequal number of outputs."
125
126
    failure_message = "Output mismatches in:"
    failed_tensors = []
127
128
    for i, (t1, t2) in enumerate(zip(l1, l2)):
        if not torch.equal(t1, t2):
129
130
131
132
133
134
135
136
137
138
            failure_message += "\n    "
            if names is None:
                failure_message += f"tensor at idx={i}"
            else:
                failure_message += names[i]
            failed_tensors.append((t1, t2))
    if failed_tensors:
        print(failure_message)
        t1, t2 = failed_tensors[0]
        torch.testing.assert_close(t1, t2, rtol=0, atol=0)
139
140
141


def generate_data(
142
    model_config: ModelConfig,
143
144
    dtype: torch.dtype,
    warmup: bool = False,
145
146
    requires_grad: bool = True,
) -> torch.Tensor:
147
148
    """Generate synthetic data."""
    gen_func = torch.ones if warmup else torch.randn
149
    return gen_func(
150
        model_config.max_seqlen_q,
151
152
        model_config.batch_size,
        model_config.hidden_size,
153
        device="cuda",
154
        requires_grad=requires_grad,
155
156
        dtype=dtype,
    )
157
158


159
160
161
162
def get_outputs(
    model: torch.nn.Module,
    output: Union[torch.Tensor, Iterable[torch.Tensor]],
) -> List[torch.Tensor]:
163
164
165
166
167
168
    """Return grads and params for comparsion."""
    values = []
    for param in model.parameters():
        values.append(param)
        if param.grad is not None:
            values.append(param.grad)
169
170
171
172
    if isinstance(output, torch.Tensor):
        values.append(output)
    else:
        values.extend(output)
173
174
175
    return values


176
177
178
179
180
181
182
183
184
185
186
187
188
189
def reset_graphs(
    graphed_callables: Union[Callable, Tuple[Callable, ...], Dict[Tuple[int, int], Callable]],
) -> None:
    """Reset CUDA graphs."""
    if isinstance(graphed_callables, tuple) or isinstance(graphed_callables, list):
        for callable in graphed_callables:
            callable.reset()
    elif isinstance(graphed_callables, dict):
        for callable in graphed_callables.values():
            callable.reset()
    else:
        graphed_callables.reset()


190
191
192
193
194
195
196
197
198
199
class _Sequential(torch.nn.Sequential):
    """Sequential model that forwards keyword arguments to modules"""

    def forward(self, input_: torch.Tensor, **kwargs) -> torch.Tensor:
        x = input_
        for module in self:
            x = module(x, **kwargs)
        return x


200
201
# Supported modules
_test_cuda_graphs_modules: List[str] = [
202
203
204
    # Put linear first to test the case where the cuda context might not be set in
    # creating TMA descriptor for MXFP8 quantization.
    "linear",
205
206
207
208
209
210
211
212
    "transformer",
    "layernorm_mlp",
    "layernorm_linear",
    "mha",
    "linear_op",
]


213
214
def _test_cuda_graphs(
    *,
215
216
217
    graph_mode: str,
    module: str,
    model_config: ModelConfig,
218
219
220
221
222
    num_layers: int,
    dtype: torch.dtype,
    fp8: bool,
    fp8_params: bool,
    fp8_weight_caching: bool,
223
    fp8_recipe: recipe.Recipe,
224
) -> List[torch.Tensor]:
225
    """Helper function for CUDA graph test."""
226
227
228
    reset_rng_states()
    FP8GlobalStateManager.reset()

229
230
231
232
233
    # Operation-based API does not support FP8 weight caching.
    if module == "linear_op":
        fp8_weight_caching = False

    # Create modules.
234
    with quantized_model_init(enabled=fp8_params, recipe=fp8_recipe):
235
        if module == "transformer":
236
237
            modules = [
                TransformerLayer(
238
239
240
                    model_config.hidden_size,
                    model_config.hidden_size,
                    model_config.num_heads,
241
242
243
244
245
246
247
                    hidden_dropout=0.0,
                    attention_dropout=0.0,
                    fuse_qkv_params=True,
                    params_dtype=dtype,
                )
                for _ in range(num_layers)
            ]
248
        elif module == "layernorm_mlp":
249
            modules = [
250
251
252
253
254
                LayerNormMLP(
                    model_config.hidden_size,
                    model_config.hidden_size,
                    params_dtype=dtype,
                )
255
256
                for _ in range(num_layers)
            ]
257
        elif module == "layernorm_linear":
258
            modules = [
259
260
261
262
263
                LayerNormLinear(
                    model_config.hidden_size,
                    model_config.hidden_size,
                    params_dtype=dtype,
                )
264
265
                for _ in range(num_layers)
            ]
266
        elif module == "mha":
267
268
            modules = [
                MultiheadAttention(
269
270
                    model_config.hidden_size,
                    model_config.num_heads,
271
272
273
274
275
276
                    attention_dropout=0.0,
                    params_dtype=dtype,
                    fuse_qkv_params=True,
                )
                for _ in range(num_layers)
            ]
277
        elif module == "linear":
278
            modules = [
279
280
281
282
283
284
                Linear(
                    model_config.hidden_size,
                    model_config.hidden_size,
                    device="cuda",
                    params_dtype=dtype,
                )
285
286
                for _ in range(num_layers)
            ]
287
        elif module == "linear_op":
288
            modules = [
289
290
291
292
293
294
295
                te_ops.Sequential(
                    te_ops.Linear(
                        model_config.hidden_size,
                        model_config.hidden_size,
                        dtype=dtype,
                    ),
                )
296
297
                for _ in range(num_layers)
            ]
298
299
        else:
            raise ValueError(f"Unknown module type ({module})")
300

301
302
303
304
305
        # Initialize gradient buffers.
        for module in modules:
            for param in module.parameters():
                param.grad = torch.empty_like(param)

306
        # Generate model and wrap API to return graphed version.
307
308
        if graph_mode == "full":
            # Graph entire model at once.
309
            model = torch.nn.Sequential(*modules)
310
311
            model = make_graphed_callables(
                model,
312
                (generate_data(model_config, dtype, warmup=True),),
313
                num_warmup_iters=10,
314
315
316
                enabled=fp8,
                cache_quantized_params=fp8_weight_caching,
                recipe=fp8_recipe,
317
318
            )
        elif graph_mode == "individual":
319
            # Graph individual modules.
320
321
            modules = [
                make_graphed_callables(
322
                    module,
323
                    (generate_data(model_config, dtype, warmup=True),),
324
                    num_warmup_iters=10,
325
326
327
                    enabled=fp8,
                    cache_quantized_params=fp8_weight_caching,
                    recipe=fp8_recipe,
328
329
330
                )
                for module in modules
            ]
331
            model = _Sequential(*modules)
332
        else:
333
            model = _Sequential(*modules)
334

335
    # Optimizer.
336
    optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
337

338
    # Training steps.
339
    for _ in range(3):
340
        optimizer.zero_grad(set_to_none=False)
341
        for grad_accumulation_step in range(2):
342
343
            input_ = generate_data(model_config, dtype)
            grad_output = generate_data(model_config, dtype, requires_grad=False)
344
            with autocast(enabled=fp8, recipe=fp8_recipe):
345
346
                kwargs = {}
                if fp8_weight_caching:
347
                    kwargs["is_first_microbatch"] = grad_accumulation_step == 0
348
                output = model(input_, **kwargs)
349
            output.backward(grad_output)
350
        optimizer.step()
351

352
353
354
355
356
357
    outputs = get_outputs(model, output)
    if graph_mode == "full":
        reset_graphs(model)
    elif graph_mode == "individual":
        reset_graphs(modules)
    return outputs
358
359


360
@pytest.mark.parametrize("module", _test_cuda_graphs_modules)
361
@pytest.mark.parametrize("dtype", dtypes)
362
@pytest.mark.parametrize("fp8_params", (False, True))
363
@pytest.mark.parametrize("fp8_recipe", fp8_recipes + [None], ids=lambda r: type(r).__name__)
364
365
366
367
368
def test_make_graphed_callables(
    *,
    module: str,
    model_config: str = "small",
    num_layers: int = 3,
369
370
    dtype: torch.dtype,
    fp8_params: bool,
371
    fp8_recipe: recipe.Recipe,
372
    fp8_weight_caching: bool = False,
373
) -> None:
374

375
    fp8 = fp8_recipe is not None
376
377
    if fp8_params and not fp8:
        pytest.skip("FP8 needed for FP8 parameters.")
378
379
    if fp8_weight_caching and not fp8:
        pytest.skip("FP8 needed for FP8 parameters.")
380
381
382
383
384
385
386
387
388
389
390
391
    if fp8 and (fp8_recipe.float8_block_scaling() or fp8_recipe.nvfp4()) and module == "linear_op":
        pytest.skip(
            f"Module not yet supported for {fp8_recipe.__class__.__name__} with CUDA graphs"
        )
    if fp8 and fp8_recipe.nvfp4():
        if dtype not in get_nvfp4_inp_supported_dtypes(fp8_recipe, dtype):
            pytest.skip(
                f"Input dtype {dtype} not supported for NVFP4 Recipe"
                f" {fp8_recipe.__class__.__name__}"
            )
        if fp8_params:
            pytest.skip("NVFP4 params not supported")
392
    if fp8 and not fp8_available:
wenjh's avatar
wenjh committed
393
        pytest.skip("FP8 not supported on rocm GPU.")
394
    if fp8 and fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
wenjh's avatar
wenjh committed
395
        pytest.skip("FP8 block scaling not supported on rocm GPU.")
396
    if fp8 and fp8_recipe.mxfp8() and not mxfp8_available:
wenjh's avatar
wenjh committed
397
398
        pytest.skip("MXFP8 not supported on rocm GPU.")
        
399
400
    # Run model with different CUDA graph settings.
    model_config = model_configs[model_config]
401
    kwargs = dict(
402
403
        module=module,
        model_config=model_config,
404
405
406
407
408
        num_layers=num_layers,
        dtype=dtype,
        fp8=fp8,
        fp8_params=fp8_params,
        fp8_weight_caching=fp8_weight_caching,
409
        fp8_recipe=fp8_recipe,
410
    )
411
412
    # Put graphed callables first to test the case where the cuda context might not be set in
    # creating TMA descriptor for MXFP8 quantization.
413
414
    graph_outputs_mode1 = _test_cuda_graphs(graph_mode="full", **kwargs)
    graph_outputs_mode2 = _test_cuda_graphs(graph_mode="individual", **kwargs)
415
    outputs = _test_cuda_graphs(graph_mode="none", **kwargs)
416

417
    # Check that results match.
418
419
    assert_all_equal(outputs, graph_outputs_mode1)
    assert_all_equal(outputs, graph_outputs_mode2)
420
421


422
423
424
425
426
427
428
429
430
431
432
433
434
_test_make_graphed_callables_with_fp8_weight_caching_modules = [
    "transformer",
    "layernorm_mlp",
    "layernorm_linear",
    "linear",
    "mha",
]


@pytest.mark.parametrize(
    "module",
    _test_make_graphed_callables_with_fp8_weight_caching_modules,
)
435
@pytest.mark.parametrize("dtype", dtypes)
436
@pytest.mark.parametrize("fp8_params", (False, True))
437
@pytest.mark.parametrize("fp8_recipe", fp8_recipes, ids=lambda r: type(r).__name__)
438
def test_make_graphed_callables_with_fp8_weight_caching(
439
    *,
440
    module: str,
441
    dtype: torch.dtype,
442
    fp8_params: bool,
443
    fp8_recipe: recipe.Recipe,
444
445
446
) -> None:
    test_make_graphed_callables(
        module=module,
447
        dtype=dtype,
448
        fp8_params=fp8_params,
449
        fp8_recipe=fp8_recipe,
450
451
452
453
454
455
        fp8_weight_caching=True,
    )


def generate_data_for_dot_product_attention(
    model_config: ModelConfig,
456
    dtype: torch.dtype,
457
458
459
460
461
462
    warmup: bool = False,
) -> List[torch.Tensor]:
    """Generate synthetic data for dot product attention."""
    gen_func = torch.ones if warmup else torch.randn
    return [
        gen_func(
463
            model_config.max_seqlen_q,
464
465
466
467
468
469
470
471
472
473
474
475
476
            model_config.batch_size,
            model_config.num_heads,
            model_config.kv_channels,
            device="cuda",
            requires_grad=True,
            dtype=dtype,
        )
        for _ in range(3)
    ]


def _test_cuda_graphs_with_dot_product_attention(
    *,
477
    with_graph: bool,
478
479
    model_config: ModelConfig,
    dtype: torch.dtype,
480
) -> List[torch.Tensor]:
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
    """Helper function for CUDA graph test."""
    reset_rng_states()
    FP8GlobalStateManager.reset()

    # Create dot product attention module.
    assert model_config.hidden_size % model_config.num_heads == 0
    model = DotProductAttention(
        model_config.num_heads,
        model_config.kv_channels,
        attention_dropout=0.0,
    )

    # Graph model if needed.
    if with_graph:
        model = make_graphed_callables(
            model,
            generate_data_for_dot_product_attention(model_config, dtype, warmup=True),
            num_warmup_iters=10,
499
            enabled=False,
500
501
502
503
504
505
506
507
508
        )

    # Forward and backward passes.
    for _ in range(3):
        inputs = generate_data_for_dot_product_attention(model_config, dtype)
        grad_output = generate_data(model_config, dtype, requires_grad=False)
        output = model(*inputs)
        output.backward(grad_output)

509
510
511
512
    outputs = get_outputs(model, output)
    if with_graph:
        reset_graphs(model)
    return outputs
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535


@pytest.mark.parametrize("dtype", dtypes)
def test_make_graphed_callables_with_dot_product_attention(
    *,
    model_config: str = "small",
    dtype: torch.dtype,
) -> None:
    """Test CUDA graphs with dot product attention."""
    model_config = model_configs[model_config]
    kwargs = dict(model_config=model_config, dtype=dtype)
    outputs = _test_cuda_graphs_with_dot_product_attention(with_graph=False, **kwargs)
    graph_outputs = _test_cuda_graphs_with_dot_product_attention(with_graph=True, **kwargs)
    assert_all_equal(outputs, graph_outputs)


def _test_cuda_graphs_with_kwargs(
    *,
    with_graph: bool,
    model_config: ModelConfig,
    dtype: torch.dtype,
) -> List[torch.Tensor]:
    """Helper function for CUDA graph test with keyword arguments."""
536
537
538
539
    reset_rng_states()

    # Initialize model.
    model = TransformerLayer(
540
541
542
        model_config.hidden_size,
        model_config.hidden_size,
        model_config.num_heads,
543
544
545
546
547
548
549
550
551
552
553
554
555
556
        hidden_dropout=0.0,
        attention_dropout=0.0,
        self_attn_mask_type="arbitrary",
        fuse_qkv_params=True,
        params_dtype=dtype,
    )

    # Initialize gradient buffers.
    for param in model.parameters():
        param.grad = torch.empty_like(param)

    # Make graphed version of model if needed.
    if with_graph:
        attn_mask = torch.zeros(
557
558
559
            (
                model_config.batch_size,
                1,
560
561
                model_config.max_seqlen_q,
                model_config.max_seqlen_kv,
562
            ),
563
564
565
566
567
            dtype=torch.bool,
            device="cuda",
        )
        model = make_graphed_callables(
            model,
568
            (generate_data(model_config, dtype, warmup=True),),
569
570
571
572
573
574
575
576
577
578
579
            sample_kwargs=dict(attention_mask=attn_mask),
            allow_unused_input=True,
        )

    # Optimizer.
    optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

    # Training loop.
    for _ in range(3):
        optimizer.zero_grad(set_to_none=False)
        for grad_accumulation_step in range(2):
580
581
            input_ = generate_data(model_config, dtype)
            grad_output = generate_data(model_config, dtype, requires_grad=False)
582
583
            attn_mask = torch.randint(
                2,
584
585
586
                (
                    model_config.batch_size,
                    1,
587
588
                    model_config.max_seqlen_q,
                    model_config.max_seqlen_kv,
589
                ),
590
591
592
                dtype=torch.bool,
                device="cuda",
            )
593
            output = model(input_, attention_mask=attn_mask)
594
595
596
            output.backward(grad_output)
        optimizer.step()

597
598
599
600
    outputs = get_outputs(model, output)
    if with_graph:
        reset_graphs(model)
    return outputs
601
602
603


def test_make_graphed_callables_with_kwargs(
604
605
    *,
    model_config: str = "small",
606
607
608
    dtype: torch.dtype = torch.float32,
) -> None:
    """Test CUDA graphs with keyword arguments."""
609
610
    model_config = model_configs[model_config]
    kwargs = dict(model_config=model_config, dtype=dtype)
611
612
613
614
615
616
617
618
    outputs = _test_cuda_graphs_with_kwargs(with_graph=False, **kwargs)
    graph_outputs = _test_cuda_graphs_with_kwargs(with_graph=True, **kwargs)
    assert_all_equal(outputs, graph_outputs)


def _test_cuda_graphs_with_interleaved_pipeline_parallelism(
    *,
    with_graph: bool,
619
620
    model_config: ModelConfig,
    dtype: torch.dtype,
621
622
623
624
625
626
627
628
629
630
631
632
633
) -> List[torch.Tensor]:
    """Simulate Megatron-LM interleaved pipeline parallelism."""
    reset_rng_states()

    # Pipeline parallel configuration.
    num_layers = 2
    num_microbatches = 3
    layer_order = [1, 2, 1, 2, -2, -1, 1, 2, -2, -1, -2, -1]

    # Initialize model.
    model = torch.nn.ModuleList(
        [
            Linear(
634
635
                model_config.hidden_size,
                model_config.hidden_size,
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
                params_dtype=dtype,
            )
            for _ in range(num_layers)
        ]
    )

    # Initialize gradient buffers.
    for param in model.parameters():
        param.grad = torch.empty_like(param)

    # Make graphed version of model if needed.
    layer_forwards = {
        (i % num_layers, i // num_layers): model[i % num_layers]
        for i in range(num_layers * num_microbatches)
    }
    if with_graph:
        sample_args = tuple(
653
654
            (generate_data(model_config, dtype, warmup=True),)
            for _ in range(num_layers * num_microbatches)
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
        )
        layer_forwards = make_graphed_callables(
            tuple(model),
            sample_args,
            allow_unused_input=True,
            _order=layer_order,
        )
        layer_forwards = {
            (i // num_microbatches, i % num_microbatches): forward
            for i, forward in enumerate(layer_forwards)
        }

    # Optimizer.
    optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

    # Training loop.
    for _ in range(3):
        optimizer.zero_grad(set_to_none=False)

        # Generate data.
        inputs = {}
        grad_outputs = {}
        for layer_idx in range(num_layers):
            for microbatch_idx in range(num_microbatches):
679
680
                x = generate_data(model_config, dtype)
                dy = generate_data(model_config, dtype, requires_grad=False)
681
                idxs = (layer_idx, microbatch_idx)
682
                inputs[idxs] = x
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
                grad_outputs[idxs] = dy

        # Cache for layer outputs.
        outputs = {}

        def forward(layer_idx: int, microbatch_idx: int):
            """Helper function for forward steps"""
            idxs = (layer_idx, microbatch_idx)
            outputs[idxs] = layer_forwards[idxs](inputs[idxs])

        def backward(layer_idx: int, microbatch_idx: int):
            """Helper function for backward steps"""
            outputs[layer_idx, microbatch_idx].backward(grad_outputs[layer_idx, microbatch_idx])

        # Forward and backward steps.
        forward(0, 0)
        forward(1, 0)
        forward(0, 1)
        forward(1, 1)
        backward(1, 0)
        backward(0, 0)
        forward(0, 2)
        forward(1, 2)
        backward(1, 1)
        backward(0, 1)
        backward(1, 2)
        backward(0, 2)

        # Optimizer step.
        optimizer.step()

    outputs = [y for _, y in sorted(outputs.items())]
715
716
717
718
    outputs = get_outputs(model, outputs)
    if with_graph:
        reset_graphs(layer_forwards)
    return outputs
719
720
721


def test_make_graphed_callables_with_interleaved_pipeline_parallelism(
722
723
    *,
    model_config: str = "small",
724
725
726
    dtype: torch.dtype = torch.float16,
) -> None:
    """Test CUDA graphs with Megatron-LM interleaved pipeline parallelism."""
727
728
    model_config = model_configs[model_config]
    kwargs = dict(model_config=model_config, dtype=dtype)
729
730
731
732
733
734
735
736
737
    outputs = _test_cuda_graphs_with_interleaved_pipeline_parallelism(
        with_graph=False,
        **kwargs,
    )
    graph_outputs = _test_cuda_graphs_with_interleaved_pipeline_parallelism(
        with_graph=True,
        **kwargs,
    )
    assert_all_equal(outputs, graph_outputs)