test_cuda_graphs.py 18.6 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.

5
from dataclasses import dataclass
6
7
import itertools
from typing import Iterable, List, Tuple, Union
8
9
10
11
import pytest

import torch
from transformer_engine.pytorch import (
12
13
14
15
16
17
18
19
    DotProductAttention,
    LayerNormLinear,
    LayerNormMLP,
    Linear,
    MultiheadAttention,
    TransformerLayer,
    fp8_autocast,
    fp8_model_init,
20
    make_graphed_callables,
21
22
23
)
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.utils import is_bf16_compatible
24
import transformer_engine.pytorch.ops as te_ops
25
26


27
# Check if FP8 is supported.
28
29
30
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()


31
# Record initial RNG state.
32
33
34
35
36
37
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
_cpu_rng_state = torch.get_rng_state()
_cuda_rng_state = torch.cuda.get_rng_state()

38

39
@dataclass
40
class ModelConfig:
41
    """Data tensor dimensions within Transformer model"""
42

43
44
45
46
47
    sequence_length: int
    batch_size: int
    hidden_size: int
    num_heads: int
    kv_channels: int
48

49

50
model_configs = {"small": ModelConfig(2, 32, 64, 2, 32)}
51

52
53
# Supported data types
dtypes: List[torch.dtype] = [torch.float32, torch.float16]
54
55
56
57
58
if is_bf16_compatible():  # bf16 requires sm_80 or higher
    dtypes.append(torch.bfloat16)


def reset_rng_states() -> None:
59
    """Revert to initial RNG state."""
60
61
62
63
64
65
66
67
68
69
70
    torch.set_rng_state(_cpu_rng_state)
    torch.cuda.set_rng_state(_cuda_rng_state)


@pytest.fixture(autouse=True)
def reset_global_fp8_state():
    yield
    FP8GlobalStateManager.reset()


def assert_all_equal(l1: List[torch.Tensor], l2: List[torch.Tensor], names=None) -> bool:
71
    """Check that two lists of tensors match exactly."""
72
    assert len(l1) == len(l2), "Unequal number of outputs."
73
74
    failure_message = "Output mismatches in:"
    failed_tensors = []
75
76
    for i, (t1, t2) in enumerate(zip(l1, l2)):
        if not torch.equal(t1, t2):
77
78
79
80
81
82
83
84
85
86
            failure_message += "\n    "
            if names is None:
                failure_message += f"tensor at idx={i}"
            else:
                failure_message += names[i]
            failed_tensors.append((t1, t2))
    if failed_tensors:
        print(failure_message)
        t1, t2 = failed_tensors[0]
        torch.testing.assert_close(t1, t2, rtol=0, atol=0)
87
88
89


def generate_data(
90
    model_config: ModelConfig,
91
92
    dtype: torch.dtype,
    warmup: bool = False,
93
94
    requires_grad: bool = True,
) -> torch.Tensor:
95
96
    """Generate synthetic data."""
    gen_func = torch.ones if warmup else torch.randn
97
98
99
100
    return gen_func(
        model_config.sequence_length,
        model_config.batch_size,
        model_config.hidden_size,
101
        device="cuda",
102
        requires_grad=requires_grad,
103
104
        dtype=dtype,
    )
105
106


107
108
109
110
def get_outputs(
    model: torch.nn.Module,
    output: Union[torch.Tensor, Iterable[torch.Tensor]],
) -> List[torch.Tensor]:
111
112
113
114
115
116
    """Return grads and params for comparsion."""
    values = []
    for param in model.parameters():
        values.append(param)
        if param.grad is not None:
            values.append(param.grad)
117
118
119
120
    if isinstance(output, torch.Tensor):
        values.append(output)
    else:
        values.extend(output)
121
122
123
    return values


124
125
126
127
128
129
130
131
132
133
class _Sequential(torch.nn.Sequential):
    """Sequential model that forwards keyword arguments to modules"""

    def forward(self, input_: torch.Tensor, **kwargs) -> torch.Tensor:
        x = input_
        for module in self:
            x = module(x, **kwargs)
        return x


134
135
136
137
138
139
140
141
142
143
144
# Supported modules
_test_cuda_graphs_modules: List[str] = [
    "transformer",
    "layernorm_mlp",
    "layernorm_linear",
    "linear",
    "mha",
    "linear_op",
]


145
146
def _test_cuda_graphs(
    *,
147
148
149
    graph_mode: str,
    module: str,
    model_config: ModelConfig,
150
151
152
153
154
155
    num_layers: int,
    dtype: torch.dtype,
    fp8: bool,
    fp8_params: bool,
    fp8_weight_caching: bool,
) -> List[torch.Tensor]:
156
    """Helper function for CUDA graph test."""
157
158
159
    reset_rng_states()
    FP8GlobalStateManager.reset()

160
161
162
163
164
    # Operation-based API does not support FP8 weight caching.
    if module == "linear_op":
        fp8_weight_caching = False

    # Create modules.
165
166
    with fp8_model_init(enabled=fp8_params):
        if module == "transformer":
167
168
            modules = [
                TransformerLayer(
169
170
171
                    model_config.hidden_size,
                    model_config.hidden_size,
                    model_config.num_heads,
172
173
174
175
176
177
178
                    hidden_dropout=0.0,
                    attention_dropout=0.0,
                    fuse_qkv_params=True,
                    params_dtype=dtype,
                )
                for _ in range(num_layers)
            ]
179
        elif module == "layernorm_mlp":
180
            modules = [
181
182
183
184
185
                LayerNormMLP(
                    model_config.hidden_size,
                    model_config.hidden_size,
                    params_dtype=dtype,
                )
186
187
                for _ in range(num_layers)
            ]
188
        elif module == "layernorm_linear":
189
            modules = [
190
191
192
193
194
                LayerNormLinear(
                    model_config.hidden_size,
                    model_config.hidden_size,
                    params_dtype=dtype,
                )
195
196
                for _ in range(num_layers)
            ]
197
        elif module == "mha":
198
199
            modules = [
                MultiheadAttention(
200
201
                    model_config.hidden_size,
                    model_config.num_heads,
202
203
204
205
206
207
                    attention_dropout=0.0,
                    params_dtype=dtype,
                    fuse_qkv_params=True,
                )
                for _ in range(num_layers)
            ]
208
        elif module == "linear":
209
            modules = [
210
211
212
213
214
215
                Linear(
                    model_config.hidden_size,
                    model_config.hidden_size,
                    device="cuda",
                    params_dtype=dtype,
                )
216
217
                for _ in range(num_layers)
            ]
218
        elif module == "linear_op":
219
            modules = [
220
221
222
223
224
225
226
                te_ops.Sequential(
                    te_ops.Linear(
                        model_config.hidden_size,
                        model_config.hidden_size,
                        dtype=dtype,
                    ),
                )
227
228
                for _ in range(num_layers)
            ]
229
230
        else:
            raise ValueError(f"Unknown module type ({module})")
231

232
233
234
235
236
        # Initialize gradient buffers.
        for module in modules:
            for param in module.parameters():
                param.grad = torch.empty_like(param)

237
        # Generate model and wrap API to return graphed version.
238
239
        if graph_mode == "full":
            # Graph entire model at once.
240
            model = torch.nn.Sequential(*modules)
241
242
            model = make_graphed_callables(
                model,
243
                (generate_data(model_config, dtype, warmup=True),),
244
245
246
247
248
                num_warmup_iters=10,
                fp8_enabled=fp8,
                fp8_weight_caching=fp8_weight_caching,
            )
        elif graph_mode == "individual":
249
            # Graph individual modules.
250
251
            modules = [
                make_graphed_callables(
252
                    module,
253
                    (generate_data(model_config, dtype, warmup=True),),
254
                    num_warmup_iters=10,
255
256
257
258
259
                    fp8_enabled=fp8,
                    fp8_weight_caching=fp8_weight_caching,
                )
                for module in modules
            ]
260
            model = _Sequential(*modules)
261
        else:
262
            model = _Sequential(*modules)
263

264
    # Optimizer.
265
    optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
266

267
    # Training steps.
268
    for _ in range(3):
269
        optimizer.zero_grad(set_to_none=False)
270
        for grad_accumulation_step in range(2):
271
272
            input_ = generate_data(model_config, dtype)
            grad_output = generate_data(model_config, dtype, requires_grad=False)
273
274
275
            with fp8_autocast(enabled=fp8):
                kwargs = {}
                if fp8_weight_caching:
276
                    kwargs["is_first_microbatch"] = grad_accumulation_step == 0
277
                output = model(input_, **kwargs)
278
            output.backward(grad_output)
279
        optimizer.step()
280
281
282
283

    return get_outputs(model, output)


284
@pytest.mark.parametrize("module", _test_cuda_graphs_modules)
285
@pytest.mark.parametrize("dtype", dtypes)
286
287
288
289
290
291
292
@pytest.mark.parametrize("fp8", (False, True))
@pytest.mark.parametrize("fp8_params", (False, True))
def test_make_graphed_callables(
    *,
    module: str,
    model_config: str = "small",
    num_layers: int = 3,
293
294
295
    dtype: torch.dtype,
    fp8: bool,
    fp8_params: bool,
296
    fp8_weight_caching: bool = False,
297
) -> None:
298
299

    # Skip invalid configurations.
300
301
302
303
    if fp8 and not fp8_available:
        pytest.skip(reason_for_no_fp8)
    if fp8_params and not fp8:
        pytest.skip("FP8 needed for FP8 parameters.")
304
305
    if fp8_weight_caching and not fp8:
        pytest.skip("FP8 needed for FP8 parameters.")
306

307
308
    # Run model with different CUDA graph settings.
    model_config = model_configs[model_config]
309
    kwargs = dict(
310
311
        module=module,
        model_config=model_config,
312
313
314
315
316
317
318
319
320
        num_layers=num_layers,
        dtype=dtype,
        fp8=fp8,
        fp8_params=fp8_params,
        fp8_weight_caching=fp8_weight_caching,
    )
    outputs = _test_cuda_graphs(graph_mode="none", **kwargs)
    graph_outputs_mode1 = _test_cuda_graphs(graph_mode="full", **kwargs)
    graph_outputs_mode2 = _test_cuda_graphs(graph_mode="individual", **kwargs)
321

322
    # Check that results match.
323
324
    assert_all_equal(outputs, graph_outputs_mode1)
    assert_all_equal(outputs, graph_outputs_mode2)
325
326


327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
_test_make_graphed_callables_with_fp8_weight_caching_modules = [
    "transformer",
    "layernorm_mlp",
    "layernorm_linear",
    "linear",
    "mha",
]


@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.parametrize(
    "module",
    _test_make_graphed_callables_with_fp8_weight_caching_modules,
)
@pytest.mark.parametrize("fp8_params", (False, True))
def test_make_graphed_callables_with_fp8_weight_caching(
343
    *,
344
345
346
347
348
349
350
351
352
353
354
355
356
357
    module: str,
    fp8_params: bool,
) -> None:
    test_make_graphed_callables(
        module=module,
        dtype=torch.float32,
        fp8=True,
        fp8_params=fp8_params,
        fp8_weight_caching=True,
    )


def generate_data_for_dot_product_attention(
    model_config: ModelConfig,
358
    dtype: torch.dtype,
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
    warmup: bool = False,
) -> List[torch.Tensor]:
    """Generate synthetic data for dot product attention."""
    gen_func = torch.ones if warmup else torch.randn
    return [
        gen_func(
            model_config.sequence_length,
            model_config.batch_size,
            model_config.num_heads,
            model_config.kv_channels,
            device="cuda",
            requires_grad=True,
            dtype=dtype,
        )
        for _ in range(3)
    ]


def _test_cuda_graphs_with_dot_product_attention(
    *,
379
    with_graph: bool,
380
381
    model_config: ModelConfig,
    dtype: torch.dtype,
382
) -> List[torch.Tensor]:
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
    """Helper function for CUDA graph test."""
    reset_rng_states()
    FP8GlobalStateManager.reset()

    # Create dot product attention module.
    assert model_config.hidden_size % model_config.num_heads == 0
    model = DotProductAttention(
        model_config.num_heads,
        model_config.kv_channels,
        attention_dropout=0.0,
    )

    # Graph model if needed.
    if with_graph:
        model = make_graphed_callables(
            model,
            generate_data_for_dot_product_attention(model_config, dtype, warmup=True),
            num_warmup_iters=10,
            fp8_enabled=False,
        )

    # Forward and backward passes.
    for _ in range(3):
        inputs = generate_data_for_dot_product_attention(model_config, dtype)
        grad_output = generate_data(model_config, dtype, requires_grad=False)
        output = model(*inputs)
        output.backward(grad_output)

    return get_outputs(model, output)


@pytest.mark.parametrize("dtype", dtypes)
def test_make_graphed_callables_with_dot_product_attention(
    *,
    model_config: str = "small",
    dtype: torch.dtype,
) -> None:
    """Test CUDA graphs with dot product attention."""
    model_config = model_configs[model_config]
    kwargs = dict(model_config=model_config, dtype=dtype)
    outputs = _test_cuda_graphs_with_dot_product_attention(with_graph=False, **kwargs)
    graph_outputs = _test_cuda_graphs_with_dot_product_attention(with_graph=True, **kwargs)
    assert_all_equal(outputs, graph_outputs)


def _test_cuda_graphs_with_kwargs(
    *,
    with_graph: bool,
    model_config: ModelConfig,
    dtype: torch.dtype,
) -> List[torch.Tensor]:
    """Helper function for CUDA graph test with keyword arguments."""
435
436
437
438
    reset_rng_states()

    # Initialize model.
    model = TransformerLayer(
439
440
441
        model_config.hidden_size,
        model_config.hidden_size,
        model_config.num_heads,
442
443
444
445
446
447
448
449
450
451
452
453
454
455
        hidden_dropout=0.0,
        attention_dropout=0.0,
        self_attn_mask_type="arbitrary",
        fuse_qkv_params=True,
        params_dtype=dtype,
    )

    # Initialize gradient buffers.
    for param in model.parameters():
        param.grad = torch.empty_like(param)

    # Make graphed version of model if needed.
    if with_graph:
        attn_mask = torch.zeros(
456
457
458
459
460
461
            (
                model_config.batch_size,
                1,
                model_config.sequence_length,
                model_config.sequence_length,
            ),
462
463
464
465
466
            dtype=torch.bool,
            device="cuda",
        )
        model = make_graphed_callables(
            model,
467
            (generate_data(model_config, dtype, warmup=True),),
468
469
470
471
472
473
474
475
476
477
478
            sample_kwargs=dict(attention_mask=attn_mask),
            allow_unused_input=True,
        )

    # Optimizer.
    optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

    # Training loop.
    for _ in range(3):
        optimizer.zero_grad(set_to_none=False)
        for grad_accumulation_step in range(2):
479
480
            input_ = generate_data(model_config, dtype)
            grad_output = generate_data(model_config, dtype, requires_grad=False)
481
482
            attn_mask = torch.randint(
                2,
483
484
485
486
487
488
                (
                    model_config.batch_size,
                    1,
                    model_config.sequence_length,
                    model_config.sequence_length,
                ),
489
490
491
                dtype=torch.bool,
                device="cuda",
            )
492
            output = model(input_, attention_mask=attn_mask)
493
494
495
496
497
498
499
            output.backward(grad_output)
        optimizer.step()

    return get_outputs(model, output)


def test_make_graphed_callables_with_kwargs(
500
501
    *,
    model_config: str = "small",
502
503
504
    dtype: torch.dtype = torch.float32,
) -> None:
    """Test CUDA graphs with keyword arguments."""
505
506
    model_config = model_configs[model_config]
    kwargs = dict(model_config=model_config, dtype=dtype)
507
508
509
510
511
512
513
514
    outputs = _test_cuda_graphs_with_kwargs(with_graph=False, **kwargs)
    graph_outputs = _test_cuda_graphs_with_kwargs(with_graph=True, **kwargs)
    assert_all_equal(outputs, graph_outputs)


def _test_cuda_graphs_with_interleaved_pipeline_parallelism(
    *,
    with_graph: bool,
515
516
    model_config: ModelConfig,
    dtype: torch.dtype,
517
518
519
520
521
522
523
524
525
526
527
528
529
) -> List[torch.Tensor]:
    """Simulate Megatron-LM interleaved pipeline parallelism."""
    reset_rng_states()

    # Pipeline parallel configuration.
    num_layers = 2
    num_microbatches = 3
    layer_order = [1, 2, 1, 2, -2, -1, 1, 2, -2, -1, -2, -1]

    # Initialize model.
    model = torch.nn.ModuleList(
        [
            Linear(
530
531
                model_config.hidden_size,
                model_config.hidden_size,
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
                params_dtype=dtype,
            )
            for _ in range(num_layers)
        ]
    )

    # Initialize gradient buffers.
    for param in model.parameters():
        param.grad = torch.empty_like(param)

    # Make graphed version of model if needed.
    layer_forwards = {
        (i % num_layers, i // num_layers): model[i % num_layers]
        for i in range(num_layers * num_microbatches)
    }
    if with_graph:
        sample_args = tuple(
549
550
            (generate_data(model_config, dtype, warmup=True),)
            for _ in range(num_layers * num_microbatches)
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
        )
        layer_forwards = make_graphed_callables(
            tuple(model),
            sample_args,
            allow_unused_input=True,
            _order=layer_order,
        )
        layer_forwards = {
            (i // num_microbatches, i % num_microbatches): forward
            for i, forward in enumerate(layer_forwards)
        }

    # Optimizer.
    optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

    # Training loop.
    for _ in range(3):
        optimizer.zero_grad(set_to_none=False)

        # Generate data.
        inputs = {}
        grad_outputs = {}
        for layer_idx in range(num_layers):
            for microbatch_idx in range(num_microbatches):
575
576
                x = generate_data(model_config, dtype)
                dy = generate_data(model_config, dtype, requires_grad=False)
577
                idxs = (layer_idx, microbatch_idx)
578
                inputs[idxs] = x
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
                grad_outputs[idxs] = dy

        # Cache for layer outputs.
        outputs = {}

        def forward(layer_idx: int, microbatch_idx: int):
            """Helper function for forward steps"""
            idxs = (layer_idx, microbatch_idx)
            outputs[idxs] = layer_forwards[idxs](inputs[idxs])

        def backward(layer_idx: int, microbatch_idx: int):
            """Helper function for backward steps"""
            outputs[layer_idx, microbatch_idx].backward(grad_outputs[layer_idx, microbatch_idx])

        # Forward and backward steps.
        forward(0, 0)
        forward(1, 0)
        forward(0, 1)
        forward(1, 1)
        backward(1, 0)
        backward(0, 0)
        forward(0, 2)
        forward(1, 2)
        backward(1, 1)
        backward(0, 1)
        backward(1, 2)
        backward(0, 2)

        # Optimizer step.
        optimizer.step()

    outputs = [y for _, y in sorted(outputs.items())]
    return get_outputs(model, outputs)


def test_make_graphed_callables_with_interleaved_pipeline_parallelism(
615
616
    *,
    model_config: str = "small",
617
618
619
    dtype: torch.dtype = torch.float16,
) -> None:
    """Test CUDA graphs with Megatron-LM interleaved pipeline parallelism."""
620
621
    model_config = model_configs[model_config]
    kwargs = dict(model_config=model_config, dtype=dtype)
622
623
624
625
626
627
628
629
630
    outputs = _test_cuda_graphs_with_interleaved_pipeline_parallelism(
        with_graph=False,
        **kwargs,
    )
    graph_outputs = _test_cuda_graphs_with_interleaved_pipeline_parallelism(
        with_graph=True,
        **kwargs,
    )
    assert_all_equal(outputs, graph_outputs)