test_cuda_graphs.py 19.3 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.

5
from dataclasses import dataclass
6
7
import itertools
from typing import Iterable, List, Tuple, Union
8
9
10
11
import pytest

import torch
from transformer_engine.pytorch import (
12
13
14
15
16
17
18
19
    DotProductAttention,
    LayerNormLinear,
    LayerNormMLP,
    Linear,
    MultiheadAttention,
    TransformerLayer,
    fp8_autocast,
    fp8_model_init,
20
    make_graphed_callables,
21
22
23
)
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.utils import is_bf16_compatible
24
import transformer_engine.pytorch.ops as te_ops
25
from transformer_engine.common import recipe
26
from utils import ModelConfig, reset_rng_states
27

28
# Check if FP8 is supported.
29
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
30
31
32
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = (
    FP8GlobalStateManager.is_fp8_block_scaling_available()
)
33
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
34

35
36
# Reset RNG states.
reset_rng_states()
37

38
39
40
model_configs = {
    "small": ModelConfig(32, 2, 2, 32),
}
41

42
43
44
fp8_recipes = [
    recipe.DelayedScaling(),
    recipe.MXFP8BlockScaling(),
45
    recipe.Float8CurrentScaling(),
46
    recipe.Float8BlockScaling(),
47
48
]

49
50
# Supported data types
dtypes: List[torch.dtype] = [torch.float32, torch.float16]
51
52
53
54
55
56
57
58
59
60
61
if is_bf16_compatible():  # bf16 requires sm_80 or higher
    dtypes.append(torch.bfloat16)


@pytest.fixture(autouse=True)
def reset_global_fp8_state():
    yield
    FP8GlobalStateManager.reset()


def assert_all_equal(l1: List[torch.Tensor], l2: List[torch.Tensor], names=None) -> bool:
62
    """Check that two lists of tensors match exactly."""
63
    assert len(l1) == len(l2), "Unequal number of outputs."
64
65
    failure_message = "Output mismatches in:"
    failed_tensors = []
66
67
    for i, (t1, t2) in enumerate(zip(l1, l2)):
        if not torch.equal(t1, t2):
68
69
70
71
72
73
74
75
76
77
            failure_message += "\n    "
            if names is None:
                failure_message += f"tensor at idx={i}"
            else:
                failure_message += names[i]
            failed_tensors.append((t1, t2))
    if failed_tensors:
        print(failure_message)
        t1, t2 = failed_tensors[0]
        torch.testing.assert_close(t1, t2, rtol=0, atol=0)
78
79
80


def generate_data(
81
    model_config: ModelConfig,
82
83
    dtype: torch.dtype,
    warmup: bool = False,
84
85
    requires_grad: bool = True,
) -> torch.Tensor:
86
87
    """Generate synthetic data."""
    gen_func = torch.ones if warmup else torch.randn
88
    return gen_func(
89
        model_config.max_seqlen_q,
90
91
        model_config.batch_size,
        model_config.hidden_size,
92
        device="cuda",
93
        requires_grad=requires_grad,
94
95
        dtype=dtype,
    )
96
97


98
99
100
101
def get_outputs(
    model: torch.nn.Module,
    output: Union[torch.Tensor, Iterable[torch.Tensor]],
) -> List[torch.Tensor]:
102
103
104
105
106
107
    """Return grads and params for comparsion."""
    values = []
    for param in model.parameters():
        values.append(param)
        if param.grad is not None:
            values.append(param.grad)
108
109
110
111
    if isinstance(output, torch.Tensor):
        values.append(output)
    else:
        values.extend(output)
112
113
114
    return values


115
116
117
118
119
120
121
122
123
124
class _Sequential(torch.nn.Sequential):
    """Sequential model that forwards keyword arguments to modules"""

    def forward(self, input_: torch.Tensor, **kwargs) -> torch.Tensor:
        x = input_
        for module in self:
            x = module(x, **kwargs)
        return x


125
126
127
128
129
130
131
132
133
134
135
# Supported modules
_test_cuda_graphs_modules: List[str] = [
    "transformer",
    "layernorm_mlp",
    "layernorm_linear",
    "linear",
    "mha",
    "linear_op",
]


136
137
def _test_cuda_graphs(
    *,
138
139
140
    graph_mode: str,
    module: str,
    model_config: ModelConfig,
141
142
143
144
145
    num_layers: int,
    dtype: torch.dtype,
    fp8: bool,
    fp8_params: bool,
    fp8_weight_caching: bool,
146
    fp8_recipe: recipe.Recipe,
147
) -> List[torch.Tensor]:
148
    """Helper function for CUDA graph test."""
149
150
151
    reset_rng_states()
    FP8GlobalStateManager.reset()

152
153
154
155
156
    # Operation-based API does not support FP8 weight caching.
    if module == "linear_op":
        fp8_weight_caching = False

    # Create modules.
157
    with fp8_model_init(enabled=fp8_params, recipe=fp8_recipe):
158
        if module == "transformer":
159
160
            modules = [
                TransformerLayer(
161
162
163
                    model_config.hidden_size,
                    model_config.hidden_size,
                    model_config.num_heads,
164
165
166
167
168
169
170
                    hidden_dropout=0.0,
                    attention_dropout=0.0,
                    fuse_qkv_params=True,
                    params_dtype=dtype,
                )
                for _ in range(num_layers)
            ]
171
        elif module == "layernorm_mlp":
172
            modules = [
173
174
175
176
177
                LayerNormMLP(
                    model_config.hidden_size,
                    model_config.hidden_size,
                    params_dtype=dtype,
                )
178
179
                for _ in range(num_layers)
            ]
180
        elif module == "layernorm_linear":
181
            modules = [
182
183
184
185
186
                LayerNormLinear(
                    model_config.hidden_size,
                    model_config.hidden_size,
                    params_dtype=dtype,
                )
187
188
                for _ in range(num_layers)
            ]
189
        elif module == "mha":
190
191
            modules = [
                MultiheadAttention(
192
193
                    model_config.hidden_size,
                    model_config.num_heads,
194
195
196
197
198
199
                    attention_dropout=0.0,
                    params_dtype=dtype,
                    fuse_qkv_params=True,
                )
                for _ in range(num_layers)
            ]
200
        elif module == "linear":
201
            modules = [
202
203
204
205
206
207
                Linear(
                    model_config.hidden_size,
                    model_config.hidden_size,
                    device="cuda",
                    params_dtype=dtype,
                )
208
209
                for _ in range(num_layers)
            ]
210
        elif module == "linear_op":
211
            modules = [
212
213
214
215
216
217
218
                te_ops.Sequential(
                    te_ops.Linear(
                        model_config.hidden_size,
                        model_config.hidden_size,
                        dtype=dtype,
                    ),
                )
219
220
                for _ in range(num_layers)
            ]
221
222
        else:
            raise ValueError(f"Unknown module type ({module})")
223

224
225
226
227
228
        # Initialize gradient buffers.
        for module in modules:
            for param in module.parameters():
                param.grad = torch.empty_like(param)

229
        # Generate model and wrap API to return graphed version.
230
231
        if graph_mode == "full":
            # Graph entire model at once.
232
            model = torch.nn.Sequential(*modules)
233
234
            model = make_graphed_callables(
                model,
235
                (generate_data(model_config, dtype, warmup=True),),
236
237
238
                num_warmup_iters=10,
                fp8_enabled=fp8,
                fp8_weight_caching=fp8_weight_caching,
239
                fp8_recipe=fp8_recipe,
240
241
            )
        elif graph_mode == "individual":
242
            # Graph individual modules.
243
244
            modules = [
                make_graphed_callables(
245
                    module,
246
                    (generate_data(model_config, dtype, warmup=True),),
247
                    num_warmup_iters=10,
248
249
                    fp8_enabled=fp8,
                    fp8_weight_caching=fp8_weight_caching,
250
                    fp8_recipe=fp8_recipe,
251
252
253
                )
                for module in modules
            ]
254
            model = _Sequential(*modules)
255
        else:
256
            model = _Sequential(*modules)
257

258
    # Optimizer.
259
    optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
260

261
    # Training steps.
262
    for _ in range(3):
263
        optimizer.zero_grad(set_to_none=False)
264
        for grad_accumulation_step in range(2):
265
266
            input_ = generate_data(model_config, dtype)
            grad_output = generate_data(model_config, dtype, requires_grad=False)
267
            with fp8_autocast(enabled=fp8, fp8_recipe=fp8_recipe):
268
269
                kwargs = {}
                if fp8_weight_caching:
270
                    kwargs["is_first_microbatch"] = grad_accumulation_step == 0
271
                output = model(input_, **kwargs)
272
            output.backward(grad_output)
273
        optimizer.step()
274
275
276
277

    return get_outputs(model, output)


278
@pytest.mark.parametrize("module", _test_cuda_graphs_modules)
279
@pytest.mark.parametrize("dtype", dtypes)
280
281
@pytest.mark.parametrize("fp8", (False, True))
@pytest.mark.parametrize("fp8_params", (False, True))
282
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
283
284
285
286
287
def test_make_graphed_callables(
    *,
    module: str,
    model_config: str = "small",
    num_layers: int = 3,
288
289
290
    dtype: torch.dtype,
    fp8: bool,
    fp8_params: bool,
291
    fp8_recipe: recipe.Recipe,
292
    fp8_weight_caching: bool = False,
293
) -> None:
294
295

    # Skip invalid configurations.
296
297
298
299
    if fp8 and not fp8_available:
        pytest.skip(reason_for_no_fp8)
    if fp8_params and not fp8:
        pytest.skip("FP8 needed for FP8 parameters.")
300
301
    if fp8_weight_caching and not fp8:
        pytest.skip("FP8 needed for FP8 parameters.")
302
303
    if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
        pytest.skip(reason_for_no_fp8_block_scaling)
304
305
    if fp8_recipe.mxfp8() and not mxfp8_available:
        pytest.skip(reason_for_no_mxfp8)
306

307
308
    if fp8_recipe.float8_block_scaling() and module == "linear_op":
        pytest.skip("Module not yet supported for float8_block_scaling with CUDA graphs")
309
310
    # Run model with different CUDA graph settings.
    model_config = model_configs[model_config]
311
    kwargs = dict(
312
313
        module=module,
        model_config=model_config,
314
315
316
317
318
        num_layers=num_layers,
        dtype=dtype,
        fp8=fp8,
        fp8_params=fp8_params,
        fp8_weight_caching=fp8_weight_caching,
319
        fp8_recipe=fp8_recipe,
320
321
322
323
    )
    outputs = _test_cuda_graphs(graph_mode="none", **kwargs)
    graph_outputs_mode1 = _test_cuda_graphs(graph_mode="full", **kwargs)
    graph_outputs_mode2 = _test_cuda_graphs(graph_mode="individual", **kwargs)
324

325
    # Check that results match.
326
327
    assert_all_equal(outputs, graph_outputs_mode1)
    assert_all_equal(outputs, graph_outputs_mode2)
328
329


330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
_test_make_graphed_callables_with_fp8_weight_caching_modules = [
    "transformer",
    "layernorm_mlp",
    "layernorm_linear",
    "linear",
    "mha",
]


@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.parametrize(
    "module",
    _test_make_graphed_callables_with_fp8_weight_caching_modules,
)
@pytest.mark.parametrize("fp8_params", (False, True))
345
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
346
def test_make_graphed_callables_with_fp8_weight_caching(
347
    *,
348
349
    module: str,
    fp8_params: bool,
350
    fp8_recipe: recipe.Recipe,
351
352
353
354
355
356
) -> None:
    test_make_graphed_callables(
        module=module,
        dtype=torch.float32,
        fp8=True,
        fp8_params=fp8_params,
357
        fp8_recipe=fp8_recipe,
358
359
360
361
362
363
        fp8_weight_caching=True,
    )


def generate_data_for_dot_product_attention(
    model_config: ModelConfig,
364
    dtype: torch.dtype,
365
366
367
368
369
370
    warmup: bool = False,
) -> List[torch.Tensor]:
    """Generate synthetic data for dot product attention."""
    gen_func = torch.ones if warmup else torch.randn
    return [
        gen_func(
371
            model_config.max_seqlen_q,
372
373
374
375
376
377
378
379
380
381
382
383
384
            model_config.batch_size,
            model_config.num_heads,
            model_config.kv_channels,
            device="cuda",
            requires_grad=True,
            dtype=dtype,
        )
        for _ in range(3)
    ]


def _test_cuda_graphs_with_dot_product_attention(
    *,
385
    with_graph: bool,
386
387
    model_config: ModelConfig,
    dtype: torch.dtype,
388
) -> List[torch.Tensor]:
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
    """Helper function for CUDA graph test."""
    reset_rng_states()
    FP8GlobalStateManager.reset()

    # Create dot product attention module.
    assert model_config.hidden_size % model_config.num_heads == 0
    model = DotProductAttention(
        model_config.num_heads,
        model_config.kv_channels,
        attention_dropout=0.0,
    )

    # Graph model if needed.
    if with_graph:
        model = make_graphed_callables(
            model,
            generate_data_for_dot_product_attention(model_config, dtype, warmup=True),
            num_warmup_iters=10,
            fp8_enabled=False,
        )

    # Forward and backward passes.
    for _ in range(3):
        inputs = generate_data_for_dot_product_attention(model_config, dtype)
        grad_output = generate_data(model_config, dtype, requires_grad=False)
        output = model(*inputs)
        output.backward(grad_output)

    return get_outputs(model, output)


@pytest.mark.parametrize("dtype", dtypes)
def test_make_graphed_callables_with_dot_product_attention(
    *,
    model_config: str = "small",
    dtype: torch.dtype,
) -> None:
    """Test CUDA graphs with dot product attention."""
    model_config = model_configs[model_config]
    kwargs = dict(model_config=model_config, dtype=dtype)
    outputs = _test_cuda_graphs_with_dot_product_attention(with_graph=False, **kwargs)
    graph_outputs = _test_cuda_graphs_with_dot_product_attention(with_graph=True, **kwargs)
    assert_all_equal(outputs, graph_outputs)


def _test_cuda_graphs_with_kwargs(
    *,
    with_graph: bool,
    model_config: ModelConfig,
    dtype: torch.dtype,
) -> List[torch.Tensor]:
    """Helper function for CUDA graph test with keyword arguments."""
441
442
443
444
    reset_rng_states()

    # Initialize model.
    model = TransformerLayer(
445
446
447
        model_config.hidden_size,
        model_config.hidden_size,
        model_config.num_heads,
448
449
450
451
452
453
454
455
456
457
458
459
460
461
        hidden_dropout=0.0,
        attention_dropout=0.0,
        self_attn_mask_type="arbitrary",
        fuse_qkv_params=True,
        params_dtype=dtype,
    )

    # Initialize gradient buffers.
    for param in model.parameters():
        param.grad = torch.empty_like(param)

    # Make graphed version of model if needed.
    if with_graph:
        attn_mask = torch.zeros(
462
463
464
            (
                model_config.batch_size,
                1,
465
466
                model_config.max_seqlen_q,
                model_config.max_seqlen_kv,
467
            ),
468
469
470
471
472
            dtype=torch.bool,
            device="cuda",
        )
        model = make_graphed_callables(
            model,
473
            (generate_data(model_config, dtype, warmup=True),),
474
475
476
477
478
479
480
481
482
483
484
            sample_kwargs=dict(attention_mask=attn_mask),
            allow_unused_input=True,
        )

    # Optimizer.
    optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

    # Training loop.
    for _ in range(3):
        optimizer.zero_grad(set_to_none=False)
        for grad_accumulation_step in range(2):
485
486
            input_ = generate_data(model_config, dtype)
            grad_output = generate_data(model_config, dtype, requires_grad=False)
487
488
            attn_mask = torch.randint(
                2,
489
490
491
                (
                    model_config.batch_size,
                    1,
492
493
                    model_config.max_seqlen_q,
                    model_config.max_seqlen_kv,
494
                ),
495
496
497
                dtype=torch.bool,
                device="cuda",
            )
498
            output = model(input_, attention_mask=attn_mask)
499
500
501
502
503
504
505
            output.backward(grad_output)
        optimizer.step()

    return get_outputs(model, output)


def test_make_graphed_callables_with_kwargs(
506
507
    *,
    model_config: str = "small",
508
509
510
    dtype: torch.dtype = torch.float32,
) -> None:
    """Test CUDA graphs with keyword arguments."""
511
512
    model_config = model_configs[model_config]
    kwargs = dict(model_config=model_config, dtype=dtype)
513
514
515
516
517
518
519
520
    outputs = _test_cuda_graphs_with_kwargs(with_graph=False, **kwargs)
    graph_outputs = _test_cuda_graphs_with_kwargs(with_graph=True, **kwargs)
    assert_all_equal(outputs, graph_outputs)


def _test_cuda_graphs_with_interleaved_pipeline_parallelism(
    *,
    with_graph: bool,
521
522
    model_config: ModelConfig,
    dtype: torch.dtype,
523
524
525
526
527
528
529
530
531
532
533
534
535
) -> List[torch.Tensor]:
    """Simulate Megatron-LM interleaved pipeline parallelism."""
    reset_rng_states()

    # Pipeline parallel configuration.
    num_layers = 2
    num_microbatches = 3
    layer_order = [1, 2, 1, 2, -2, -1, 1, 2, -2, -1, -2, -1]

    # Initialize model.
    model = torch.nn.ModuleList(
        [
            Linear(
536
537
                model_config.hidden_size,
                model_config.hidden_size,
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
                params_dtype=dtype,
            )
            for _ in range(num_layers)
        ]
    )

    # Initialize gradient buffers.
    for param in model.parameters():
        param.grad = torch.empty_like(param)

    # Make graphed version of model if needed.
    layer_forwards = {
        (i % num_layers, i // num_layers): model[i % num_layers]
        for i in range(num_layers * num_microbatches)
    }
    if with_graph:
        sample_args = tuple(
555
556
            (generate_data(model_config, dtype, warmup=True),)
            for _ in range(num_layers * num_microbatches)
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
        )
        layer_forwards = make_graphed_callables(
            tuple(model),
            sample_args,
            allow_unused_input=True,
            _order=layer_order,
        )
        layer_forwards = {
            (i // num_microbatches, i % num_microbatches): forward
            for i, forward in enumerate(layer_forwards)
        }

    # Optimizer.
    optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

    # Training loop.
    for _ in range(3):
        optimizer.zero_grad(set_to_none=False)

        # Generate data.
        inputs = {}
        grad_outputs = {}
        for layer_idx in range(num_layers):
            for microbatch_idx in range(num_microbatches):
581
582
                x = generate_data(model_config, dtype)
                dy = generate_data(model_config, dtype, requires_grad=False)
583
                idxs = (layer_idx, microbatch_idx)
584
                inputs[idxs] = x
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
                grad_outputs[idxs] = dy

        # Cache for layer outputs.
        outputs = {}

        def forward(layer_idx: int, microbatch_idx: int):
            """Helper function for forward steps"""
            idxs = (layer_idx, microbatch_idx)
            outputs[idxs] = layer_forwards[idxs](inputs[idxs])

        def backward(layer_idx: int, microbatch_idx: int):
            """Helper function for backward steps"""
            outputs[layer_idx, microbatch_idx].backward(grad_outputs[layer_idx, microbatch_idx])

        # Forward and backward steps.
        forward(0, 0)
        forward(1, 0)
        forward(0, 1)
        forward(1, 1)
        backward(1, 0)
        backward(0, 0)
        forward(0, 2)
        forward(1, 2)
        backward(1, 1)
        backward(0, 1)
        backward(1, 2)
        backward(0, 2)

        # Optimizer step.
        optimizer.step()

    outputs = [y for _, y in sorted(outputs.items())]
    return get_outputs(model, outputs)


def test_make_graphed_callables_with_interleaved_pipeline_parallelism(
621
622
    *,
    model_config: str = "small",
623
624
625
    dtype: torch.dtype = torch.float16,
) -> None:
    """Test CUDA graphs with Megatron-LM interleaved pipeline parallelism."""
626
627
    model_config = model_configs[model_config]
    kwargs = dict(model_config=model_config, dtype=dtype)
628
629
630
631
632
633
634
635
636
    outputs = _test_cuda_graphs_with_interleaved_pipeline_parallelism(
        with_graph=False,
        **kwargs,
    )
    graph_outputs = _test_cuda_graphs_with_interleaved_pipeline_parallelism(
        with_graph=True,
        **kwargs,
    )
    assert_all_equal(outputs, graph_outputs)