test_toy_llama.py 17.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
"""
Test the piecewise compilation with a simple model, comparing the output
with and without the piecewise compilation.
6
7
8
9

This is a tractable model, the weights and computation are specially designed
if the config `tractable_init` is set to True. Otherwise, the weights are
initialized randomly with a fixed seed.
10
"""
11

12
from copy import deepcopy
13
from dataclasses import dataclass
14
from typing import Any
15

16
import pytest
17
18
19
20
21
import torch
from torch import nn

from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile
22
23
24
25
26
27
28
from vllm.config import (
    CompilationConfig,
    CompilationLevel,
    CUDAGraphMode,
    VllmConfig,
    set_current_vllm_config,
)
29
from vllm.forward_context import BatchDescriptor, set_forward_context
30
from vllm.utils import is_torch_equal_or_newer
31

32
33
# This import automatically registers `torch.ops.silly.attention`
from .. import silly_attention  # noqa: F401
34
35


36
37
38
39
40
41
@dataclass
class LlamaConfig:
    hidden_size: int = 128
    mlp_size: int = 256
    vocab_size: int = 128
    num_layers: int = 2
42
43
44
45
    init_value: float = 1.0
    tractable_init: bool = False
    random_seed: int = 0

46
    def compute_hash(self) -> str:
47
        factors: list[Any] = []
48
49
50
51
52
53
        for k, v in self.__dict__.items():
            if k == "random_seed":
                continue
            factors.append((k, v))
        factors.sort()
        import hashlib
54
55

        return hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
56

57
58
    def __post_init__(self):
        assert self.mlp_size >= self.hidden_size
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74


class LlamaMLP(nn.Module):
    def __init__(self, config: LlamaConfig) -> None:
        super().__init__()
        self.gate_up_projection = nn.Linear(
            in_features=config.hidden_size,
            out_features=config.mlp_size * 2,
            bias=False,
        )
        self.down_projection = nn.Linear(
            in_features=config.mlp_size,
            out_features=config.hidden_size,
            bias=False,
        )

75
        if config.tractable_init:
76
77
            nn.init.eye_(self.gate_up_projection.weight.data[: config.mlp_size])
            nn.init.eye_(self.gate_up_projection.weight.data[config.mlp_size :])
78
79
            nn.init.eye_(self.down_projection.weight.data)
        else:
80
81
82
83
84
85
86
87
88
89
            nn.init.xavier_normal_(
                self.gate_up_projection.weight.data,
                generator=torch.Generator().manual_seed(config.random_seed),
                gain=0.001,
            )
            nn.init.xavier_normal_(
                self.down_projection.weight.data,
                generator=torch.Generator().manual_seed(config.random_seed),
                gain=0.001,
            )
90
91

    def forward(self, x):
92
93
        # for tractable_init and positive input, this is
        # essentially an elementwise-square
94
        x = self.gate_up_projection(x)
95
        x = x[:, : x.size(1) // 2] * torch.nn.functional.relu(x[:, x.size(1) // 2 :])
96
97
98
99
100
101
102
103
104
105
        x = self.down_projection(x)
        return x


class LlamaAttention(nn.Module):
    def __init__(self, config: LlamaConfig) -> None:
        super().__init__()
        self.qkv_projection = nn.Linear(
            in_features=config.hidden_size,
            out_features=config.hidden_size * 3,
106
            bias=False,
107
108
109
110
111
        )

        self.output_projection = nn.Linear(
            in_features=config.hidden_size,
            out_features=config.hidden_size,
112
            bias=False,
113
114
        )

115
        if config.tractable_init:
116
117
118
119
120
121
122
            nn.init.eye_(self.qkv_projection.weight.data[: config.hidden_size])
            nn.init.eye_(
                self.qkv_projection.weight.data[
                    config.hidden_size : 2 * config.hidden_size
                ]
            )
            nn.init.eye_(self.qkv_projection.weight.data[2 * config.hidden_size :])
123
124
            nn.init.eye_(self.output_projection.weight.data)
        else:
125
126
127
128
129
130
131
132
133
134
            nn.init.xavier_normal_(
                self.qkv_projection.weight.data,
                generator=torch.Generator().manual_seed(config.random_seed),
                gain=0.001,
            )
            nn.init.xavier_normal_(
                self.output_projection.weight.data,
                generator=torch.Generator().manual_seed(config.random_seed),
                gain=0.001,
            )
135
136
137
138
139
140

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
141
142
        # for tractable_init, this is:
        # output = (hidden_states * 3 + positions * 2)
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
        qkv = self.qkv_projection(hidden_states)
        hidden_size = qkv.size(-1) // 3
        q, k, v = qkv.split([hidden_size, hidden_size, hidden_size], dim=-1)

        q = q + positions.unsqueeze(1)
        k = k + positions.unsqueeze(1)

        attn_output = torch.empty_like(q)
        torch.ops.silly.attention(q, k, v, attn_output)

        output = self.output_projection(attn_output)
        return output


class LlamaDecoderLayer(nn.Module):
    def __init__(self, config: LlamaConfig) -> None:
        super().__init__()
        self.self_attention = LlamaAttention(config)
        self.mlp = LlamaMLP(config)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
167
        residual: torch.Tensor | None,
168
    ) -> tuple[torch.Tensor, torch.Tensor]:
169
170
171
172
173
174
175
176
        """
        For tractable computation:
        - if residual is None, the outputs are:
            - residual = (hidden_states + 1) * 3 + positions * 2 + hidden_states = hidden_states * 4 + positions * 2 + 3
            - hidden_states = (residual + 1) ** 2
        - if residual is not None, the outputs are:
            - residual = (hidden_states + residual + 1) * 3 + positions * 2 + hidden_states + residual = (hidden_states + residual) * 4 + positions * 2 + 3
            - hidden_states = (residual + 1) ** 2
177
        """  # noqa
178
179
        if residual is None:
            residual = hidden_states
180
            hidden_states = hidden_states + 1
181
182
183
        else:
            hidden_states = hidden_states + residual
            residual = hidden_states
184
            hidden_states = hidden_states + 1
185

186
187
188
        hidden_states = self.self_attention(
            positions=positions, hidden_states=hidden_states
        )
189
190
191

        hidden_states = hidden_states + residual
        residual = hidden_states
192
        hidden_states = hidden_states + 1
193
194
195
196
197
        hidden_states = self.mlp(hidden_states)

        return hidden_states, residual


198
@support_torch_compile
199
class LlamaModel(nn.Module):
200
201
202
203
204
205
206
207
    def __init__(
        self,
        *,
        vllm_config: VllmConfig,
        config: LlamaConfig,
        prefix: str = "",
        **kwargs,
    ) -> None:
208
209
210
211
212
213
        super().__init__()
        self.embedding_tokens = nn.Embedding(
            num_embeddings=config.vocab_size,
            embedding_dim=config.hidden_size,
        )
        self.layers = nn.ModuleList(
214
215
            [LlamaDecoderLayer(config) for _ in range(config.num_layers)]
        )
216

217
218
        # this is the initial value of the hidden states
        self.embedding_tokens.weight.data.fill_(config.init_value)
219
220
221

    def forward(
        self,
222
        input_ids: torch.Tensor | None,
223
224
225
226
227
228
229
230
231
        positions: torch.Tensor,
    ) -> torch.Tensor:
        hidden_states = self.embedding_tokens(input_ids)
        residual = None
        for layer in self.layers:
            hidden_states, residual = layer(positions, hidden_states, residual)
        return hidden_states


232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
def tractable_computation(
    input_ids: torch.Tensor,
    positions: torch.Tensor,
    config: LlamaConfig,
    init_value: float = 1.0,
) -> torch.Tensor:
    hidden_states = (
        torch.ones(
            input_ids.size(0),
            config.hidden_size,
            device=input_ids.device,
            dtype=input_ids.dtype,
        )
        * init_value
    )
247
248
249

    # first layer
    residual = hidden_states * 4 + positions.unsqueeze(1) * 2 + 3
250
    hidden_states = (residual + 1) ** 2
251
252
253
254
255

    # following layers
    for _ in range(config.num_layers - 1):
        hidden_states = hidden_states + residual
        residual = hidden_states * 4 + positions.unsqueeze(1) * 2 + 3
256
        hidden_states = (residual + 1) ** 2
257
258
259
260

    return hidden_states


261
@torch.inference_mode
262
263
264
265
def run_model(llama_config, compile_config: CompilationConfig) -> torch.Tensor:
    # Start with a fresh copy to make sure there's no cache dir sharing
    compile_config = deepcopy(compile_config)
    cudagraph_runtime_mode = compile_config.cudagraph_mode
266

267
    vllm_config = VllmConfig(
268
        compilation_config=compile_config, additional_config=llama_config
269
    )
270
    with set_current_vllm_config(vllm_config):
271
272
273
274
275
        model = (
            LlamaModel(config=llama_config, vllm_config=vllm_config, prefix="")
            .eval()
            .cuda()
        )
276

277
    with set_forward_context({}, vllm_config=vllm_config):  # background context
278
        B = 16  # max batch size
279
        input_ids = torch.randint(0, llama_config.vocab_size, (B,)).cuda()
280
        positions = torch.arange(B).cuda()
281

282
        # warmup for the model with cudagraph_mode NONE
283
        model(input_ids, positions)
284
285

        # simulate cudagraphs capturing
286
287
288
289
290
291
292
293
        with set_forward_context(
            {},
            vllm_config=vllm_config,
            cudagraph_runtime_mode=cudagraph_runtime_mode,
            batch_descriptor=BatchDescriptor(
                num_tokens=2,
            ),
        ):
294
            model(input_ids[:2], positions[:2])
295
296
297
298
299
300
301
302
        with set_forward_context(
            {},
            vllm_config=vllm_config,
            cudagraph_runtime_mode=cudagraph_runtime_mode,
            batch_descriptor=BatchDescriptor(
                num_tokens=1,
            ),
        ):
303
            model(input_ids[:1], positions[:1])
304

305
        input_ids[:2].zero_()
306
        # simulate cudagraphs replay
307
308
309
310
311
312
313
314
        with set_forward_context(
            {},
            vllm_config=vllm_config,
            cudagraph_runtime_mode=cudagraph_runtime_mode,
            batch_descriptor=BatchDescriptor(
                num_tokens=2,
            ),
        ):
315
            output = model(input_ids[:2], positions[:2])
316

317
        output = output.cpu()
318

319
        if llama_config.tractable_init:
320
321
322
            expected_output = tractable_computation(
                input_ids[:2], positions[:2], llama_config
            ).cpu()
323

324
325
326
            assert torch.allclose(output, expected_output)
        else:
            return output.cpu()
327
328


329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
@pytest.mark.parametrize(
    "backend, use_inductor_graph_partition",
    [
        ("eager", False),  # No inductor
        ("inductor", False),  # Inductor, Dynamo partition
        ("inductor", True),  # Inductor, Inductor partition
    ],
)
def test_toy_llama(
    backend: str, use_inductor_graph_partition: bool, monkeypatch, tmp_path
):
    # We disable the vLLM compile cache into a new tmp dir for 2 reasons:
    # 1. To make sure we can properly track the number of Inductor compilations.
    # 2. Inductor partitioning does not play nicely with Autograd cache (below)
    monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")

    if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
        pytest.skip("Inductor graph partition only supported in torch>=2.9")

348
349
    # compare output with and without piecewise compilation

350
351
352
    llama_config = LlamaConfig(
        hidden_size=128, mlp_size=256, vocab_size=128, num_layers=12
    )
353

354
355
356
    tractable_config = LlamaConfig(
        hidden_size=128, mlp_size=256, vocab_size=128, num_layers=2, tractable_init=True
    )
357

358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
    compile_config_no_compile = CompilationConfig(
        level=CompilationLevel.NO_COMPILATION,
        cudagraph_mode=CUDAGraphMode.NONE,
        backend="eager",
    )

    compile_config_no_split = CompilationConfig(
        level=CompilationLevel.PIECEWISE,
        use_inductor_graph_partition=use_inductor_graph_partition,
        cudagraph_mode=CUDAGraphMode.PIECEWISE,
        backend=backend,
        cudagraph_capture_sizes=[1, 2],
    )

    # FIXME(luka/boyuan): the graph from the previous test case
    #  (no inductor partition) gets cached by AotAutograd so then the
    #  compilation with inductor partitioning incorrectly loads an unpartitioned
    #  graph and never partitions. I think this is a bug with custom inductor
    #  partitioning but does not affect vLLM more generally as vLLM uses its own
    #  cache (which takes inductor partitioning into account).
    if use_inductor_graph_partition:
        compile_config_no_split.inductor_compile_config["force_disable_caches"] = True

    compile_config_split = deepcopy(compile_config_no_split)
    compile_config_split.splitting_ops = ["silly::attention"]

384
385
    outputs = []
    with compilation_counter.expect(
386
387
388
389
390
        num_graphs_seen=0,
        num_piecewise_graphs_seen=0,
        num_piecewise_capturable_graphs_seen=0,
        num_backend_compilations=0,
        num_cudagraph_captured=0,
391
    ):
392
393
394
        outputs.append(run_model(llama_config, compile_config_no_compile))

    run_model(tractable_config, compile_config_no_compile)
395

396
    if backend == "inductor":
397
398
399
        kwargs = {"num_inductor_compiles": 1, "num_eager_compiles": 0}
    else:
        kwargs = {"num_eager_compiles": 1, "num_inductor_compiles": 0}
400

401
    with compilation_counter.expect(
402
        num_graphs_seen=1,  # one graph for the model
403
404
        num_piecewise_graphs_seen=1,
        num_piecewise_capturable_graphs_seen=1,
405
        num_backend_compilations=1,  # num_piecewise_capturable_graphs_seen
406
        num_cudagraph_captured=2,
407
        **kwargs,
408
    ):
409
410
411
412
413
414
415
416
417
418
        outputs.append(run_model(llama_config, compile_config_no_split))

    run_model(tractable_config, compile_config_no_split)

    if use_inductor_graph_partition:
        num_piecewise_fx = 1
        num_piecewise_capturable_fx = 1
    else:
        num_piecewise_fx = 2 * llama_config.num_layers + 1
        num_piecewise_capturable_fx = 1 + llama_config.num_layers
419
420

    with compilation_counter.expect(
421
        num_graphs_seen=1,  # one graph for the model
422
423
424
425
426
        num_piecewise_graphs_seen=num_piecewise_fx,
        num_piecewise_capturable_graphs_seen=num_piecewise_capturable_fx,
        num_backend_compilations=num_piecewise_capturable_fx,
        # num_cudagraph_sizes * num_partitions
        num_cudagraph_captured=2 * (1 + llama_config.num_layers),
427
    ):
428
429
        outputs.append(run_model(llama_config, compile_config_split))
    run_model(tractable_config, compile_config_split)
430
431
432
433
434
435
436
437
438
439

    for i in range(1, len(outputs)):
        assert torch.allclose(outputs[0], outputs[i])


@torch.inference_mode
def benchmark():
    from triton.testing import do_bench

    # similar to llama 3.1-8B
440
441
442
    llama_config = LlamaConfig(
        hidden_size=4096, mlp_size=14336, vocab_size=128 * 1024, num_layers=32
    )
443
444
445

    # a tiny model to measure the overhead
    # of piecewise cudagraph
446
447
448
    llama_config = LlamaConfig(
        hidden_size=40, mlp_size=80, vocab_size=128, num_layers=2
    )
449
450
451
452
453
454
455
456
457
458
459

    cudagraph_sizes = [1, 2, 4] + [i * 8 for i in range(1, 33)]

    eager_time = {}
    full_cudagraph_time = {}
    piecewise_cudagraph_time = {}

    pool = torch.cuda.graph_pool_handle()

    for piecewise in [False, True]:
        if piecewise:
460
461
462
            compilation_config = CompilationConfig(
                level=CompilationLevel.PIECEWISE,
                use_cudagraph=True,
463
                splitting_ops=["silly::attention"],
464
                cudagraph_capture_sizes=cudagraph_sizes,
465
            )
466
        else:
467
            compilation_config = CompilationConfig(
468
469
470
                level=CompilationLevel.PIECEWISE,
                cudagraph_capture_sizes=cudagraph_sizes,
            )
471

472
        vllm_config = VllmConfig(compilation_config=compilation_config)
473
        with set_current_vllm_config(vllm_config):
474
475
476
477
478
479
            model = (
                LlamaModel(config=llama_config, vllm_config=vllm_config, prefix="")
                .eval()
                .cuda()
                .to(torch.bfloat16)
            )
480
481

        B = 256  # max batch size
482
        input_ids = torch.randint(0, llama_config.vocab_size, (B,)).cuda()
483
484
485
486
        positions = torch.arange(B).cuda().to(torch.bfloat16)

        graphs = {}

487
488
489
490
491
        model(input_ids, positions)
        for b in cudagraph_sizes[::-1]:
            if not piecewise:
                graph = torch.cuda.CUDAGraph()
                with torch.cuda.graph(graph, pool=pool):
492
                    output = model(input_ids[:b], positions[:b])
493
494
495
496
                graphs[b] = (graph, output)
            else:
                output = model(input_ids[:b], positions[:b])
                graphs[b] = (model, output)
497
498
499
500
501
502
503
        for b in cudagraph_sizes:
            if piecewise:
                # noqa is for `Function definition does not bind loop variable`
                # it will be problematic if we save the created lambda function
                # and use it later, because it will look up the name `b` in the
                # enclosing scope, and the value of `b` will always be 256.
                # it is fine here, because we only use the lambda function once.
504
505
                runtime = do_bench(
                    lambda: graphs[b][0](  # noqa
506
507
                        input_ids[:b],  # noqa
                        positions[:b],  # noqa
508
                    )
509
                )
510
511
512
                piecewise_cudagraph_time[b] = runtime
            else:
                runtime = do_bench(lambda: graphs[b][0].replay())  # noqa
513
                eager_runtime = do_bench(lambda: model(input_ids[:b], positions[:b]))  # noqa
514
515
516
517
518
519
                full_cudagraph_time[b] = runtime
                eager_time[b] = eager_runtime

    # print in tabular format
    print("batch size\teager mode\tfull cudagraph\tpiecewise cudagraph")
    for b in cudagraph_sizes:
520
521
522
523
        print(
            f"{b}\t{eager_time[b]:.3f}\t{full_cudagraph_time[b]:.3f}"
            f"\t{piecewise_cudagraph_time[b]:.3f}"
        )
524
525
526
527


if __name__ == "__main__":
    benchmark()