test_cuda_graphs.py 9.36 KB
Newer Older
1
2
3
4
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

5
from dataclasses import dataclass
6
7
8
9
10
from typing import List, Tuple
import pytest

import torch
from transformer_engine.pytorch import (
11
12
13
14
15
16
17
18
19
    DotProductAttention,
    LayerNormLinear,
    LayerNormMLP,
    Linear,
    make_graphed_callables,
    MultiheadAttention,
    TransformerLayer,
    fp8_autocast,
    fp8_model_init,
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
)
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.utils import is_bf16_compatible


# Only run FP8 tests on H100.
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()


seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# Record initial RNG state from script run.
_cpu_rng_state = torch.get_rng_state()
_cuda_rng_state = torch.cuda.get_rng_state()

36

37
@dataclass
38
class ModelConfig:
39
    """Data tensor dimensions within Transformer model"""
40

41
42
43
44
45
    sequence_length: int
    batch_size: int
    hidden_size: int
    num_heads: int
    kv_channels: int
46

47

48
model_configs = {"small": ModelConfig(2, 32, 64, 2, 32)}
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78

modules = ["transformer", "layernorm_mlp", "layernorm_linear", "linear", "mha", "dpa"]

all_boolean = [True, False]

dtypes = [torch.float32, torch.float16]
if is_bf16_compatible():  # bf16 requires sm_80 or higher
    dtypes.append(torch.bfloat16)


def reset_rng_states() -> None:
    """revert back to initial RNG state."""
    torch.set_rng_state(_cpu_rng_state)
    torch.cuda.set_rng_state(_cuda_rng_state)


@pytest.fixture(autouse=True)
def reset_global_fp8_state():
    yield
    FP8GlobalStateManager.reset()


def assert_all_equal(l1: List[torch.Tensor], l2: List[torch.Tensor], names=None) -> bool:
    """Ensures two lists are equal."""
    assert len(l1) == len(l2), "Unequal number of outputs."
    failed = False
    failed_tensors = ""
    for i, (t1, t2) in enumerate(zip(l1, l2)):
        if not torch.equal(t1, t2):
            failed = True
79
80
81
            failed_tensors += (
                f"    {names[i]}\n" if names is not None else f"    tensor at idx={i}\n"
            )
82
83
84
85
    assert not failed, "Output mismatches in:\n" + failed_tensors


def generate_data(
86
87
88
89
90
    config: ModelConfig,
    dtype: torch.dtype,
    dpa: bool = False,
    warmup: bool = False,
    return_grad_output: bool = False,
91
92
93
94
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Generate synthetic data."""
    gen_func = torch.ones if warmup else torch.randn
    if dpa:
95
96
97
98
99
100
101
102
103
104
105
106
        inputs = [
            gen_func(
                config.sequence_length,
                config.batch_size,
                config.num_heads,
                config.kv_channels,
                device="cuda",
                requires_grad=True,
                dtype=dtype,
            )
            for _ in range(3)
        ]
107
    else:
108
109
110
111
112
113
114
115
116
117
118
119
        inputs = [
            gen_func(
                config.sequence_length,
                config.batch_size,
                config.hidden_size,
                device="cuda",
                requires_grad=True,
                dtype=dtype,
            )
        ]

    if not return_grad_output:
120
121
        return inputs

122
123
124
125
126
127
128
129
    grad_output = torch.randn(
        config.sequence_length,
        config.batch_size,
        config.hidden_size,
        device="cuda",
        dtype=dtype,
    )
    return inputs, grad_output
130
131
132
133
134
135
136
137
138
139
140
141
142


def get_outputs(model, output):
    """Return grads and params for comparsion."""
    values = []
    for param in model.parameters():
        values.append(param)
        if param.grad is not None:
            values.append(param.grad)
    values.append(output)
    return values


143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
class _Sequential(torch.nn.Sequential):
    """Sequential model that forwards keyword arguments to modules"""

    def forward(self, input_: torch.Tensor, **kwargs) -> torch.Tensor:
        x = input_
        for module in self:
            x = module(x, **kwargs)
        return x


def _test_cuda_graphs(
    *,
    config: ModelConfig,
    num_layers: int,
    dtype: torch.dtype,
    fp8: bool,
    fp8_params: bool,
    fp8_weight_caching: bool,
    module: str,
    graph_mode: str,
) -> List[torch.Tensor]:
164
165
166
167
168
169
170
171
    """Helper function for test."""
    reset_rng_states()
    FP8GlobalStateManager.reset()
    dpa = module == "dpa"

    with fp8_model_init(enabled=fp8_params):
        # Create modules.
        if module == "transformer":
172
173
174
175
176
177
178
179
180
181
182
183
            modules = [
                TransformerLayer(
                    config.hidden_size,
                    config.hidden_size,
                    config.num_heads,
                    hidden_dropout=0.0,
                    attention_dropout=0.0,
                    fuse_qkv_params=True,
                    params_dtype=dtype,
                )
                for _ in range(num_layers)
            ]
184
        elif module == "layernorm_mlp":
185
186
187
188
            modules = [
                LayerNormMLP(config.hidden_size, config.hidden_size, params_dtype=dtype)
                for _ in range(num_layers)
            ]
189
        elif module == "layernorm_linear":
190
191
192
193
            modules = [
                LayerNormLinear(config.hidden_size, config.hidden_size, params_dtype=dtype)
                for _ in range(num_layers)
            ]
194
        elif module == "mha":
195
196
197
198
199
200
201
202
203
204
            modules = [
                MultiheadAttention(
                    config.hidden_size,
                    config.num_heads,
                    attention_dropout=0.0,
                    params_dtype=dtype,
                    fuse_qkv_params=True,
                )
                for _ in range(num_layers)
            ]
205
        elif dpa:
206
            assert config.hidden_size % config.num_heads == 0, "Err."
207
            assert num_layers == 1, "Err."
208
209
210
211
            modules = [
                DotProductAttention(config.num_heads, config.kv_channels, attention_dropout=0.0)
                for _ in range(num_layers)
            ]
212
        else:
213
214
215
216
            modules = [
                Linear(config.hidden_size, config.hidden_size, device="cuda", params_dtype=dtype)
                for _ in range(num_layers)
            ]
217

218
219
220
221
222
        # Initialize gradient buffers.
        for module in modules:
            for param in module.parameters():
                param.grad = torch.empty_like(param)

223
        # Generate model and wrap API to return graphed version.
224
225
226
227
228
229
230
231
232
233
234
235
236
237
        if graph_mode == "full":
            # Graph entire model at once.
            model = modules[0] if dpa else torch.nn.Sequential(*modules)
            model = make_graphed_callables(
                model,
                generate_data(config, dtype, dpa=dpa, warmup=True),
                num_warmup_iters=10,
                fp8_enabled=fp8,
                fp8_weight_caching=fp8_weight_caching,
            )
        elif graph_mode == "individual":
            # Graph individual modules
            modules = [
                make_graphed_callables(
238
                    module,
239
                    generate_data(config, dtype, dpa=dpa, warmup=True),
240
                    num_warmup_iters=10,
241
242
243
244
245
246
                    fp8_enabled=fp8,
                    fp8_weight_caching=fp8_weight_caching,
                )
                for module in modules
            ]
            model = modules[0] if dpa else _Sequential(*modules)
247
        else:
248
            model = modules[0] if dpa else _Sequential(*modules)
249
250
251

    # Loss function and optimizer.
    if not dpa:
252
        optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
253
254

    # Launch.
255
256
257
258
259
260
261
262
    for _ in range(3):
        if not dpa:
            optimizer.zero_grad(set_to_none=False)
        for grad_accumulation_step in range(2):
            inputs, grad_output = generate_data(config, dtype, dpa=dpa, return_grad_output=True)
            with fp8_autocast(enabled=fp8):
                kwargs = {}
                if fp8_weight_caching:
263
                    kwargs["is_first_microbatch"] = grad_accumulation_step == 0
264
265
                output = model(*inputs, **kwargs)
            output.backward(grad_output)
266
267
268
269
270
271
272
273
        if not dpa:
            optimizer.step()

    return get_outputs(model, output)


@pytest.mark.parametrize("dtype", dtypes)
@pytest.mark.parametrize("model", model_configs.keys())
274
@pytest.mark.parametrize("num_layers", [1, 3])
275
276
@pytest.mark.parametrize("fp8", all_boolean)
@pytest.mark.parametrize("fp8_params", all_boolean)
277
@pytest.mark.parametrize("fp8_weight_caching", all_boolean)
278
@pytest.mark.parametrize("module", modules)
279
280
281
282
283
284
285
286
287
def test_gpt_make_graphed_callables(
    dtype: torch.dtype,
    model: str,
    num_layers: int,
    fp8: bool,
    fp8_params: bool,
    fp8_weight_caching: bool,
    module: str,
) -> None:
288
289
290
291
    if fp8 and not fp8_available:
        pytest.skip(reason_for_no_fp8)
    if fp8_params and not fp8:
        pytest.skip("FP8 needed for FP8 parameters.")
292
293
    if fp8_weight_caching and not fp8:
        pytest.skip("FP8 needed for FP8 parameters.")
294
295
296
297
298
    if module == "dpa" and num_layers > 1:
        pytest.skip("Max 1 layer for DPA.")

    config = model_configs[model]

299
300
301
302
303
304
305
306
307
308
309
310
    kwargs = dict(
        config=config,
        num_layers=num_layers,
        dtype=dtype,
        fp8=fp8,
        fp8_params=fp8_params,
        fp8_weight_caching=fp8_weight_caching,
        module=module,
    )
    outputs = _test_cuda_graphs(graph_mode="none", **kwargs)
    graph_outputs_mode1 = _test_cuda_graphs(graph_mode="full", **kwargs)
    graph_outputs_mode2 = _test_cuda_graphs(graph_mode="individual", **kwargs)
311
312
313
314

    # Check that results match
    assert_all_equal(outputs, graph_outputs_mode1)
    assert_all_equal(outputs, graph_outputs_mode2)