You need to sign in or sign up before continuing.
test_cuda_graphs.py 19.2 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.

5
from dataclasses import dataclass
6
7
import itertools
from typing import Iterable, List, Tuple, Union
8
9
10
11
import pytest

import torch
from transformer_engine.pytorch import (
12
13
14
15
16
17
18
19
    DotProductAttention,
    LayerNormLinear,
    LayerNormMLP,
    Linear,
    MultiheadAttention,
    TransformerLayer,
    fp8_autocast,
    fp8_model_init,
20
    make_graphed_callables,
21
22
23
)
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.utils import is_bf16_compatible
24
import transformer_engine.pytorch.ops as te_ops
25
from transformer_engine.common import recipe
26
27


28
# Check if FP8 is supported.
29
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
30
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
31
32


33
# Record initial RNG state.
34
35
36
37
38
39
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
_cpu_rng_state = torch.get_rng_state()
_cuda_rng_state = torch.cuda.get_rng_state()

40

41
@dataclass
42
class ModelConfig:
43
    """Data tensor dimensions within Transformer model"""
44

45
46
47
48
49
    sequence_length: int
    batch_size: int
    hidden_size: int
    num_heads: int
    kv_channels: int
50

51

52
model_configs = {"small": ModelConfig(2, 32, 64, 2, 32)}
53

54
55
56
57
58
fp8_recipes = [
    recipe.DelayedScaling(),
    recipe.MXFP8BlockScaling(),
]

59
60
# Supported data types
dtypes: List[torch.dtype] = [torch.float32, torch.float16]
61
62
63
64
65
if is_bf16_compatible():  # bf16 requires sm_80 or higher
    dtypes.append(torch.bfloat16)


def reset_rng_states() -> None:
66
    """Revert to initial RNG state."""
67
68
69
70
71
72
73
74
75
76
77
    torch.set_rng_state(_cpu_rng_state)
    torch.cuda.set_rng_state(_cuda_rng_state)


@pytest.fixture(autouse=True)
def reset_global_fp8_state():
    yield
    FP8GlobalStateManager.reset()


def assert_all_equal(l1: List[torch.Tensor], l2: List[torch.Tensor], names=None) -> bool:
78
    """Check that two lists of tensors match exactly."""
79
    assert len(l1) == len(l2), "Unequal number of outputs."
80
81
    failure_message = "Output mismatches in:"
    failed_tensors = []
82
83
    for i, (t1, t2) in enumerate(zip(l1, l2)):
        if not torch.equal(t1, t2):
84
85
86
87
88
89
90
91
92
93
            failure_message += "\n    "
            if names is None:
                failure_message += f"tensor at idx={i}"
            else:
                failure_message += names[i]
            failed_tensors.append((t1, t2))
    if failed_tensors:
        print(failure_message)
        t1, t2 = failed_tensors[0]
        torch.testing.assert_close(t1, t2, rtol=0, atol=0)
94
95
96


def generate_data(
97
    model_config: ModelConfig,
98
99
    dtype: torch.dtype,
    warmup: bool = False,
100
101
    requires_grad: bool = True,
) -> torch.Tensor:
102
103
    """Generate synthetic data."""
    gen_func = torch.ones if warmup else torch.randn
104
105
106
107
    return gen_func(
        model_config.sequence_length,
        model_config.batch_size,
        model_config.hidden_size,
108
        device="cuda",
109
        requires_grad=requires_grad,
110
111
        dtype=dtype,
    )
112
113


114
115
116
117
def get_outputs(
    model: torch.nn.Module,
    output: Union[torch.Tensor, Iterable[torch.Tensor]],
) -> List[torch.Tensor]:
118
119
120
121
122
123
    """Return grads and params for comparsion."""
    values = []
    for param in model.parameters():
        values.append(param)
        if param.grad is not None:
            values.append(param.grad)
124
125
126
127
    if isinstance(output, torch.Tensor):
        values.append(output)
    else:
        values.extend(output)
128
129
130
    return values


131
132
133
134
135
136
137
138
139
140
class _Sequential(torch.nn.Sequential):
    """Sequential model that forwards keyword arguments to modules"""

    def forward(self, input_: torch.Tensor, **kwargs) -> torch.Tensor:
        x = input_
        for module in self:
            x = module(x, **kwargs)
        return x


141
142
143
144
145
146
147
148
149
150
151
# Supported modules
_test_cuda_graphs_modules: List[str] = [
    "transformer",
    "layernorm_mlp",
    "layernorm_linear",
    "linear",
    "mha",
    "linear_op",
]


152
153
def _test_cuda_graphs(
    *,
154
155
156
    graph_mode: str,
    module: str,
    model_config: ModelConfig,
157
158
159
160
161
    num_layers: int,
    dtype: torch.dtype,
    fp8: bool,
    fp8_params: bool,
    fp8_weight_caching: bool,
162
    fp8_recipe: recipe.Recipe,
163
) -> List[torch.Tensor]:
164
    """Helper function for CUDA graph test."""
165
166
167
    reset_rng_states()
    FP8GlobalStateManager.reset()

168
169
170
171
172
    # Operation-based API does not support FP8 weight caching.
    if module == "linear_op":
        fp8_weight_caching = False

    # Create modules.
173
    with fp8_model_init(enabled=fp8_params, recipe=fp8_recipe):
174
        if module == "transformer":
175
176
            modules = [
                TransformerLayer(
177
178
179
                    model_config.hidden_size,
                    model_config.hidden_size,
                    model_config.num_heads,
180
181
182
183
184
185
186
                    hidden_dropout=0.0,
                    attention_dropout=0.0,
                    fuse_qkv_params=True,
                    params_dtype=dtype,
                )
                for _ in range(num_layers)
            ]
187
        elif module == "layernorm_mlp":
188
            modules = [
189
190
191
192
193
                LayerNormMLP(
                    model_config.hidden_size,
                    model_config.hidden_size,
                    params_dtype=dtype,
                )
194
195
                for _ in range(num_layers)
            ]
196
        elif module == "layernorm_linear":
197
            modules = [
198
199
200
201
202
                LayerNormLinear(
                    model_config.hidden_size,
                    model_config.hidden_size,
                    params_dtype=dtype,
                )
203
204
                for _ in range(num_layers)
            ]
205
        elif module == "mha":
206
207
            modules = [
                MultiheadAttention(
208
209
                    model_config.hidden_size,
                    model_config.num_heads,
210
211
212
213
214
215
                    attention_dropout=0.0,
                    params_dtype=dtype,
                    fuse_qkv_params=True,
                )
                for _ in range(num_layers)
            ]
216
        elif module == "linear":
217
            modules = [
218
219
220
221
222
223
                Linear(
                    model_config.hidden_size,
                    model_config.hidden_size,
                    device="cuda",
                    params_dtype=dtype,
                )
224
225
                for _ in range(num_layers)
            ]
226
        elif module == "linear_op":
227
            modules = [
228
229
230
231
232
233
234
                te_ops.Sequential(
                    te_ops.Linear(
                        model_config.hidden_size,
                        model_config.hidden_size,
                        dtype=dtype,
                    ),
                )
235
236
                for _ in range(num_layers)
            ]
237
238
        else:
            raise ValueError(f"Unknown module type ({module})")
239

240
241
242
243
244
        # Initialize gradient buffers.
        for module in modules:
            for param in module.parameters():
                param.grad = torch.empty_like(param)

245
        # Generate model and wrap API to return graphed version.
246
247
        if graph_mode == "full":
            # Graph entire model at once.
248
            model = torch.nn.Sequential(*modules)
249
250
            model = make_graphed_callables(
                model,
251
                (generate_data(model_config, dtype, warmup=True),),
252
253
254
                num_warmup_iters=10,
                fp8_enabled=fp8,
                fp8_weight_caching=fp8_weight_caching,
255
                fp8_recipe=fp8_recipe,
256
257
            )
        elif graph_mode == "individual":
258
            # Graph individual modules.
259
260
            modules = [
                make_graphed_callables(
261
                    module,
262
                    (generate_data(model_config, dtype, warmup=True),),
263
                    num_warmup_iters=10,
264
265
                    fp8_enabled=fp8,
                    fp8_weight_caching=fp8_weight_caching,
266
                    fp8_recipe=fp8_recipe,
267
268
269
                )
                for module in modules
            ]
270
            model = _Sequential(*modules)
271
        else:
272
            model = _Sequential(*modules)
273

274
    # Optimizer.
275
    optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
276

277
    # Training steps.
278
    for _ in range(3):
279
        optimizer.zero_grad(set_to_none=False)
280
        for grad_accumulation_step in range(2):
281
282
            input_ = generate_data(model_config, dtype)
            grad_output = generate_data(model_config, dtype, requires_grad=False)
283
            with fp8_autocast(enabled=fp8, fp8_recipe=fp8_recipe):
284
285
                kwargs = {}
                if fp8_weight_caching:
286
                    kwargs["is_first_microbatch"] = grad_accumulation_step == 0
287
                output = model(input_, **kwargs)
288
            output.backward(grad_output)
289
        optimizer.step()
290
291
292
293

    return get_outputs(model, output)


294
@pytest.mark.parametrize("module", _test_cuda_graphs_modules)
295
@pytest.mark.parametrize("dtype", dtypes)
296
297
@pytest.mark.parametrize("fp8", (False, True))
@pytest.mark.parametrize("fp8_params", (False, True))
298
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
299
300
301
302
303
def test_make_graphed_callables(
    *,
    module: str,
    model_config: str = "small",
    num_layers: int = 3,
304
305
306
    dtype: torch.dtype,
    fp8: bool,
    fp8_params: bool,
307
    fp8_recipe: recipe.Recipe,
308
    fp8_weight_caching: bool = False,
309
) -> None:
310
311

    # Skip invalid configurations.
312
313
314
315
    if fp8 and not fp8_available:
        pytest.skip(reason_for_no_fp8)
    if fp8_params and not fp8:
        pytest.skip("FP8 needed for FP8 parameters.")
316
317
    if fp8_weight_caching and not fp8:
        pytest.skip("FP8 needed for FP8 parameters.")
318
319
    if fp8_recipe.mxfp8() and not mxfp8_available:
        pytest.skip(reason_for_no_mxfp8)
320

321
322
    # Run model with different CUDA graph settings.
    model_config = model_configs[model_config]
323
    kwargs = dict(
324
325
        module=module,
        model_config=model_config,
326
327
328
329
330
        num_layers=num_layers,
        dtype=dtype,
        fp8=fp8,
        fp8_params=fp8_params,
        fp8_weight_caching=fp8_weight_caching,
331
        fp8_recipe=fp8_recipe,
332
333
334
335
    )
    outputs = _test_cuda_graphs(graph_mode="none", **kwargs)
    graph_outputs_mode1 = _test_cuda_graphs(graph_mode="full", **kwargs)
    graph_outputs_mode2 = _test_cuda_graphs(graph_mode="individual", **kwargs)
336

337
    # Check that results match.
338
339
    assert_all_equal(outputs, graph_outputs_mode1)
    assert_all_equal(outputs, graph_outputs_mode2)
340
341


342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
_test_make_graphed_callables_with_fp8_weight_caching_modules = [
    "transformer",
    "layernorm_mlp",
    "layernorm_linear",
    "linear",
    "mha",
]


@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.parametrize(
    "module",
    _test_make_graphed_callables_with_fp8_weight_caching_modules,
)
@pytest.mark.parametrize("fp8_params", (False, True))
357
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
358
def test_make_graphed_callables_with_fp8_weight_caching(
359
    *,
360
361
    module: str,
    fp8_params: bool,
362
    fp8_recipe: recipe.Recipe,
363
364
365
366
367
368
) -> None:
    test_make_graphed_callables(
        module=module,
        dtype=torch.float32,
        fp8=True,
        fp8_params=fp8_params,
369
        fp8_recipe=fp8_recipe,
370
371
372
373
374
375
        fp8_weight_caching=True,
    )


def generate_data_for_dot_product_attention(
    model_config: ModelConfig,
376
    dtype: torch.dtype,
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
    warmup: bool = False,
) -> List[torch.Tensor]:
    """Generate synthetic data for dot product attention."""
    gen_func = torch.ones if warmup else torch.randn
    return [
        gen_func(
            model_config.sequence_length,
            model_config.batch_size,
            model_config.num_heads,
            model_config.kv_channels,
            device="cuda",
            requires_grad=True,
            dtype=dtype,
        )
        for _ in range(3)
    ]


def _test_cuda_graphs_with_dot_product_attention(
    *,
397
    with_graph: bool,
398
399
    model_config: ModelConfig,
    dtype: torch.dtype,
400
) -> List[torch.Tensor]:
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
    """Helper function for CUDA graph test."""
    reset_rng_states()
    FP8GlobalStateManager.reset()

    # Create dot product attention module.
    assert model_config.hidden_size % model_config.num_heads == 0
    model = DotProductAttention(
        model_config.num_heads,
        model_config.kv_channels,
        attention_dropout=0.0,
    )

    # Graph model if needed.
    if with_graph:
        model = make_graphed_callables(
            model,
            generate_data_for_dot_product_attention(model_config, dtype, warmup=True),
            num_warmup_iters=10,
            fp8_enabled=False,
        )

    # Forward and backward passes.
    for _ in range(3):
        inputs = generate_data_for_dot_product_attention(model_config, dtype)
        grad_output = generate_data(model_config, dtype, requires_grad=False)
        output = model(*inputs)
        output.backward(grad_output)

    return get_outputs(model, output)


@pytest.mark.parametrize("dtype", dtypes)
def test_make_graphed_callables_with_dot_product_attention(
    *,
    model_config: str = "small",
    dtype: torch.dtype,
) -> None:
    """Test CUDA graphs with dot product attention."""
    model_config = model_configs[model_config]
    kwargs = dict(model_config=model_config, dtype=dtype)
    outputs = _test_cuda_graphs_with_dot_product_attention(with_graph=False, **kwargs)
    graph_outputs = _test_cuda_graphs_with_dot_product_attention(with_graph=True, **kwargs)
    assert_all_equal(outputs, graph_outputs)


def _test_cuda_graphs_with_kwargs(
    *,
    with_graph: bool,
    model_config: ModelConfig,
    dtype: torch.dtype,
) -> List[torch.Tensor]:
    """Helper function for CUDA graph test with keyword arguments."""
453
454
455
456
    reset_rng_states()

    # Initialize model.
    model = TransformerLayer(
457
458
459
        model_config.hidden_size,
        model_config.hidden_size,
        model_config.num_heads,
460
461
462
463
464
465
466
467
468
469
470
471
472
473
        hidden_dropout=0.0,
        attention_dropout=0.0,
        self_attn_mask_type="arbitrary",
        fuse_qkv_params=True,
        params_dtype=dtype,
    )

    # Initialize gradient buffers.
    for param in model.parameters():
        param.grad = torch.empty_like(param)

    # Make graphed version of model if needed.
    if with_graph:
        attn_mask = torch.zeros(
474
475
476
477
478
479
            (
                model_config.batch_size,
                1,
                model_config.sequence_length,
                model_config.sequence_length,
            ),
480
481
482
483
484
            dtype=torch.bool,
            device="cuda",
        )
        model = make_graphed_callables(
            model,
485
            (generate_data(model_config, dtype, warmup=True),),
486
487
488
489
490
491
492
493
494
495
496
            sample_kwargs=dict(attention_mask=attn_mask),
            allow_unused_input=True,
        )

    # Optimizer.
    optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

    # Training loop.
    for _ in range(3):
        optimizer.zero_grad(set_to_none=False)
        for grad_accumulation_step in range(2):
497
498
            input_ = generate_data(model_config, dtype)
            grad_output = generate_data(model_config, dtype, requires_grad=False)
499
500
            attn_mask = torch.randint(
                2,
501
502
503
504
505
506
                (
                    model_config.batch_size,
                    1,
                    model_config.sequence_length,
                    model_config.sequence_length,
                ),
507
508
509
                dtype=torch.bool,
                device="cuda",
            )
510
            output = model(input_, attention_mask=attn_mask)
511
512
513
514
515
516
517
            output.backward(grad_output)
        optimizer.step()

    return get_outputs(model, output)


def test_make_graphed_callables_with_kwargs(
518
519
    *,
    model_config: str = "small",
520
521
522
    dtype: torch.dtype = torch.float32,
) -> None:
    """Test CUDA graphs with keyword arguments."""
523
524
    model_config = model_configs[model_config]
    kwargs = dict(model_config=model_config, dtype=dtype)
525
526
527
528
529
530
531
532
    outputs = _test_cuda_graphs_with_kwargs(with_graph=False, **kwargs)
    graph_outputs = _test_cuda_graphs_with_kwargs(with_graph=True, **kwargs)
    assert_all_equal(outputs, graph_outputs)


def _test_cuda_graphs_with_interleaved_pipeline_parallelism(
    *,
    with_graph: bool,
533
534
    model_config: ModelConfig,
    dtype: torch.dtype,
535
536
537
538
539
540
541
542
543
544
545
546
547
) -> List[torch.Tensor]:
    """Simulate Megatron-LM interleaved pipeline parallelism."""
    reset_rng_states()

    # Pipeline parallel configuration.
    num_layers = 2
    num_microbatches = 3
    layer_order = [1, 2, 1, 2, -2, -1, 1, 2, -2, -1, -2, -1]

    # Initialize model.
    model = torch.nn.ModuleList(
        [
            Linear(
548
549
                model_config.hidden_size,
                model_config.hidden_size,
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
                params_dtype=dtype,
            )
            for _ in range(num_layers)
        ]
    )

    # Initialize gradient buffers.
    for param in model.parameters():
        param.grad = torch.empty_like(param)

    # Make graphed version of model if needed.
    layer_forwards = {
        (i % num_layers, i // num_layers): model[i % num_layers]
        for i in range(num_layers * num_microbatches)
    }
    if with_graph:
        sample_args = tuple(
567
568
            (generate_data(model_config, dtype, warmup=True),)
            for _ in range(num_layers * num_microbatches)
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
        )
        layer_forwards = make_graphed_callables(
            tuple(model),
            sample_args,
            allow_unused_input=True,
            _order=layer_order,
        )
        layer_forwards = {
            (i // num_microbatches, i % num_microbatches): forward
            for i, forward in enumerate(layer_forwards)
        }

    # Optimizer.
    optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

    # Training loop.
    for _ in range(3):
        optimizer.zero_grad(set_to_none=False)

        # Generate data.
        inputs = {}
        grad_outputs = {}
        for layer_idx in range(num_layers):
            for microbatch_idx in range(num_microbatches):
593
594
                x = generate_data(model_config, dtype)
                dy = generate_data(model_config, dtype, requires_grad=False)
595
                idxs = (layer_idx, microbatch_idx)
596
                inputs[idxs] = x
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
                grad_outputs[idxs] = dy

        # Cache for layer outputs.
        outputs = {}

        def forward(layer_idx: int, microbatch_idx: int):
            """Helper function for forward steps"""
            idxs = (layer_idx, microbatch_idx)
            outputs[idxs] = layer_forwards[idxs](inputs[idxs])

        def backward(layer_idx: int, microbatch_idx: int):
            """Helper function for backward steps"""
            outputs[layer_idx, microbatch_idx].backward(grad_outputs[layer_idx, microbatch_idx])

        # Forward and backward steps.
        forward(0, 0)
        forward(1, 0)
        forward(0, 1)
        forward(1, 1)
        backward(1, 0)
        backward(0, 0)
        forward(0, 2)
        forward(1, 2)
        backward(1, 1)
        backward(0, 1)
        backward(1, 2)
        backward(0, 2)

        # Optimizer step.
        optimizer.step()

    outputs = [y for _, y in sorted(outputs.items())]
    return get_outputs(model, outputs)


def test_make_graphed_callables_with_interleaved_pipeline_parallelism(
633
634
    *,
    model_config: str = "small",
635
636
637
    dtype: torch.dtype = torch.float16,
) -> None:
    """Test CUDA graphs with Megatron-LM interleaved pipeline parallelism."""
638
639
    model_config = model_configs[model_config]
    kwargs = dict(model_config=model_config, dtype=dtype)
640
641
642
643
644
645
646
647
648
    outputs = _test_cuda_graphs_with_interleaved_pipeline_parallelism(
        with_graph=False,
        **kwargs,
    )
    graph_outputs = _test_cuda_graphs_with_interleaved_pipeline_parallelism(
        with_graph=True,
        **kwargs,
    )
    assert_all_equal(outputs, graph_outputs)