test_cuda_graphs.py 9.27 KB
Newer Older
1
2
3
4
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

5
from dataclasses import dataclass
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from typing import List, Tuple
import pytest

import torch
from transformer_engine.pytorch import (
    DotProductAttention, LayerNormLinear, LayerNormMLP, Linear, make_graphed_callables,
    MultiheadAttention, TransformerLayer, fp8_autocast, fp8_model_init,
)
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.utils import is_bf16_compatible


# Only run FP8 tests on H100.
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()


seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# Record initial RNG state from script run.
_cpu_rng_state = torch.get_rng_state()
_cuda_rng_state = torch.cuda.get_rng_state()

29
@dataclass
30
class ModelConfig:
31
32
33
34
35
36
    """Data tensor dimensions within Transformer model"""
    sequence_length: int
    batch_size: int
    hidden_size: int
    num_heads: int
    kv_channels: int
37

38
model_configs = {"small": ModelConfig(2, 32, 64, 2, 32)}
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73

modules = ["transformer", "layernorm_mlp", "layernorm_linear", "linear", "mha", "dpa"]

all_boolean = [True, False]

dtypes = [torch.float32, torch.float16]
if is_bf16_compatible():  # bf16 requires sm_80 or higher
    dtypes.append(torch.bfloat16)


def reset_rng_states() -> None:
    """revert back to initial RNG state."""
    torch.set_rng_state(_cpu_rng_state)
    torch.cuda.set_rng_state(_cuda_rng_state)


@pytest.fixture(autouse=True)
def reset_global_fp8_state():
    yield
    FP8GlobalStateManager.reset()


def assert_all_equal(l1: List[torch.Tensor], l2: List[torch.Tensor], names=None) -> bool:
    """Ensures two lists are equal."""
    assert len(l1) == len(l2), "Unequal number of outputs."
    failed = False
    failed_tensors = ""
    for i, (t1, t2) in enumerate(zip(l1, l2)):
        if not torch.equal(t1, t2):
            failed = True
            failed_tensors += f"    {names[i]}\n" if names is not None else f"    tensor at idx={i}\n"
    assert not failed, "Output mismatches in:\n" + failed_tensors


def generate_data(
74
75
76
77
78
    config: ModelConfig,
    dtype: torch.dtype,
    dpa: bool = False,
    warmup: bool = False,
    return_grad_output: bool = False,
79
80
81
82
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Generate synthetic data."""
    gen_func = torch.ones if warmup else torch.randn
    if dpa:
83
84
85
86
87
88
89
90
91
92
93
94
        inputs = [
            gen_func(
                config.sequence_length,
                config.batch_size,
                config.num_heads,
                config.kv_channels,
                device="cuda",
                requires_grad=True,
                dtype=dtype,
            )
            for _ in range(3)
        ]
95
    else:
96
97
98
99
100
101
102
103
104
105
106
107
        inputs = [
            gen_func(
                config.sequence_length,
                config.batch_size,
                config.hidden_size,
                device="cuda",
                requires_grad=True,
                dtype=dtype,
            )
        ]

    if not return_grad_output:
108
109
        return inputs

110
111
112
113
114
115
116
117
    grad_output = torch.randn(
        config.sequence_length,
        config.batch_size,
        config.hidden_size,
        device="cuda",
        dtype=dtype,
    )
    return inputs, grad_output
118
119
120
121
122
123
124
125
126
127
128
129
130


def get_outputs(model, output):
    """Return grads and params for comparsion."""
    values = []
    for param in model.parameters():
        values.append(param)
        if param.grad is not None:
            values.append(param.grad)
    values.append(output)
    return values


131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
class _Sequential(torch.nn.Sequential):
    """Sequential model that forwards keyword arguments to modules"""

    def forward(self, input_: torch.Tensor, **kwargs) -> torch.Tensor:
        x = input_
        for module in self:
            x = module(x, **kwargs)
        return x


def _test_cuda_graphs(
    *,
    config: ModelConfig,
    num_layers: int,
    dtype: torch.dtype,
    fp8: bool,
    fp8_params: bool,
    fp8_weight_caching: bool,
    module: str,
    graph_mode: str,
) -> List[torch.Tensor]:
152
153
154
155
156
157
158
159
160
    """Helper function for test."""
    reset_rng_states()
    FP8GlobalStateManager.reset()
    dpa = module == "dpa"

    with fp8_model_init(enabled=fp8_params):
        # Create modules.
        if module == "transformer":
            modules = [TransformerLayer(
161
162
163
                            config.hidden_size,
                            config.hidden_size,
                            config.num_heads,
164
165
166
167
168
169
170
                            hidden_dropout=0.0,
                            attention_dropout=0.0,
                            fuse_qkv_params=True,
                            params_dtype=dtype,
                       ) for _ in range(num_layers)]
        elif module == "layernorm_mlp":
            modules = [LayerNormMLP(
171
                config.hidden_size, config.hidden_size, params_dtype=dtype
172
173
174
            ) for _ in range(num_layers)]
        elif module == "layernorm_linear":
            modules = [LayerNormLinear(
175
                config.hidden_size, config.hidden_size, params_dtype=dtype
176
177
178
            ) for _ in range(num_layers)]
        elif module == "mha":
            modules = [MultiheadAttention(
179
180
                            config.hidden_size,
                            config.num_heads,
181
182
183
184
185
                            attention_dropout=0.0,
                            params_dtype=dtype,
                            fuse_qkv_params=True,
                       ) for _ in range(num_layers)]
        elif dpa:
186
            assert config.hidden_size % config.num_heads == 0, "Err."
187
188
            assert num_layers == 1, "Err."
            modules = [DotProductAttention(
189
                        config.num_heads, config.kv_channels, attention_dropout=0.0
190
191
192
                        ) for _ in range(num_layers)]
        else:
            modules = [Linear(
193
                config.hidden_size, config.hidden_size, device="cuda", params_dtype=dtype
194
195
            ) for _ in range(num_layers)]

196
197
198
199
200
        # Initialize gradient buffers.
        for module in modules:
            for param in module.parameters():
                param.grad = torch.empty_like(param)

201
        # Generate model and wrap API to return graphed version.
202
203
204
205
206
207
208
209
210
211
212
213
214
215
        if graph_mode == "full":
            # Graph entire model at once.
            model = modules[0] if dpa else torch.nn.Sequential(*modules)
            model = make_graphed_callables(
                model,
                generate_data(config, dtype, dpa=dpa, warmup=True),
                num_warmup_iters=10,
                fp8_enabled=fp8,
                fp8_weight_caching=fp8_weight_caching,
            )
        elif graph_mode == "individual":
            # Graph individual modules
            modules = [
                make_graphed_callables(
216
                    module,
217
                    generate_data(config, dtype, dpa=dpa, warmup=True),
218
                    num_warmup_iters=10,
219
220
221
222
223
224
                    fp8_enabled=fp8,
                    fp8_weight_caching=fp8_weight_caching,
                )
                for module in modules
            ]
            model = modules[0] if dpa else _Sequential(*modules)
225
        else:
226
            model = modules[0] if dpa else _Sequential(*modules)
227
228
229

    # Loss function and optimizer.
    if not dpa:
230
        optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
231
232

    # Launch.
233
234
235
236
237
238
239
240
241
242
243
    for _ in range(3):
        if not dpa:
            optimizer.zero_grad(set_to_none=False)
        for grad_accumulation_step in range(2):
            inputs, grad_output = generate_data(config, dtype, dpa=dpa, return_grad_output=True)
            with fp8_autocast(enabled=fp8):
                kwargs = {}
                if fp8_weight_caching:
                    kwargs["is_first_microbatch"] = (grad_accumulation_step == 0)
                output = model(*inputs, **kwargs)
            output.backward(grad_output)
244
245
246
247
248
249
250
251
        if not dpa:
            optimizer.step()

    return get_outputs(model, output)


@pytest.mark.parametrize("dtype", dtypes)
@pytest.mark.parametrize("model", model_configs.keys())
252
@pytest.mark.parametrize("num_layers", [1, 3])
253
254
@pytest.mark.parametrize("fp8", all_boolean)
@pytest.mark.parametrize("fp8_params", all_boolean)
255
@pytest.mark.parametrize("fp8_weight_caching", all_boolean)
256
@pytest.mark.parametrize("module", modules)
257
258
259
260
261
262
263
264
265
def test_gpt_make_graphed_callables(
    dtype: torch.dtype,
    model: str,
    num_layers: int,
    fp8: bool,
    fp8_params: bool,
    fp8_weight_caching: bool,
    module: str,
) -> None:
266
267
268
269
    if fp8 and not fp8_available:
        pytest.skip(reason_for_no_fp8)
    if fp8_params and not fp8:
        pytest.skip("FP8 needed for FP8 parameters.")
270
271
    if fp8_weight_caching and not fp8:
        pytest.skip("FP8 needed for FP8 parameters.")
272
273
274
275
276
    if module == "dpa" and num_layers > 1:
        pytest.skip("Max 1 layer for DPA.")

    config = model_configs[model]

277
278
279
280
281
282
283
284
285
286
287
288
    kwargs = dict(
        config=config,
        num_layers=num_layers,
        dtype=dtype,
        fp8=fp8,
        fp8_params=fp8_params,
        fp8_weight_caching=fp8_weight_caching,
        module=module,
    )
    outputs = _test_cuda_graphs(graph_mode="none", **kwargs)
    graph_outputs_mode1 = _test_cuda_graphs(graph_mode="full", **kwargs)
    graph_outputs_mode2 = _test_cuda_graphs(graph_mode="individual", **kwargs)
289
290
291
292

    # Check that results match
    assert_all_equal(outputs, graph_outputs_mode1)
    assert_all_equal(outputs, graph_outputs_mode2)