test_cuda_graphs.py 15.3 KB
Newer Older
1
2
3
4
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

5
from dataclasses import dataclass
6
7
import itertools
from typing import Iterable, List, Tuple, Union
8
9
10
11
import pytest

import torch
from transformer_engine.pytorch import (
12
13
14
15
16
17
18
19
20
    DotProductAttention,
    LayerNormLinear,
    LayerNormMLP,
    Linear,
    make_graphed_callables,
    MultiheadAttention,
    TransformerLayer,
    fp8_autocast,
    fp8_model_init,
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
)
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.utils import is_bf16_compatible


# Only run FP8 tests on H100.
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()


seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# Record initial RNG state from script run.
_cpu_rng_state = torch.get_rng_state()
_cuda_rng_state = torch.cuda.get_rng_state()

37

38
@dataclass
39
class ModelConfig:
40
    """Data tensor dimensions within Transformer model"""
41

42
43
44
45
46
    sequence_length: int
    batch_size: int
    hidden_size: int
    num_heads: int
    kv_channels: int
47

48

49
model_configs = {"small": ModelConfig(2, 32, 64, 2, 32)}
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79

modules = ["transformer", "layernorm_mlp", "layernorm_linear", "linear", "mha", "dpa"]

all_boolean = [True, False]

dtypes = [torch.float32, torch.float16]
if is_bf16_compatible():  # bf16 requires sm_80 or higher
    dtypes.append(torch.bfloat16)


def reset_rng_states() -> None:
    """revert back to initial RNG state."""
    torch.set_rng_state(_cpu_rng_state)
    torch.cuda.set_rng_state(_cuda_rng_state)


@pytest.fixture(autouse=True)
def reset_global_fp8_state():
    yield
    FP8GlobalStateManager.reset()


def assert_all_equal(l1: List[torch.Tensor], l2: List[torch.Tensor], names=None) -> bool:
    """Ensures two lists are equal."""
    assert len(l1) == len(l2), "Unequal number of outputs."
    failed = False
    failed_tensors = ""
    for i, (t1, t2) in enumerate(zip(l1, l2)):
        if not torch.equal(t1, t2):
            failed = True
80
81
82
            failed_tensors += (
                f"    {names[i]}\n" if names is not None else f"    tensor at idx={i}\n"
            )
83
84
85
86
    assert not failed, "Output mismatches in:\n" + failed_tensors


def generate_data(
87
88
89
90
91
    config: ModelConfig,
    dtype: torch.dtype,
    dpa: bool = False,
    warmup: bool = False,
    return_grad_output: bool = False,
92
) -> Tuple[List[torch.Tensor], torch.Tensor]:
93
94
95
    """Generate synthetic data."""
    gen_func = torch.ones if warmup else torch.randn
    if dpa:
96
97
98
99
100
101
102
103
104
105
106
107
        inputs = [
            gen_func(
                config.sequence_length,
                config.batch_size,
                config.num_heads,
                config.kv_channels,
                device="cuda",
                requires_grad=True,
                dtype=dtype,
            )
            for _ in range(3)
        ]
108
    else:
109
110
111
112
113
114
115
116
117
118
119
120
        inputs = [
            gen_func(
                config.sequence_length,
                config.batch_size,
                config.hidden_size,
                device="cuda",
                requires_grad=True,
                dtype=dtype,
            )
        ]

    if not return_grad_output:
121
122
        return inputs

123
124
125
126
127
128
129
130
    grad_output = torch.randn(
        config.sequence_length,
        config.batch_size,
        config.hidden_size,
        device="cuda",
        dtype=dtype,
    )
    return inputs, grad_output
131
132


133
134
135
136
def get_outputs(
    model: torch.nn.Module,
    output: Union[torch.Tensor, Iterable[torch.Tensor]],
) -> List[torch.Tensor]:
137
138
139
140
141
142
    """Return grads and params for comparsion."""
    values = []
    for param in model.parameters():
        values.append(param)
        if param.grad is not None:
            values.append(param.grad)
143
144
145
146
    if isinstance(output, torch.Tensor):
        values.append(output)
    else:
        values.extend(output)
147
148
149
    return values


150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
class _Sequential(torch.nn.Sequential):
    """Sequential model that forwards keyword arguments to modules"""

    def forward(self, input_: torch.Tensor, **kwargs) -> torch.Tensor:
        x = input_
        for module in self:
            x = module(x, **kwargs)
        return x


def _test_cuda_graphs(
    *,
    config: ModelConfig,
    num_layers: int,
    dtype: torch.dtype,
    fp8: bool,
    fp8_params: bool,
    fp8_weight_caching: bool,
    module: str,
    graph_mode: str,
) -> List[torch.Tensor]:
171
    """Helper function for CUDA graph test."""
172
173
174
175
176
177
178
    reset_rng_states()
    FP8GlobalStateManager.reset()
    dpa = module == "dpa"

    with fp8_model_init(enabled=fp8_params):
        # Create modules.
        if module == "transformer":
179
180
181
182
183
184
185
186
187
188
189
190
            modules = [
                TransformerLayer(
                    config.hidden_size,
                    config.hidden_size,
                    config.num_heads,
                    hidden_dropout=0.0,
                    attention_dropout=0.0,
                    fuse_qkv_params=True,
                    params_dtype=dtype,
                )
                for _ in range(num_layers)
            ]
191
        elif module == "layernorm_mlp":
192
193
194
195
            modules = [
                LayerNormMLP(config.hidden_size, config.hidden_size, params_dtype=dtype)
                for _ in range(num_layers)
            ]
196
        elif module == "layernorm_linear":
197
198
199
200
            modules = [
                LayerNormLinear(config.hidden_size, config.hidden_size, params_dtype=dtype)
                for _ in range(num_layers)
            ]
201
        elif module == "mha":
202
203
204
205
206
207
208
209
210
211
            modules = [
                MultiheadAttention(
                    config.hidden_size,
                    config.num_heads,
                    attention_dropout=0.0,
                    params_dtype=dtype,
                    fuse_qkv_params=True,
                )
                for _ in range(num_layers)
            ]
212
        elif dpa:
213
            assert config.hidden_size % config.num_heads == 0, "Err."
214
            assert num_layers == 1, "Err."
215
216
217
218
            modules = [
                DotProductAttention(config.num_heads, config.kv_channels, attention_dropout=0.0)
                for _ in range(num_layers)
            ]
219
        else:
220
221
222
223
            modules = [
                Linear(config.hidden_size, config.hidden_size, device="cuda", params_dtype=dtype)
                for _ in range(num_layers)
            ]
224

225
226
227
228
229
        # Initialize gradient buffers.
        for module in modules:
            for param in module.parameters():
                param.grad = torch.empty_like(param)

230
        # Generate model and wrap API to return graphed version.
231
232
233
234
235
236
237
238
239
240
241
242
243
244
        if graph_mode == "full":
            # Graph entire model at once.
            model = modules[0] if dpa else torch.nn.Sequential(*modules)
            model = make_graphed_callables(
                model,
                generate_data(config, dtype, dpa=dpa, warmup=True),
                num_warmup_iters=10,
                fp8_enabled=fp8,
                fp8_weight_caching=fp8_weight_caching,
            )
        elif graph_mode == "individual":
            # Graph individual modules
            modules = [
                make_graphed_callables(
245
                    module,
246
                    generate_data(config, dtype, dpa=dpa, warmup=True),
247
                    num_warmup_iters=10,
248
249
250
251
252
253
                    fp8_enabled=fp8,
                    fp8_weight_caching=fp8_weight_caching,
                )
                for module in modules
            ]
            model = modules[0] if dpa else _Sequential(*modules)
254
        else:
255
            model = modules[0] if dpa else _Sequential(*modules)
256

257
    # Optimizer.
258
    if not dpa:
259
        optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
260
261

    # Launch.
262
263
264
265
266
267
268
269
    for _ in range(3):
        if not dpa:
            optimizer.zero_grad(set_to_none=False)
        for grad_accumulation_step in range(2):
            inputs, grad_output = generate_data(config, dtype, dpa=dpa, return_grad_output=True)
            with fp8_autocast(enabled=fp8):
                kwargs = {}
                if fp8_weight_caching:
270
                    kwargs["is_first_microbatch"] = grad_accumulation_step == 0
271
272
                output = model(*inputs, **kwargs)
            output.backward(grad_output)
273
274
275
276
277
278
279
280
        if not dpa:
            optimizer.step()

    return get_outputs(model, output)


@pytest.mark.parametrize("dtype", dtypes)
@pytest.mark.parametrize("model", model_configs.keys())
281
@pytest.mark.parametrize("num_layers", [1, 3])
282
283
@pytest.mark.parametrize("fp8", all_boolean)
@pytest.mark.parametrize("fp8_params", all_boolean)
284
@pytest.mark.parametrize("fp8_weight_caching", all_boolean)
285
@pytest.mark.parametrize("module", modules)
286
287
288
289
290
291
292
293
294
def test_gpt_make_graphed_callables(
    dtype: torch.dtype,
    model: str,
    num_layers: int,
    fp8: bool,
    fp8_params: bool,
    fp8_weight_caching: bool,
    module: str,
) -> None:
295
296
297
298
    if fp8 and not fp8_available:
        pytest.skip(reason_for_no_fp8)
    if fp8_params and not fp8:
        pytest.skip("FP8 needed for FP8 parameters.")
299
300
    if fp8_weight_caching and not fp8:
        pytest.skip("FP8 needed for FP8 parameters.")
301
302
303
304
305
    if module == "dpa" and num_layers > 1:
        pytest.skip("Max 1 layer for DPA.")

    config = model_configs[model]

306
307
308
309
310
311
312
313
314
315
316
317
    kwargs = dict(
        config=config,
        num_layers=num_layers,
        dtype=dtype,
        fp8=fp8,
        fp8_params=fp8_params,
        fp8_weight_caching=fp8_weight_caching,
        module=module,
    )
    outputs = _test_cuda_graphs(graph_mode="none", **kwargs)
    graph_outputs_mode1 = _test_cuda_graphs(graph_mode="full", **kwargs)
    graph_outputs_mode2 = _test_cuda_graphs(graph_mode="individual", **kwargs)
318
319
320
321

    # Check that results match
    assert_all_equal(outputs, graph_outputs_mode1)
    assert_all_equal(outputs, graph_outputs_mode2)
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511


def _test_cuda_graphs_with_kwargs(
    *,
    config: ModelConfig,
    dtype: torch.dtype,
    with_graph: bool,
) -> List[torch.Tensor]:
    """Simulate Megatron-LM interleaved pipeline parallelism."""
    reset_rng_states()

    # Initialize model.
    model = TransformerLayer(
        config.hidden_size,
        config.hidden_size,
        config.num_heads,
        hidden_dropout=0.0,
        attention_dropout=0.0,
        self_attn_mask_type="arbitrary",
        fuse_qkv_params=True,
        params_dtype=dtype,
    )

    # Initialize gradient buffers.
    for param in model.parameters():
        param.grad = torch.empty_like(param)

    # Make graphed version of model if needed.
    if with_graph:
        attn_mask = torch.zeros(
            (config.batch_size, 1, config.sequence_length, config.sequence_length),
            dtype=torch.bool,
            device="cuda",
        )
        model = make_graphed_callables(
            model,
            generate_data(config, dtype, warmup=True),
            sample_kwargs=dict(attention_mask=attn_mask),
            allow_unused_input=True,
        )

    # Optimizer.
    optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

    # Training loop.
    for _ in range(3):
        optimizer.zero_grad(set_to_none=False)
        for grad_accumulation_step in range(2):
            inputs, grad_output = generate_data(config, dtype, return_grad_output=True)
            attn_mask = torch.randint(
                2,
                (config.batch_size, 1, config.sequence_length, config.sequence_length),
                dtype=torch.bool,
                device="cuda",
            )
            output = model(*inputs, attention_mask=attn_mask)
            output.backward(grad_output)
        optimizer.step()

    return get_outputs(model, output)


def test_make_graphed_callables_with_kwargs(
    dtype: torch.dtype = torch.float32,
    model: str = "small",
) -> None:
    """Test CUDA graphs with keyword arguments."""
    config = model_configs[model]
    kwargs = dict(config=config, dtype=dtype)
    outputs = _test_cuda_graphs_with_kwargs(with_graph=False, **kwargs)
    graph_outputs = _test_cuda_graphs_with_kwargs(with_graph=True, **kwargs)
    assert_all_equal(outputs, graph_outputs)


def _test_cuda_graphs_with_interleaved_pipeline_parallelism(
    *,
    config: ModelConfig,
    dtype: torch.dtype,
    with_graph: bool,
) -> List[torch.Tensor]:
    """Simulate Megatron-LM interleaved pipeline parallelism."""
    reset_rng_states()

    # Pipeline parallel configuration.
    num_layers = 2
    num_microbatches = 3
    layer_order = [1, 2, 1, 2, -2, -1, 1, 2, -2, -1, -2, -1]

    # Initialize model.
    model = torch.nn.ModuleList(
        [
            Linear(
                config.hidden_size,
                config.hidden_size,
                params_dtype=dtype,
            )
            for _ in range(num_layers)
        ]
    )

    # Initialize gradient buffers.
    for param in model.parameters():
        param.grad = torch.empty_like(param)

    # Make graphed version of model if needed.
    layer_forwards = {
        (i % num_layers, i // num_layers): model[i % num_layers]
        for i in range(num_layers * num_microbatches)
    }
    if with_graph:
        sample_args = tuple(
            generate_data(config, dtype, warmup=True) for _ in range(num_layers * num_microbatches)
        )
        layer_forwards = make_graphed_callables(
            tuple(model),
            sample_args,
            allow_unused_input=True,
            _order=layer_order,
        )
        layer_forwards = {
            (i // num_microbatches, i % num_microbatches): forward
            for i, forward in enumerate(layer_forwards)
        }

    # Optimizer.
    optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

    # Training loop.
    for _ in range(3):
        optimizer.zero_grad(set_to_none=False)

        # Generate data.
        inputs = {}
        grad_outputs = {}
        for layer_idx in range(num_layers):
            for microbatch_idx in range(num_microbatches):
                x, dy = generate_data(config, dtype, return_grad_output=True)
                idxs = (layer_idx, microbatch_idx)
                inputs[idxs] = x[0]
                grad_outputs[idxs] = dy

        # Cache for layer outputs.
        outputs = {}

        def forward(layer_idx: int, microbatch_idx: int):
            """Helper function for forward steps"""
            idxs = (layer_idx, microbatch_idx)
            outputs[idxs] = layer_forwards[idxs](inputs[idxs])

        def backward(layer_idx: int, microbatch_idx: int):
            """Helper function for backward steps"""
            outputs[layer_idx, microbatch_idx].backward(grad_outputs[layer_idx, microbatch_idx])

        # Forward and backward steps.
        forward(0, 0)
        forward(1, 0)
        forward(0, 1)
        forward(1, 1)
        backward(1, 0)
        backward(0, 0)
        forward(0, 2)
        forward(1, 2)
        backward(1, 1)
        backward(0, 1)
        backward(1, 2)
        backward(0, 2)

        # Optimizer step.
        optimizer.step()

    outputs = [y for _, y in sorted(outputs.items())]
    return get_outputs(model, outputs)


def test_make_graphed_callables_with_interleaved_pipeline_parallelism(
    dtype: torch.dtype = torch.float16,
    model: str = "small",
) -> None:
    """Test CUDA graphs with Megatron-LM interleaved pipeline parallelism."""
    config = model_configs[model]
    kwargs = dict(config=config, dtype=dtype)
    outputs = _test_cuda_graphs_with_interleaved_pipeline_parallelism(
        with_graph=False,
        **kwargs,
    )
    graph_outputs = _test_cuda_graphs_with_interleaved_pipeline_parallelism(
        with_graph=True,
        **kwargs,
    )
    assert_all_equal(outputs, graph_outputs)