test_cuda_graphs.py 18.9 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.

5
from typing import Iterable, List, Union
6
7
8
9
import pytest

import torch
from transformer_engine.pytorch import (
10
11
12
13
14
15
16
17
    DotProductAttention,
    LayerNormLinear,
    LayerNormMLP,
    Linear,
    MultiheadAttention,
    TransformerLayer,
    fp8_autocast,
    fp8_model_init,
18
    make_graphed_callables,
19
20
21
)
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.utils import is_bf16_compatible
22
import transformer_engine.pytorch.ops as te_ops
23
from transformer_engine.common import recipe
24
from utils import ModelConfig, reset_rng_states
25

26
# Check if FP8 is supported.
27
28
29
fp8_available, _ = FP8GlobalStateManager.is_fp8_available()
fp8_block_scaling_available, _ = FP8GlobalStateManager.is_fp8_block_scaling_available()
mxfp8_available, _ = FP8GlobalStateManager.is_mxfp8_available()
30

31
32
# Reset RNG states.
reset_rng_states()
33

34
35
36
model_configs = {
    "small": ModelConfig(32, 2, 2, 32),
}
37

38
39
40
41
42
43
44
45
fp8_recipes = []
if mxfp8_available:
    fp8_recipes.append(recipe.MXFP8BlockScaling())
if fp8_block_scaling_available:
    fp8_recipes.append(recipe.Float8BlockScaling())
if fp8_available:
    fp8_recipes.append(recipe.Float8CurrentScaling())
    fp8_recipes.append(recipe.DelayedScaling())
46

47
48
# Supported data types
dtypes: List[torch.dtype] = [torch.float32, torch.float16]
49
50
51
52
53
54
55
56
57
58
59
if is_bf16_compatible():  # bf16 requires sm_80 or higher
    dtypes.append(torch.bfloat16)


@pytest.fixture(autouse=True)
def reset_global_fp8_state():
    yield
    FP8GlobalStateManager.reset()


def assert_all_equal(l1: List[torch.Tensor], l2: List[torch.Tensor], names=None) -> bool:
60
    """Check that two lists of tensors match exactly."""
61
    assert len(l1) == len(l2), "Unequal number of outputs."
62
63
    failure_message = "Output mismatches in:"
    failed_tensors = []
64
65
    for i, (t1, t2) in enumerate(zip(l1, l2)):
        if not torch.equal(t1, t2):
66
67
68
69
70
71
72
73
74
75
            failure_message += "\n    "
            if names is None:
                failure_message += f"tensor at idx={i}"
            else:
                failure_message += names[i]
            failed_tensors.append((t1, t2))
    if failed_tensors:
        print(failure_message)
        t1, t2 = failed_tensors[0]
        torch.testing.assert_close(t1, t2, rtol=0, atol=0)
76
77
78


def generate_data(
79
    model_config: ModelConfig,
80
81
    dtype: torch.dtype,
    warmup: bool = False,
82
83
    requires_grad: bool = True,
) -> torch.Tensor:
84
85
    """Generate synthetic data."""
    gen_func = torch.ones if warmup else torch.randn
86
    return gen_func(
87
        model_config.max_seqlen_q,
88
89
        model_config.batch_size,
        model_config.hidden_size,
90
        device="cuda",
91
        requires_grad=requires_grad,
92
93
        dtype=dtype,
    )
94
95


96
97
98
99
def get_outputs(
    model: torch.nn.Module,
    output: Union[torch.Tensor, Iterable[torch.Tensor]],
) -> List[torch.Tensor]:
100
101
102
103
104
105
    """Return grads and params for comparsion."""
    values = []
    for param in model.parameters():
        values.append(param)
        if param.grad is not None:
            values.append(param.grad)
106
107
108
109
    if isinstance(output, torch.Tensor):
        values.append(output)
    else:
        values.extend(output)
110
111
112
    return values


113
114
115
116
117
118
119
120
121
122
class _Sequential(torch.nn.Sequential):
    """Sequential model that forwards keyword arguments to modules"""

    def forward(self, input_: torch.Tensor, **kwargs) -> torch.Tensor:
        x = input_
        for module in self:
            x = module(x, **kwargs)
        return x


123
124
125
126
127
128
129
130
131
132
133
# Supported modules
_test_cuda_graphs_modules: List[str] = [
    "transformer",
    "layernorm_mlp",
    "layernorm_linear",
    "linear",
    "mha",
    "linear_op",
]


134
135
def _test_cuda_graphs(
    *,
136
137
138
    graph_mode: str,
    module: str,
    model_config: ModelConfig,
139
140
141
142
143
    num_layers: int,
    dtype: torch.dtype,
    fp8: bool,
    fp8_params: bool,
    fp8_weight_caching: bool,
144
    fp8_recipe: recipe.Recipe,
145
) -> List[torch.Tensor]:
146
    """Helper function for CUDA graph test."""
147
148
149
    reset_rng_states()
    FP8GlobalStateManager.reset()

150
151
152
153
154
    # Operation-based API does not support FP8 weight caching.
    if module == "linear_op":
        fp8_weight_caching = False

    # Create modules.
155
    with fp8_model_init(enabled=fp8_params, recipe=fp8_recipe):
156
        if module == "transformer":
157
158
            modules = [
                TransformerLayer(
159
160
161
                    model_config.hidden_size,
                    model_config.hidden_size,
                    model_config.num_heads,
162
163
164
165
166
167
168
                    hidden_dropout=0.0,
                    attention_dropout=0.0,
                    fuse_qkv_params=True,
                    params_dtype=dtype,
                )
                for _ in range(num_layers)
            ]
169
        elif module == "layernorm_mlp":
170
            modules = [
171
172
173
174
175
                LayerNormMLP(
                    model_config.hidden_size,
                    model_config.hidden_size,
                    params_dtype=dtype,
                )
176
177
                for _ in range(num_layers)
            ]
178
        elif module == "layernorm_linear":
179
            modules = [
180
181
182
183
184
                LayerNormLinear(
                    model_config.hidden_size,
                    model_config.hidden_size,
                    params_dtype=dtype,
                )
185
186
                for _ in range(num_layers)
            ]
187
        elif module == "mha":
188
189
            modules = [
                MultiheadAttention(
190
191
                    model_config.hidden_size,
                    model_config.num_heads,
192
193
194
195
196
197
                    attention_dropout=0.0,
                    params_dtype=dtype,
                    fuse_qkv_params=True,
                )
                for _ in range(num_layers)
            ]
198
        elif module == "linear":
199
            modules = [
200
201
202
203
204
205
                Linear(
                    model_config.hidden_size,
                    model_config.hidden_size,
                    device="cuda",
                    params_dtype=dtype,
                )
206
207
                for _ in range(num_layers)
            ]
208
        elif module == "linear_op":
209
            modules = [
210
211
212
213
214
215
216
                te_ops.Sequential(
                    te_ops.Linear(
                        model_config.hidden_size,
                        model_config.hidden_size,
                        dtype=dtype,
                    ),
                )
217
218
                for _ in range(num_layers)
            ]
219
220
        else:
            raise ValueError(f"Unknown module type ({module})")
221

222
223
224
225
226
        # Initialize gradient buffers.
        for module in modules:
            for param in module.parameters():
                param.grad = torch.empty_like(param)

227
        # Generate model and wrap API to return graphed version.
228
229
        if graph_mode == "full":
            # Graph entire model at once.
230
            model = torch.nn.Sequential(*modules)
231
232
            model = make_graphed_callables(
                model,
233
                (generate_data(model_config, dtype, warmup=True),),
234
235
236
                num_warmup_iters=10,
                fp8_enabled=fp8,
                fp8_weight_caching=fp8_weight_caching,
237
                fp8_recipe=fp8_recipe,
238
239
            )
        elif graph_mode == "individual":
240
            # Graph individual modules.
241
242
            modules = [
                make_graphed_callables(
243
                    module,
244
                    (generate_data(model_config, dtype, warmup=True),),
245
                    num_warmup_iters=10,
246
247
                    fp8_enabled=fp8,
                    fp8_weight_caching=fp8_weight_caching,
248
                    fp8_recipe=fp8_recipe,
249
250
251
                )
                for module in modules
            ]
252
            model = _Sequential(*modules)
253
        else:
254
            model = _Sequential(*modules)
255

256
    # Optimizer.
257
    optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
258

259
    # Training steps.
260
    for _ in range(3):
261
        optimizer.zero_grad(set_to_none=False)
262
        for grad_accumulation_step in range(2):
263
264
            input_ = generate_data(model_config, dtype)
            grad_output = generate_data(model_config, dtype, requires_grad=False)
265
            with fp8_autocast(enabled=fp8, fp8_recipe=fp8_recipe):
266
267
                kwargs = {}
                if fp8_weight_caching:
268
                    kwargs["is_first_microbatch"] = grad_accumulation_step == 0
269
                output = model(input_, **kwargs)
270
            output.backward(grad_output)
271
        optimizer.step()
272
273
274
275

    return get_outputs(model, output)


276
@pytest.mark.parametrize("module", _test_cuda_graphs_modules)
277
@pytest.mark.parametrize("dtype", dtypes)
278
@pytest.mark.parametrize("fp8_params", (False, True))
279
@pytest.mark.parametrize("fp8_recipe", fp8_recipes + [None])
280
281
282
283
284
def test_make_graphed_callables(
    *,
    module: str,
    model_config: str = "small",
    num_layers: int = 3,
285
286
    dtype: torch.dtype,
    fp8_params: bool,
287
    fp8_recipe: recipe.Recipe,
288
    fp8_weight_caching: bool = False,
289
) -> None:
290

291
    fp8 = fp8_recipe is not None
292
293
    if fp8_params and not fp8:
        pytest.skip("FP8 needed for FP8 parameters.")
294
295
    if fp8_weight_caching and not fp8:
        pytest.skip("FP8 needed for FP8 parameters.")
296
    if fp8 and fp8_recipe.float8_block_scaling() and module == "linear_op":
297
        pytest.skip("Module not yet supported for float8_block_scaling with CUDA graphs")
298

299
300
    # Run model with different CUDA graph settings.
    model_config = model_configs[model_config]
301
    kwargs = dict(
302
303
        module=module,
        model_config=model_config,
304
305
306
307
308
        num_layers=num_layers,
        dtype=dtype,
        fp8=fp8,
        fp8_params=fp8_params,
        fp8_weight_caching=fp8_weight_caching,
309
        fp8_recipe=fp8_recipe,
310
311
312
313
    )
    outputs = _test_cuda_graphs(graph_mode="none", **kwargs)
    graph_outputs_mode1 = _test_cuda_graphs(graph_mode="full", **kwargs)
    graph_outputs_mode2 = _test_cuda_graphs(graph_mode="individual", **kwargs)
314

315
    # Check that results match.
316
317
    assert_all_equal(outputs, graph_outputs_mode1)
    assert_all_equal(outputs, graph_outputs_mode2)
318
319


320
321
322
323
324
325
326
327
328
329
330
331
332
333
_test_make_graphed_callables_with_fp8_weight_caching_modules = [
    "transformer",
    "layernorm_mlp",
    "layernorm_linear",
    "linear",
    "mha",
]


@pytest.mark.parametrize(
    "module",
    _test_make_graphed_callables_with_fp8_weight_caching_modules,
)
@pytest.mark.parametrize("fp8_params", (False, True))
334
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
335
def test_make_graphed_callables_with_fp8_weight_caching(
336
    *,
337
338
    module: str,
    fp8_params: bool,
339
    fp8_recipe: recipe.Recipe,
340
341
342
343
344
) -> None:
    test_make_graphed_callables(
        module=module,
        dtype=torch.float32,
        fp8_params=fp8_params,
345
        fp8_recipe=fp8_recipe,
346
347
348
349
350
351
        fp8_weight_caching=True,
    )


def generate_data_for_dot_product_attention(
    model_config: ModelConfig,
352
    dtype: torch.dtype,
353
354
355
356
357
358
    warmup: bool = False,
) -> List[torch.Tensor]:
    """Generate synthetic data for dot product attention."""
    gen_func = torch.ones if warmup else torch.randn
    return [
        gen_func(
359
            model_config.max_seqlen_q,
360
361
362
363
364
365
366
367
368
369
370
371
372
            model_config.batch_size,
            model_config.num_heads,
            model_config.kv_channels,
            device="cuda",
            requires_grad=True,
            dtype=dtype,
        )
        for _ in range(3)
    ]


def _test_cuda_graphs_with_dot_product_attention(
    *,
373
    with_graph: bool,
374
375
    model_config: ModelConfig,
    dtype: torch.dtype,
376
) -> List[torch.Tensor]:
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
    """Helper function for CUDA graph test."""
    reset_rng_states()
    FP8GlobalStateManager.reset()

    # Create dot product attention module.
    assert model_config.hidden_size % model_config.num_heads == 0
    model = DotProductAttention(
        model_config.num_heads,
        model_config.kv_channels,
        attention_dropout=0.0,
    )

    # Graph model if needed.
    if with_graph:
        model = make_graphed_callables(
            model,
            generate_data_for_dot_product_attention(model_config, dtype, warmup=True),
            num_warmup_iters=10,
            fp8_enabled=False,
        )

    # Forward and backward passes.
    for _ in range(3):
        inputs = generate_data_for_dot_product_attention(model_config, dtype)
        grad_output = generate_data(model_config, dtype, requires_grad=False)
        output = model(*inputs)
        output.backward(grad_output)

    return get_outputs(model, output)


@pytest.mark.parametrize("dtype", dtypes)
def test_make_graphed_callables_with_dot_product_attention(
    *,
    model_config: str = "small",
    dtype: torch.dtype,
) -> None:
    """Test CUDA graphs with dot product attention."""
    model_config = model_configs[model_config]
    kwargs = dict(model_config=model_config, dtype=dtype)
    outputs = _test_cuda_graphs_with_dot_product_attention(with_graph=False, **kwargs)
    graph_outputs = _test_cuda_graphs_with_dot_product_attention(with_graph=True, **kwargs)
    assert_all_equal(outputs, graph_outputs)


def _test_cuda_graphs_with_kwargs(
    *,
    with_graph: bool,
    model_config: ModelConfig,
    dtype: torch.dtype,
) -> List[torch.Tensor]:
    """Helper function for CUDA graph test with keyword arguments."""
429
430
431
432
    reset_rng_states()

    # Initialize model.
    model = TransformerLayer(
433
434
435
        model_config.hidden_size,
        model_config.hidden_size,
        model_config.num_heads,
436
437
438
439
440
441
442
443
444
445
446
447
448
449
        hidden_dropout=0.0,
        attention_dropout=0.0,
        self_attn_mask_type="arbitrary",
        fuse_qkv_params=True,
        params_dtype=dtype,
    )

    # Initialize gradient buffers.
    for param in model.parameters():
        param.grad = torch.empty_like(param)

    # Make graphed version of model if needed.
    if with_graph:
        attn_mask = torch.zeros(
450
451
452
            (
                model_config.batch_size,
                1,
453
454
                model_config.max_seqlen_q,
                model_config.max_seqlen_kv,
455
            ),
456
457
458
459
460
            dtype=torch.bool,
            device="cuda",
        )
        model = make_graphed_callables(
            model,
461
            (generate_data(model_config, dtype, warmup=True),),
462
463
464
465
466
467
468
469
470
471
472
            sample_kwargs=dict(attention_mask=attn_mask),
            allow_unused_input=True,
        )

    # Optimizer.
    optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

    # Training loop.
    for _ in range(3):
        optimizer.zero_grad(set_to_none=False)
        for grad_accumulation_step in range(2):
473
474
            input_ = generate_data(model_config, dtype)
            grad_output = generate_data(model_config, dtype, requires_grad=False)
475
476
            attn_mask = torch.randint(
                2,
477
478
479
                (
                    model_config.batch_size,
                    1,
480
481
                    model_config.max_seqlen_q,
                    model_config.max_seqlen_kv,
482
                ),
483
484
485
                dtype=torch.bool,
                device="cuda",
            )
486
            output = model(input_, attention_mask=attn_mask)
487
488
489
490
491
492
493
            output.backward(grad_output)
        optimizer.step()

    return get_outputs(model, output)


def test_make_graphed_callables_with_kwargs(
494
495
    *,
    model_config: str = "small",
496
497
498
    dtype: torch.dtype = torch.float32,
) -> None:
    """Test CUDA graphs with keyword arguments."""
499
500
    model_config = model_configs[model_config]
    kwargs = dict(model_config=model_config, dtype=dtype)
501
502
503
504
505
506
507
508
    outputs = _test_cuda_graphs_with_kwargs(with_graph=False, **kwargs)
    graph_outputs = _test_cuda_graphs_with_kwargs(with_graph=True, **kwargs)
    assert_all_equal(outputs, graph_outputs)


def _test_cuda_graphs_with_interleaved_pipeline_parallelism(
    *,
    with_graph: bool,
509
510
    model_config: ModelConfig,
    dtype: torch.dtype,
511
512
513
514
515
516
517
518
519
520
521
522
523
) -> List[torch.Tensor]:
    """Simulate Megatron-LM interleaved pipeline parallelism."""
    reset_rng_states()

    # Pipeline parallel configuration.
    num_layers = 2
    num_microbatches = 3
    layer_order = [1, 2, 1, 2, -2, -1, 1, 2, -2, -1, -2, -1]

    # Initialize model.
    model = torch.nn.ModuleList(
        [
            Linear(
524
525
                model_config.hidden_size,
                model_config.hidden_size,
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
                params_dtype=dtype,
            )
            for _ in range(num_layers)
        ]
    )

    # Initialize gradient buffers.
    for param in model.parameters():
        param.grad = torch.empty_like(param)

    # Make graphed version of model if needed.
    layer_forwards = {
        (i % num_layers, i // num_layers): model[i % num_layers]
        for i in range(num_layers * num_microbatches)
    }
    if with_graph:
        sample_args = tuple(
543
544
            (generate_data(model_config, dtype, warmup=True),)
            for _ in range(num_layers * num_microbatches)
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
        )
        layer_forwards = make_graphed_callables(
            tuple(model),
            sample_args,
            allow_unused_input=True,
            _order=layer_order,
        )
        layer_forwards = {
            (i // num_microbatches, i % num_microbatches): forward
            for i, forward in enumerate(layer_forwards)
        }

    # Optimizer.
    optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

    # Training loop.
    for _ in range(3):
        optimizer.zero_grad(set_to_none=False)

        # Generate data.
        inputs = {}
        grad_outputs = {}
        for layer_idx in range(num_layers):
            for microbatch_idx in range(num_microbatches):
569
570
                x = generate_data(model_config, dtype)
                dy = generate_data(model_config, dtype, requires_grad=False)
571
                idxs = (layer_idx, microbatch_idx)
572
                inputs[idxs] = x
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
                grad_outputs[idxs] = dy

        # Cache for layer outputs.
        outputs = {}

        def forward(layer_idx: int, microbatch_idx: int):
            """Helper function for forward steps"""
            idxs = (layer_idx, microbatch_idx)
            outputs[idxs] = layer_forwards[idxs](inputs[idxs])

        def backward(layer_idx: int, microbatch_idx: int):
            """Helper function for backward steps"""
            outputs[layer_idx, microbatch_idx].backward(grad_outputs[layer_idx, microbatch_idx])

        # Forward and backward steps.
        forward(0, 0)
        forward(1, 0)
        forward(0, 1)
        forward(1, 1)
        backward(1, 0)
        backward(0, 0)
        forward(0, 2)
        forward(1, 2)
        backward(1, 1)
        backward(0, 1)
        backward(1, 2)
        backward(0, 2)

        # Optimizer step.
        optimizer.step()

    outputs = [y for _, y in sorted(outputs.items())]
    return get_outputs(model, outputs)


def test_make_graphed_callables_with_interleaved_pipeline_parallelism(
609
610
    *,
    model_config: str = "small",
611
612
613
    dtype: torch.dtype = torch.float16,
) -> None:
    """Test CUDA graphs with Megatron-LM interleaved pipeline parallelism."""
614
615
    model_config = model_configs[model_config]
    kwargs = dict(model_config=model_config, dtype=dtype)
616
617
618
619
620
621
622
623
624
    outputs = _test_cuda_graphs_with_interleaved_pipeline_parallelism(
        with_graph=False,
        **kwargs,
    )
    graph_outputs = _test_cuda_graphs_with_interleaved_pipeline_parallelism(
        with_graph=True,
        **kwargs,
    )
    assert_all_equal(outputs, graph_outputs)