"tests/vscode:/vscode.git/clone" did not exist on "1e57b1ee6312325a9dab99918422693c38f2b203"
test_toy_llama.py 17.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
"""
Test the piecewise compilation with a simple model, comparing the output
with and without the piecewise compilation.
6
7
8
9

This is a tractable model, the weights and computation are specially designed
if the config `tractable_init` is set to True. Otherwise, the weights are
initialized randomly with a fixed seed.
10
"""
11

12
from copy import deepcopy
13
from dataclasses import dataclass
14
from typing import Any
15

16
import pytest
17
18
19
20
21
import torch
from torch import nn

from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile
22
23
from vllm.config import (
    CompilationConfig,
24
    CompilationMode,
25
26
27
28
    CUDAGraphMode,
    VllmConfig,
    set_current_vllm_config,
)
29
from vllm.forward_context import BatchDescriptor, set_forward_context
30
from vllm.utils.torch_utils import is_torch_equal_or_newer
31

32
33
from ...utils import create_new_process_for_each_test

34
35
# This import automatically registers `torch.ops.silly.attention`
from .. import silly_attention  # noqa: F401
36
37


38
39
40
41
42
43
@dataclass
class LlamaConfig:
    hidden_size: int = 128
    mlp_size: int = 256
    vocab_size: int = 128
    num_layers: int = 2
44
45
46
47
    init_value: float = 1.0
    tractable_init: bool = False
    random_seed: int = 0

48
    def compute_hash(self) -> str:
49
        factors: list[Any] = []
50
51
52
53
54
55
        for k, v in self.__dict__.items():
            if k == "random_seed":
                continue
            factors.append((k, v))
        factors.sort()
        import hashlib
56
57

        return hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
58

59
60
    def __post_init__(self):
        assert self.mlp_size >= self.hidden_size
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76


class LlamaMLP(nn.Module):
    def __init__(self, config: LlamaConfig) -> None:
        super().__init__()
        self.gate_up_projection = nn.Linear(
            in_features=config.hidden_size,
            out_features=config.mlp_size * 2,
            bias=False,
        )
        self.down_projection = nn.Linear(
            in_features=config.mlp_size,
            out_features=config.hidden_size,
            bias=False,
        )

77
        if config.tractable_init:
78
79
            nn.init.eye_(self.gate_up_projection.weight.data[: config.mlp_size])
            nn.init.eye_(self.gate_up_projection.weight.data[config.mlp_size :])
80
81
            nn.init.eye_(self.down_projection.weight.data)
        else:
82
83
84
85
86
87
88
89
90
91
            nn.init.xavier_normal_(
                self.gate_up_projection.weight.data,
                generator=torch.Generator().manual_seed(config.random_seed),
                gain=0.001,
            )
            nn.init.xavier_normal_(
                self.down_projection.weight.data,
                generator=torch.Generator().manual_seed(config.random_seed),
                gain=0.001,
            )
92
93

    def forward(self, x):
94
95
        # for tractable_init and positive input, this is
        # essentially an elementwise-square
96
        x = self.gate_up_projection(x)
97
        x = x[:, : x.size(1) // 2] * torch.nn.functional.relu(x[:, x.size(1) // 2 :])
98
99
100
101
102
103
104
105
106
107
        x = self.down_projection(x)
        return x


class LlamaAttention(nn.Module):
    def __init__(self, config: LlamaConfig) -> None:
        super().__init__()
        self.qkv_projection = nn.Linear(
            in_features=config.hidden_size,
            out_features=config.hidden_size * 3,
108
            bias=False,
109
110
111
112
113
        )

        self.output_projection = nn.Linear(
            in_features=config.hidden_size,
            out_features=config.hidden_size,
114
            bias=False,
115
116
        )

117
        if config.tractable_init:
118
119
120
121
122
123
124
            nn.init.eye_(self.qkv_projection.weight.data[: config.hidden_size])
            nn.init.eye_(
                self.qkv_projection.weight.data[
                    config.hidden_size : 2 * config.hidden_size
                ]
            )
            nn.init.eye_(self.qkv_projection.weight.data[2 * config.hidden_size :])
125
126
            nn.init.eye_(self.output_projection.weight.data)
        else:
127
128
129
130
131
132
133
134
135
136
            nn.init.xavier_normal_(
                self.qkv_projection.weight.data,
                generator=torch.Generator().manual_seed(config.random_seed),
                gain=0.001,
            )
            nn.init.xavier_normal_(
                self.output_projection.weight.data,
                generator=torch.Generator().manual_seed(config.random_seed),
                gain=0.001,
            )
137
138
139
140
141
142

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
143
144
        # for tractable_init, this is:
        # output = (hidden_states * 3 + positions * 2)
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
        qkv = self.qkv_projection(hidden_states)
        hidden_size = qkv.size(-1) // 3
        q, k, v = qkv.split([hidden_size, hidden_size, hidden_size], dim=-1)

        q = q + positions.unsqueeze(1)
        k = k + positions.unsqueeze(1)

        attn_output = torch.empty_like(q)
        torch.ops.silly.attention(q, k, v, attn_output)

        output = self.output_projection(attn_output)
        return output


class LlamaDecoderLayer(nn.Module):
    def __init__(self, config: LlamaConfig) -> None:
        super().__init__()
        self.self_attention = LlamaAttention(config)
        self.mlp = LlamaMLP(config)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
169
        residual: torch.Tensor | None,
170
    ) -> tuple[torch.Tensor, torch.Tensor]:
171
172
173
174
175
176
177
178
        """
        For tractable computation:
        - if residual is None, the outputs are:
            - residual = (hidden_states + 1) * 3 + positions * 2 + hidden_states = hidden_states * 4 + positions * 2 + 3
            - hidden_states = (residual + 1) ** 2
        - if residual is not None, the outputs are:
            - residual = (hidden_states + residual + 1) * 3 + positions * 2 + hidden_states + residual = (hidden_states + residual) * 4 + positions * 2 + 3
            - hidden_states = (residual + 1) ** 2
179
        """  # noqa
180
181
        if residual is None:
            residual = hidden_states
182
            hidden_states = hidden_states + 1
183
184
185
        else:
            hidden_states = hidden_states + residual
            residual = hidden_states
186
            hidden_states = hidden_states + 1
187

188
189
190
        hidden_states = self.self_attention(
            positions=positions, hidden_states=hidden_states
        )
191
192
193

        hidden_states = hidden_states + residual
        residual = hidden_states
194
        hidden_states = hidden_states + 1
195
196
197
198
199
        hidden_states = self.mlp(hidden_states)

        return hidden_states, residual


200
@support_torch_compile
201
class LlamaModel(nn.Module):
202
203
204
205
206
207
208
209
    def __init__(
        self,
        *,
        vllm_config: VllmConfig,
        config: LlamaConfig,
        prefix: str = "",
        **kwargs,
    ) -> None:
210
211
212
213
214
215
        super().__init__()
        self.embedding_tokens = nn.Embedding(
            num_embeddings=config.vocab_size,
            embedding_dim=config.hidden_size,
        )
        self.layers = nn.ModuleList(
216
217
            [LlamaDecoderLayer(config) for _ in range(config.num_layers)]
        )
218

219
220
        # this is the initial value of the hidden states
        self.embedding_tokens.weight.data.fill_(config.init_value)
221
222
223

    def forward(
        self,
224
        input_ids: torch.Tensor | None,
225
226
227
228
229
230
231
232
233
        positions: torch.Tensor,
    ) -> torch.Tensor:
        hidden_states = self.embedding_tokens(input_ids)
        residual = None
        for layer in self.layers:
            hidden_states, residual = layer(positions, hidden_states, residual)
        return hidden_states


234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
def tractable_computation(
    input_ids: torch.Tensor,
    positions: torch.Tensor,
    config: LlamaConfig,
    init_value: float = 1.0,
) -> torch.Tensor:
    hidden_states = (
        torch.ones(
            input_ids.size(0),
            config.hidden_size,
            device=input_ids.device,
            dtype=input_ids.dtype,
        )
        * init_value
    )
249
250
251

    # first layer
    residual = hidden_states * 4 + positions.unsqueeze(1) * 2 + 3
252
    hidden_states = (residual + 1) ** 2
253
254
255
256
257

    # following layers
    for _ in range(config.num_layers - 1):
        hidden_states = hidden_states + residual
        residual = hidden_states * 4 + positions.unsqueeze(1) * 2 + 3
258
        hidden_states = (residual + 1) ** 2
259
260
261
262

    return hidden_states


263
@torch.inference_mode
264
265
266
267
def run_model(llama_config, compile_config: CompilationConfig) -> torch.Tensor:
    # Start with a fresh copy to make sure there's no cache dir sharing
    compile_config = deepcopy(compile_config)
    cudagraph_runtime_mode = compile_config.cudagraph_mode
268

269
    vllm_config = VllmConfig(
270
        compilation_config=compile_config, additional_config=llama_config
271
    )
272
    with set_current_vllm_config(vllm_config):
273
274
275
276
277
        model = (
            LlamaModel(config=llama_config, vllm_config=vllm_config, prefix="")
            .eval()
            .cuda()
        )
278

279
    with set_forward_context({}, vllm_config=vllm_config):  # background context
280
        B = 16  # max batch size
281
        input_ids = torch.randint(0, llama_config.vocab_size, (B,)).cuda()
282
        positions = torch.arange(B).cuda()
283

284
        # warmup for the model with cudagraph_mode NONE
285
        model(input_ids, positions)
286
287

        # simulate cudagraphs capturing
288
289
290
291
292
293
294
295
        with set_forward_context(
            {},
            vllm_config=vllm_config,
            cudagraph_runtime_mode=cudagraph_runtime_mode,
            batch_descriptor=BatchDescriptor(
                num_tokens=2,
            ),
        ):
296
            model(input_ids[:2], positions[:2])
297
298
299
300
301
302
303
304
        with set_forward_context(
            {},
            vllm_config=vllm_config,
            cudagraph_runtime_mode=cudagraph_runtime_mode,
            batch_descriptor=BatchDescriptor(
                num_tokens=1,
            ),
        ):
305
            model(input_ids[:1], positions[:1])
306

307
        input_ids[:2].zero_()
308
        # simulate cudagraphs replay
309
310
311
312
313
314
315
316
        with set_forward_context(
            {},
            vllm_config=vllm_config,
            cudagraph_runtime_mode=cudagraph_runtime_mode,
            batch_descriptor=BatchDescriptor(
                num_tokens=2,
            ),
        ):
317
            output = model(input_ids[:2], positions[:2])
318

319
        output = output.cpu()
320

321
        if llama_config.tractable_init:
322
323
324
            expected_output = tractable_computation(
                input_ids[:2], positions[:2], llama_config
            ).cpu()
325

326
327
328
            assert torch.allclose(output, expected_output)
        else:
            return output.cpu()
329
330


331
332
333
334
335
336
337
338
@pytest.mark.parametrize(
    "backend, use_inductor_graph_partition",
    [
        ("eager", False),  # No inductor
        ("inductor", False),  # Inductor, Dynamo partition
        ("inductor", True),  # Inductor, Inductor partition
    ],
)
339
@create_new_process_for_each_test("spawn")
340
341
342
def test_toy_llama(
    backend: str, use_inductor_graph_partition: bool, monkeypatch, tmp_path
):
343
    # We disable the vLLM compile cache into a new tmp dir for 1 reason:
344
345
346
347
348
349
    # 1. To make sure we can properly track the number of Inductor compilations.
    monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")

    if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
        pytest.skip("Inductor graph partition only supported in torch>=2.9")

350
351
    # compare output with and without piecewise compilation

352
353
354
    llama_config = LlamaConfig(
        hidden_size=128, mlp_size=256, vocab_size=128, num_layers=12
    )
355

356
357
358
    tractable_config = LlamaConfig(
        hidden_size=128, mlp_size=256, vocab_size=128, num_layers=2, tractable_init=True
    )
359

360
    compile_config_no_compile = CompilationConfig(
361
        mode=CompilationMode.NONE,
362
363
364
365
366
        cudagraph_mode=CUDAGraphMode.NONE,
        backend="eager",
    )

    compile_config_no_split = CompilationConfig(
367
        mode=CompilationMode.VLLM_COMPILE,
368
369
370
371
372
373
374
375
376
        use_inductor_graph_partition=use_inductor_graph_partition,
        cudagraph_mode=CUDAGraphMode.PIECEWISE,
        backend=backend,
        cudagraph_capture_sizes=[1, 2],
    )

    compile_config_split = deepcopy(compile_config_no_split)
    compile_config_split.splitting_ops = ["silly::attention"]

377
378
    outputs = []
    with compilation_counter.expect(
379
380
381
382
383
        num_graphs_seen=0,
        num_piecewise_graphs_seen=0,
        num_piecewise_capturable_graphs_seen=0,
        num_backend_compilations=0,
        num_cudagraph_captured=0,
384
    ):
385
386
387
        outputs.append(run_model(llama_config, compile_config_no_compile))

    run_model(tractable_config, compile_config_no_compile)
388

389
    if backend == "inductor":
390
391
392
        kwargs = {"num_inductor_compiles": 1, "num_eager_compiles": 0}
    else:
        kwargs = {"num_eager_compiles": 1, "num_inductor_compiles": 0}
393

394
    with compilation_counter.expect(
395
        num_graphs_seen=1,  # one graph for the model
396
397
        num_piecewise_graphs_seen=1,
        num_piecewise_capturable_graphs_seen=1,
398
        num_backend_compilations=1,  # num_piecewise_capturable_graphs_seen
399
        num_cudagraph_captured=2,
400
        **kwargs,
401
    ):
402
403
404
405
406
407
408
409
410
411
        outputs.append(run_model(llama_config, compile_config_no_split))

    run_model(tractable_config, compile_config_no_split)

    if use_inductor_graph_partition:
        num_piecewise_fx = 1
        num_piecewise_capturable_fx = 1
    else:
        num_piecewise_fx = 2 * llama_config.num_layers + 1
        num_piecewise_capturable_fx = 1 + llama_config.num_layers
412
413

    with compilation_counter.expect(
414
        num_graphs_seen=1,  # one graph for the model
415
416
417
418
419
        num_piecewise_graphs_seen=num_piecewise_fx,
        num_piecewise_capturable_graphs_seen=num_piecewise_capturable_fx,
        num_backend_compilations=num_piecewise_capturable_fx,
        # num_cudagraph_sizes * num_partitions
        num_cudagraph_captured=2 * (1 + llama_config.num_layers),
420
    ):
421
422
        outputs.append(run_model(llama_config, compile_config_split))
    run_model(tractable_config, compile_config_split)
423
424
425
426
427
428
429
430
431
432

    for i in range(1, len(outputs)):
        assert torch.allclose(outputs[0], outputs[i])


@torch.inference_mode
def benchmark():
    from triton.testing import do_bench

    # similar to llama 3.1-8B
433
434
435
    llama_config = LlamaConfig(
        hidden_size=4096, mlp_size=14336, vocab_size=128 * 1024, num_layers=32
    )
436
437
438

    # a tiny model to measure the overhead
    # of piecewise cudagraph
439
440
441
    llama_config = LlamaConfig(
        hidden_size=40, mlp_size=80, vocab_size=128, num_layers=2
    )
442
443
444
445
446
447
448
449
450
451
452

    cudagraph_sizes = [1, 2, 4] + [i * 8 for i in range(1, 33)]

    eager_time = {}
    full_cudagraph_time = {}
    piecewise_cudagraph_time = {}

    pool = torch.cuda.graph_pool_handle()

    for piecewise in [False, True]:
        if piecewise:
453
            compilation_config = CompilationConfig(
454
                mode=CompilationMode.VLLM_COMPILE,
455
                splitting_ops=["silly::attention"],
456
                cudagraph_capture_sizes=cudagraph_sizes,
457
            )
458
        else:
459
            compilation_config = CompilationConfig(
460
                mode=CompilationMode.VLLM_COMPILE,
461
462
                cudagraph_capture_sizes=cudagraph_sizes,
            )
463

464
        vllm_config = VllmConfig(compilation_config=compilation_config)
465
        with set_current_vllm_config(vllm_config):
466
467
468
469
470
471
            model = (
                LlamaModel(config=llama_config, vllm_config=vllm_config, prefix="")
                .eval()
                .cuda()
                .to(torch.bfloat16)
            )
472
473

        B = 256  # max batch size
474
        input_ids = torch.randint(0, llama_config.vocab_size, (B,)).cuda()
475
476
477
478
        positions = torch.arange(B).cuda().to(torch.bfloat16)

        graphs = {}

479
480
481
482
483
        model(input_ids, positions)
        for b in cudagraph_sizes[::-1]:
            if not piecewise:
                graph = torch.cuda.CUDAGraph()
                with torch.cuda.graph(graph, pool=pool):
484
                    output = model(input_ids[:b], positions[:b])
485
486
487
488
                graphs[b] = (graph, output)
            else:
                output = model(input_ids[:b], positions[:b])
                graphs[b] = (model, output)
489
490
491
492
493
494
495
        for b in cudagraph_sizes:
            if piecewise:
                # noqa is for `Function definition does not bind loop variable`
                # it will be problematic if we save the created lambda function
                # and use it later, because it will look up the name `b` in the
                # enclosing scope, and the value of `b` will always be 256.
                # it is fine here, because we only use the lambda function once.
496
497
                runtime = do_bench(
                    lambda: graphs[b][0](  # noqa
498
499
                        input_ids[:b],  # noqa
                        positions[:b],  # noqa
500
                    )
501
                )
502
503
504
                piecewise_cudagraph_time[b] = runtime
            else:
                runtime = do_bench(lambda: graphs[b][0].replay())  # noqa
505
                eager_runtime = do_bench(lambda: model(input_ids[:b], positions[:b]))  # noqa
506
507
508
509
510
511
                full_cudagraph_time[b] = runtime
                eager_time[b] = eager_runtime

    # print in tabular format
    print("batch size\teager mode\tfull cudagraph\tpiecewise cudagraph")
    for b in cudagraph_sizes:
512
513
514
515
        print(
            f"{b}\t{eager_time[b]:.3f}\t{full_cudagraph_time[b]:.3f}"
            f"\t{piecewise_cudagraph_time[b]:.3f}"
        )
516
517
518


if __name__ == "__main__":
519
520
521
522
523
    # Protect against subprocess reimport when using spawn_new_process_for_each_test
    import os

    if os.environ.get("RUNNING_IN_SUBPROCESS") != "1":
        benchmark()