test_toy_llama.py 16.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
"""
Test the piecewise compilation with a simple model, comparing the output
with and without the piecewise compilation.
6
7
8
9

This is a tractable model, the weights and computation are specially designed
if the config `tractable_init` is set to True. Otherwise, the weights are
initialized randomly with a fixed seed.
10
"""
11

12
from dataclasses import dataclass
13
from typing import Any, Optional
14

15
import pytest
16
17
18
19
20
import torch
from torch import nn

from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile
21
22
23
24
25
26
27
from vllm.config import (
    CompilationConfig,
    CompilationLevel,
    CUDAGraphMode,
    VllmConfig,
    set_current_vllm_config,
)
28
from vllm.forward_context import BatchDescriptor, set_forward_context
29

30
31
# This import automatically registers `torch.ops.silly.attention`
from .. import silly_attention  # noqa: F401
32
33


34
35
36
37
38
39
@dataclass
class LlamaConfig:
    hidden_size: int = 128
    mlp_size: int = 256
    vocab_size: int = 128
    num_layers: int = 2
40
41
42
43
    init_value: float = 1.0
    tractable_init: bool = False
    random_seed: int = 0

44
    def compute_hash(self) -> str:
45
        factors: list[Any] = []
46
47
48
49
50
51
        for k, v in self.__dict__.items():
            if k == "random_seed":
                continue
            factors.append((k, v))
        factors.sort()
        import hashlib
52
53

        return hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
54

55
56
    def __post_init__(self):
        assert self.mlp_size >= self.hidden_size
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72


class LlamaMLP(nn.Module):
    def __init__(self, config: LlamaConfig) -> None:
        super().__init__()
        self.gate_up_projection = nn.Linear(
            in_features=config.hidden_size,
            out_features=config.mlp_size * 2,
            bias=False,
        )
        self.down_projection = nn.Linear(
            in_features=config.mlp_size,
            out_features=config.hidden_size,
            bias=False,
        )

73
        if config.tractable_init:
74
75
            nn.init.eye_(self.gate_up_projection.weight.data[: config.mlp_size])
            nn.init.eye_(self.gate_up_projection.weight.data[config.mlp_size :])
76
77
            nn.init.eye_(self.down_projection.weight.data)
        else:
78
79
80
81
82
83
84
85
86
87
            nn.init.xavier_normal_(
                self.gate_up_projection.weight.data,
                generator=torch.Generator().manual_seed(config.random_seed),
                gain=0.001,
            )
            nn.init.xavier_normal_(
                self.down_projection.weight.data,
                generator=torch.Generator().manual_seed(config.random_seed),
                gain=0.001,
            )
88
89

    def forward(self, x):
90
91
        # for tractable_init and positive input, this is
        # essentially an elementwise-square
92
        x = self.gate_up_projection(x)
93
        x = x[:, : x.size(1) // 2] * torch.nn.functional.relu(x[:, x.size(1) // 2 :])
94
95
96
97
98
99
100
101
102
103
        x = self.down_projection(x)
        return x


class LlamaAttention(nn.Module):
    def __init__(self, config: LlamaConfig) -> None:
        super().__init__()
        self.qkv_projection = nn.Linear(
            in_features=config.hidden_size,
            out_features=config.hidden_size * 3,
104
            bias=False,
105
106
107
108
109
        )

        self.output_projection = nn.Linear(
            in_features=config.hidden_size,
            out_features=config.hidden_size,
110
            bias=False,
111
112
        )

113
        if config.tractable_init:
114
115
116
117
118
119
120
            nn.init.eye_(self.qkv_projection.weight.data[: config.hidden_size])
            nn.init.eye_(
                self.qkv_projection.weight.data[
                    config.hidden_size : 2 * config.hidden_size
                ]
            )
            nn.init.eye_(self.qkv_projection.weight.data[2 * config.hidden_size :])
121
122
            nn.init.eye_(self.output_projection.weight.data)
        else:
123
124
125
126
127
128
129
130
131
132
            nn.init.xavier_normal_(
                self.qkv_projection.weight.data,
                generator=torch.Generator().manual_seed(config.random_seed),
                gain=0.001,
            )
            nn.init.xavier_normal_(
                self.output_projection.weight.data,
                generator=torch.Generator().manual_seed(config.random_seed),
                gain=0.001,
            )
133
134
135
136
137
138

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
139
140
        # for tractable_init, this is:
        # output = (hidden_states * 3 + positions * 2)
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
        qkv = self.qkv_projection(hidden_states)
        hidden_size = qkv.size(-1) // 3
        q, k, v = qkv.split([hidden_size, hidden_size, hidden_size], dim=-1)

        q = q + positions.unsqueeze(1)
        k = k + positions.unsqueeze(1)

        attn_output = torch.empty_like(q)
        torch.ops.silly.attention(q, k, v, attn_output)

        output = self.output_projection(attn_output)
        return output


class LlamaDecoderLayer(nn.Module):
    def __init__(self, config: LlamaConfig) -> None:
        super().__init__()
        self.self_attention = LlamaAttention(config)
        self.mlp = LlamaMLP(config)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        residual: Optional[torch.Tensor],
166
    ) -> tuple[torch.Tensor, torch.Tensor]:
167
168
169
170
171
172
173
174
        """
        For tractable computation:
        - if residual is None, the outputs are:
            - residual = (hidden_states + 1) * 3 + positions * 2 + hidden_states = hidden_states * 4 + positions * 2 + 3
            - hidden_states = (residual + 1) ** 2
        - if residual is not None, the outputs are:
            - residual = (hidden_states + residual + 1) * 3 + positions * 2 + hidden_states + residual = (hidden_states + residual) * 4 + positions * 2 + 3
            - hidden_states = (residual + 1) ** 2
175
        """  # noqa
176
177
        if residual is None:
            residual = hidden_states
178
            hidden_states = hidden_states + 1
179
180
181
        else:
            hidden_states = hidden_states + residual
            residual = hidden_states
182
            hidden_states = hidden_states + 1
183

184
185
186
        hidden_states = self.self_attention(
            positions=positions, hidden_states=hidden_states
        )
187
188
189

        hidden_states = hidden_states + residual
        residual = hidden_states
190
        hidden_states = hidden_states + 1
191
192
193
194
195
        hidden_states = self.mlp(hidden_states)

        return hidden_states, residual


196
@support_torch_compile
197
class LlamaModel(nn.Module):
198
199
200
201
202
203
204
205
    def __init__(
        self,
        *,
        vllm_config: VllmConfig,
        config: LlamaConfig,
        prefix: str = "",
        **kwargs,
    ) -> None:
206
207
208
209
210
211
        super().__init__()
        self.embedding_tokens = nn.Embedding(
            num_embeddings=config.vocab_size,
            embedding_dim=config.hidden_size,
        )
        self.layers = nn.ModuleList(
212
213
            [LlamaDecoderLayer(config) for _ in range(config.num_layers)]
        )
214

215
216
        # this is the initial value of the hidden states
        self.embedding_tokens.weight.data.fill_(config.init_value)
217
218
219
220
221
222
223
224
225
226
227
228
229

    def forward(
        self,
        input_ids: Optional[torch.Tensor],
        positions: torch.Tensor,
    ) -> torch.Tensor:
        hidden_states = self.embedding_tokens(input_ids)
        residual = None
        for layer in self.layers:
            hidden_states, residual = layer(positions, hidden_states, residual)
        return hidden_states


230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
def tractable_computation(
    input_ids: torch.Tensor,
    positions: torch.Tensor,
    config: LlamaConfig,
    init_value: float = 1.0,
) -> torch.Tensor:
    hidden_states = (
        torch.ones(
            input_ids.size(0),
            config.hidden_size,
            device=input_ids.device,
            dtype=input_ids.dtype,
        )
        * init_value
    )
245
246
247

    # first layer
    residual = hidden_states * 4 + positions.unsqueeze(1) * 2 + 3
248
    hidden_states = (residual + 1) ** 2
249
250
251
252
253

    # following layers
    for _ in range(config.num_layers - 1):
        hidden_states = hidden_states + residual
        residual = hidden_states * 4 + positions.unsqueeze(1) * 2 + 3
254
        hidden_states = (residual + 1) ** 2
255
256
257
258

    return hidden_states


259
@torch.inference_mode
260
def run_model(
261
    llama_config, use_compile: bool, use_inductor: bool, split_attn: bool = False
262
) -> torch.Tensor:
263
    if use_compile:
264
265
266
        compilation_config = CompilationConfig(
            level=CompilationLevel.PIECEWISE,
            use_cudagraph=True,
267
            use_inductor=use_inductor,
268
            cudagraph_capture_sizes=[1, 2],
269
        )
270
        if split_attn:
271
            compilation_config.splitting_ops = ["silly.attention"]
272
        cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
273
    else:
274
        compilation_config = CompilationConfig(
275
276
            level=CompilationLevel.NO_COMPILATION,
        )
277
        cudagraph_runtime_mode = CUDAGraphMode.NONE
278

279
280
281
    vllm_config = VllmConfig(
        compilation_config=compilation_config, additional_config=llama_config
    )
282
    with set_current_vllm_config(vllm_config):
283
284
285
286
287
        model = (
            LlamaModel(config=llama_config, vllm_config=vllm_config, prefix="")
            .eval()
            .cuda()
        )
288

289
    with set_forward_context({}, vllm_config=vllm_config):  # background context
290
        B = 16  # max batch size
291
        input_ids = torch.randint(0, llama_config.vocab_size, (B,)).cuda()
292
        positions = torch.arange(B).cuda()
293

294
        # warmup for the model with cudagraph_mode NONE
295
        model(input_ids, positions)
296
297

        # simulate cudagraphs capturing
298
299
300
301
302
303
304
305
        with set_forward_context(
            {},
            vllm_config=vllm_config,
            cudagraph_runtime_mode=cudagraph_runtime_mode,
            batch_descriptor=BatchDescriptor(
                num_tokens=2,
            ),
        ):
306
            model(input_ids[:2], positions[:2])
307
308
309
310
311
312
313
314
        with set_forward_context(
            {},
            vllm_config=vllm_config,
            cudagraph_runtime_mode=cudagraph_runtime_mode,
            batch_descriptor=BatchDescriptor(
                num_tokens=1,
            ),
        ):
315
            model(input_ids[:1], positions[:1])
316

317
        input_ids[:2].zero_()
318
        # simulate cudagraphs replay
319
320
321
322
323
324
325
326
        with set_forward_context(
            {},
            vllm_config=vllm_config,
            cudagraph_runtime_mode=cudagraph_runtime_mode,
            batch_descriptor=BatchDescriptor(
                num_tokens=2,
            ),
        ):
327
            output = model(input_ids[:2], positions[:2])
328

329
        output = output.cpu()
330

331
        if llama_config.tractable_init:
332
333
334
            expected_output = tractable_computation(
                input_ids[:2], positions[:2], llama_config
            ).cpu()
335

336
337
338
            assert torch.allclose(output, expected_output)
        else:
            return output.cpu()
339
340


341
342
@pytest.mark.parametrize("use_inductor", [True, False])
def test_toy_llama(use_inductor: bool):
343
344
    # compare output with and without piecewise compilation

345
346
347
    llama_config = LlamaConfig(
        hidden_size=128, mlp_size=256, vocab_size=128, num_layers=12
    )
348

349
350
351
    tractable_config = LlamaConfig(
        hidden_size=128, mlp_size=256, vocab_size=128, num_layers=2, tractable_init=True
    )
352
353
354

    outputs = []
    with compilation_counter.expect(
355
356
357
358
359
        num_graphs_seen=0,
        num_piecewise_graphs_seen=0,
        num_piecewise_capturable_graphs_seen=0,
        num_backend_compilations=0,
        num_cudagraph_captured=0,
360
    ):
361
362
        outputs.append(run_model(llama_config, use_inductor=False, use_compile=False))
    run_model(tractable_config, use_inductor=False, use_compile=False)
363

364
    if use_inductor:
365
366
367
        kwargs = {"num_inductor_compiles": 1, "num_eager_compiles": 0}
    else:
        kwargs = {"num_eager_compiles": 1, "num_inductor_compiles": 0}
368

369
    with compilation_counter.expect(
370
371
        # One graph for the model
        num_graphs_seen=1,
372
373
        num_piecewise_graphs_seen=1,
        num_piecewise_capturable_graphs_seen=1,
374
375
376
377
        # num_piecewise_capturable_graphs_seen
        num_backend_compilations=1,
        # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
        num_cudagraph_captured=2,
378
        **kwargs,
379
    ):
380
381
382
383
        outputs.append(
            run_model(llama_config, use_inductor=use_inductor, use_compile=True)
        )
    run_model(tractable_config, use_inductor=use_inductor, use_compile=True)
384
385

    with compilation_counter.expect(
386
387
388
389
390
391
392
393
394
395
        num_graphs_seen=1,  # one graph for the model
        num_piecewise_graphs_seen=2 * llama_config.num_layers + 1,  # 2 * num_layers + 1
        num_piecewise_capturable_graphs_seen=1
        + llama_config.num_layers,  # 1 + num_layers
        num_backend_compilations=1
        + llama_config.num_layers,  # num_piecewise_capturable_graphs_seen
        num_cudagraph_captured=2
        * (
            1 + llama_config.num_layers
        ),  # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
396
397
    ):
        outputs.append(
398
399
400
401
402
403
            run_model(
                llama_config,
                use_inductor=use_inductor,
                use_compile=True,
                split_attn=True,
            )
404
        )
405
406
407
    run_model(
        tractable_config, use_inductor=use_inductor, use_compile=True, split_attn=True
    )
408
409
410
411
412
413
414
415
416
417

    for i in range(1, len(outputs)):
        assert torch.allclose(outputs[0], outputs[i])


@torch.inference_mode
def benchmark():
    from triton.testing import do_bench

    # similar to llama 3.1-8B
418
419
420
    llama_config = LlamaConfig(
        hidden_size=4096, mlp_size=14336, vocab_size=128 * 1024, num_layers=32
    )
421
422
423

    # a tiny model to measure the overhead
    # of piecewise cudagraph
424
425
426
    llama_config = LlamaConfig(
        hidden_size=40, mlp_size=80, vocab_size=128, num_layers=2
    )
427
428
429
430
431
432
433
434
435
436
437

    cudagraph_sizes = [1, 2, 4] + [i * 8 for i in range(1, 33)]

    eager_time = {}
    full_cudagraph_time = {}
    piecewise_cudagraph_time = {}

    pool = torch.cuda.graph_pool_handle()

    for piecewise in [False, True]:
        if piecewise:
438
439
440
            compilation_config = CompilationConfig(
                level=CompilationLevel.PIECEWISE,
                use_cudagraph=True,
441
                splitting_ops=["silly.attention"],
442
                cudagraph_capture_sizes=cudagraph_sizes,
443
            )
444
        else:
445
            compilation_config = CompilationConfig(
446
447
448
                level=CompilationLevel.PIECEWISE,
                cudagraph_capture_sizes=cudagraph_sizes,
            )
449

450
        vllm_config = VllmConfig(compilation_config=compilation_config)
451
        with set_current_vllm_config(vllm_config):
452
453
454
455
456
457
            model = (
                LlamaModel(config=llama_config, vllm_config=vllm_config, prefix="")
                .eval()
                .cuda()
                .to(torch.bfloat16)
            )
458
459

        B = 256  # max batch size
460
        input_ids = torch.randint(0, llama_config.vocab_size, (B,)).cuda()
461
462
463
464
        positions = torch.arange(B).cuda().to(torch.bfloat16)

        graphs = {}

465
466
467
468
469
        model(input_ids, positions)
        for b in cudagraph_sizes[::-1]:
            if not piecewise:
                graph = torch.cuda.CUDAGraph()
                with torch.cuda.graph(graph, pool=pool):
470
                    output = model(input_ids[:b], positions[:b])
471
472
473
474
                graphs[b] = (graph, output)
            else:
                output = model(input_ids[:b], positions[:b])
                graphs[b] = (model, output)
475
476
477
478
479
480
481
        for b in cudagraph_sizes:
            if piecewise:
                # noqa is for `Function definition does not bind loop variable`
                # it will be problematic if we save the created lambda function
                # and use it later, because it will look up the name `b` in the
                # enclosing scope, and the value of `b` will always be 256.
                # it is fine here, because we only use the lambda function once.
482
483
                runtime = do_bench(
                    lambda: graphs[b][0](  # noqa
484
485
                        input_ids[:b],  # noqa
                        positions[:b],  # noqa
486
                    )
487
                )
488
489
490
                piecewise_cudagraph_time[b] = runtime
            else:
                runtime = do_bench(lambda: graphs[b][0].replay())  # noqa
491
                eager_runtime = do_bench(lambda: model(input_ids[:b], positions[:b]))  # noqa
492
493
494
495
496
497
                full_cudagraph_time[b] = runtime
                eager_time[b] = eager_runtime

    # print in tabular format
    print("batch size\teager mode\tfull cudagraph\tpiecewise cudagraph")
    for b in cudagraph_sizes:
498
499
500
501
        print(
            f"{b}\t{eager_time[b]:.3f}\t{full_cudagraph_time[b]:.3f}"
            f"\t{piecewise_cudagraph_time[b]:.3f}"
        )
502
503
504
505


if __name__ == "__main__":
    benchmark()