test_toy_llama.py 17.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
"""
Test the piecewise compilation with a simple model, comparing the output
with and without the piecewise compilation.
6
7
8
9

This is a tractable model, the weights and computation are specially designed
if the config `tractable_init` is set to True. Otherwise, the weights are
initialized randomly with a fixed seed.
10
11
"""
from dataclasses import dataclass
12
from typing import Any, Optional
13

14
import pytest
15
16
17
18
19
import torch
from torch import nn

from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile
20
21
22
from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
                         VllmConfig, set_current_vllm_config)
from vllm.forward_context import BatchDescriptor, set_forward_context
23

24
25
# This import automatically registers `torch.ops.silly.attention`
from .. import silly_attention  # noqa: F401
26
27


28
29
30
31
32
33
@dataclass
class LlamaConfig:
    hidden_size: int = 128
    mlp_size: int = 256
    vocab_size: int = 128
    num_layers: int = 2
34
35
36
37
    init_value: float = 1.0
    tractable_init: bool = False
    random_seed: int = 0

38
    def compute_hash(self) -> str:
39
        factors: list[Any] = []
40
41
42
43
44
45
        for k, v in self.__dict__.items():
            if k == "random_seed":
                continue
            factors.append((k, v))
        factors.sort()
        import hashlib
46
47
        return hashlib.md5(str(factors).encode(),
                           usedforsecurity=False).hexdigest()
48

49
50
    def __post_init__(self):
        assert self.mlp_size >= self.hidden_size
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67


class LlamaMLP(nn.Module):

    def __init__(self, config: LlamaConfig) -> None:
        super().__init__()
        self.gate_up_projection = nn.Linear(
            in_features=config.hidden_size,
            out_features=config.mlp_size * 2,
            bias=False,
        )
        self.down_projection = nn.Linear(
            in_features=config.mlp_size,
            out_features=config.hidden_size,
            bias=False,
        )

68
69
70
71
72
73
74
75
76
77
78
79
80
        if config.tractable_init:
            nn.init.eye_(self.gate_up_projection.weight.data[:config.mlp_size])
            nn.init.eye_(self.gate_up_projection.weight.data[config.mlp_size:])
            nn.init.eye_(self.down_projection.weight.data)
        else:
            nn.init.xavier_normal_(self.gate_up_projection.weight.data,
                                   generator=torch.Generator().manual_seed(
                                       config.random_seed),
                                   gain=0.001)
            nn.init.xavier_normal_(self.down_projection.weight.data,
                                   generator=torch.Generator().manual_seed(
                                       config.random_seed),
                                   gain=0.001)
81
82

    def forward(self, x):
83
84
        # for tractable_init and positive input, this is
        # essentially an elementwise-square
85
86
87
88
89
90
91
92
93
94
95
96
97
98
        x = self.gate_up_projection(x)
        x = x[:, :x.size(1) // 2] * torch.nn.functional.relu(
            x[:, x.size(1) // 2:])
        x = self.down_projection(x)
        return x


class LlamaAttention(nn.Module):

    def __init__(self, config: LlamaConfig) -> None:
        super().__init__()
        self.qkv_projection = nn.Linear(
            in_features=config.hidden_size,
            out_features=config.hidden_size * 3,
99
            bias=False,
100
101
102
103
104
        )

        self.output_projection = nn.Linear(
            in_features=config.hidden_size,
            out_features=config.hidden_size,
105
            bias=False,
106
107
        )

108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
        if config.tractable_init:
            nn.init.eye_(self.qkv_projection.weight.data[:config.hidden_size])
            nn.init.eye_(self.qkv_projection.weight.data[config.hidden_size:2 *
                                                         config.hidden_size])
            nn.init.eye_(self.qkv_projection.weight.data[2 *
                                                         config.hidden_size:])
            nn.init.eye_(self.output_projection.weight.data)
        else:
            nn.init.xavier_normal_(self.qkv_projection.weight.data,
                                   generator=torch.Generator().manual_seed(
                                       config.random_seed),
                                   gain=0.001)
            nn.init.xavier_normal_(self.output_projection.weight.data,
                                   generator=torch.Generator().manual_seed(
                                       config.random_seed),
                                   gain=0.001)
124
125
126
127
128
129

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
130
131
        # for tractable_init, this is:
        # output = (hidden_states * 3 + positions * 2)
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
        qkv = self.qkv_projection(hidden_states)
        hidden_size = qkv.size(-1) // 3
        q, k, v = qkv.split([hidden_size, hidden_size, hidden_size], dim=-1)

        q = q + positions.unsqueeze(1)
        k = k + positions.unsqueeze(1)

        attn_output = torch.empty_like(q)
        torch.ops.silly.attention(q, k, v, attn_output)

        output = self.output_projection(attn_output)
        return output


class LlamaDecoderLayer(nn.Module):

    def __init__(self, config: LlamaConfig) -> None:
        super().__init__()
        self.self_attention = LlamaAttention(config)
        self.mlp = LlamaMLP(config)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        residual: Optional[torch.Tensor],
158
    ) -> tuple[torch.Tensor, torch.Tensor]:
159
160
161
162
163
164
165
166
167
        """
        For tractable computation:
        - if residual is None, the outputs are:
            - residual = (hidden_states + 1) * 3 + positions * 2 + hidden_states = hidden_states * 4 + positions * 2 + 3
            - hidden_states = (residual + 1) ** 2
        - if residual is not None, the outputs are:
            - residual = (hidden_states + residual + 1) * 3 + positions * 2 + hidden_states + residual = (hidden_states + residual) * 4 + positions * 2 + 3
            - hidden_states = (residual + 1) ** 2
        """ # noqa
168
169
        if residual is None:
            residual = hidden_states
170
            hidden_states = hidden_states + 1
171
172
173
        else:
            hidden_states = hidden_states + residual
            residual = hidden_states
174
            hidden_states = hidden_states + 1
175
176
177
178
179
180

        hidden_states = self.self_attention(positions=positions,
                                            hidden_states=hidden_states)

        hidden_states = hidden_states + residual
        residual = hidden_states
181
        hidden_states = hidden_states + 1
182
183
184
185
186
        hidden_states = self.mlp(hidden_states)

        return hidden_states, residual


187
@support_torch_compile
188
189
class LlamaModel(nn.Module):

190
191
192
193
194
195
    def __init__(self,
                 *,
                 vllm_config: VllmConfig,
                 config: LlamaConfig,
                 prefix: str = '',
                 **kwargs) -> None:
196
197
198
199
200
201
202
203
        super().__init__()
        self.embedding_tokens = nn.Embedding(
            num_embeddings=config.vocab_size,
            embedding_dim=config.hidden_size,
        )
        self.layers = nn.ModuleList(
            [LlamaDecoderLayer(config) for _ in range(config.num_layers)])

204
205
        # this is the initial value of the hidden states
        self.embedding_tokens.weight.data.fill_(config.init_value)
206
207
208
209
210
211
212
213
214
215
216
217
218

    def forward(
        self,
        input_ids: Optional[torch.Tensor],
        positions: torch.Tensor,
    ) -> torch.Tensor:
        hidden_states = self.embedding_tokens(input_ids)
        residual = None
        for layer in self.layers:
            hidden_states, residual = layer(positions, hidden_states, residual)
        return hidden_states


219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
def tractable_computation(input_ids: torch.Tensor,
                          positions: torch.Tensor,
                          config: LlamaConfig,
                          init_value: float = 1.0) -> torch.Tensor:
    hidden_states = torch.ones(input_ids.size(0),
                               config.hidden_size,
                               device=input_ids.device,
                               dtype=input_ids.dtype) * init_value

    # first layer
    residual = hidden_states * 4 + positions.unsqueeze(1) * 2 + 3
    hidden_states = (residual + 1)**2

    # following layers
    for _ in range(config.num_layers - 1):
        hidden_states = hidden_states + residual
        residual = hidden_states * 4 + positions.unsqueeze(1) * 2 + 3
        hidden_states = (residual + 1)**2

    return hidden_states


241
242
243
@torch.inference_mode
def run_model(llama_config,
              use_compile: bool,
244
              use_inductor: bool,
245
246
247
              split_attn: bool = False) -> torch.Tensor:

    if use_compile:
248
249
250
        compilation_config = CompilationConfig(
            level=CompilationLevel.PIECEWISE,
            use_cudagraph=True,
251
            use_inductor=use_inductor,
252
            cudagraph_capture_sizes=[1, 2],
253
        )
254
        if split_attn:
255
            compilation_config.splitting_ops = ["silly.attention"]
256
        cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
257
    else:
258
259
        compilation_config = CompilationConfig(
            level=CompilationLevel.NO_COMPILATION, )
260
        cudagraph_runtime_mode = CUDAGraphMode.NONE
261

262
263
    vllm_config = VllmConfig(compilation_config=compilation_config,
                             additional_config=llama_config)
264
265
266
267
    with set_current_vllm_config(vllm_config):
        model = LlamaModel(config=llama_config,
                           vllm_config=vllm_config,
                           prefix="").eval().cuda()
268

269
270
    with set_forward_context({},
                             vllm_config=vllm_config):  # background context
271
272
273
        B = 16  # max batch size
        input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda()
        positions = torch.arange(B).cuda()
274

275
        # warmup for the model with cudagraph_mode NONE
276
        model(input_ids, positions)
277
278
279
280
281
282
283
284
285
286
287
288
289
290

        # simulate cudagraphs capturing
        with set_forward_context({},
                                 vllm_config=vllm_config,
                                 cudagraph_runtime_mode=cudagraph_runtime_mode,
                                 batch_descriptor=BatchDescriptor(
                                     num_tokens=2, )):
            model(input_ids[:2], positions[:2])
        with set_forward_context({},
                                 vllm_config=vllm_config,
                                 cudagraph_runtime_mode=cudagraph_runtime_mode,
                                 batch_descriptor=BatchDescriptor(
                                     num_tokens=1, )):
            model(input_ids[:1], positions[:1])
291

292
        input_ids[:2].zero_()
293
294
295
296
297
298
299
        # simulate cudagraphs replay
        with set_forward_context({},
                                 vllm_config=vllm_config,
                                 cudagraph_runtime_mode=cudagraph_runtime_mode,
                                 batch_descriptor=BatchDescriptor(
                                     num_tokens=2, )):
            output = model(input_ids[:2], positions[:2])
300

301
        output = output.cpu()
302

303
304
305
306
        if llama_config.tractable_init:
            expected_output = tractable_computation(input_ids[:2],
                                                    positions[:2],
                                                    llama_config).cpu()
307

308
309
310
            assert torch.allclose(output, expected_output)
        else:
            return output.cpu()
311
312


313
314
@pytest.mark.parametrize("use_inductor", [True, False])
def test_toy_llama(use_inductor: bool):
315
316
317
318
319
    # compare output with and without piecewise compilation

    llama_config = LlamaConfig(hidden_size=128,
                               mlp_size=256,
                               vocab_size=128,
320
321
322
323
324
325
326
                               num_layers=12)

    tractable_config = LlamaConfig(hidden_size=128,
                                   mlp_size=256,
                                   vocab_size=128,
                                   num_layers=2,
                                   tractable_init=True)
327
328
329
330
331
332

    outputs = []
    with compilation_counter.expect(
            num_graphs_seen=0,
            num_piecewise_graphs_seen=0,
            num_piecewise_capturable_graphs_seen=0,
333
            num_backend_compilations=0,
334
            num_cudagraph_captured=0,
335
    ):
336
337
338
339
340
341
342
343
        outputs.append(
            run_model(llama_config, use_inductor=False, use_compile=False))
    run_model(tractable_config, use_inductor=False, use_compile=False)

    if use_inductor:
        kwargs = {"num_inductor_compiles": 1, "num_eager_compiles": 0}
    else:
        kwargs = {"num_eager_compiles": 1, "num_inductor_compiles": 0}
344

345
346
347
348
    with compilation_counter.expect(
            num_graphs_seen=1,  # one graph for the model
            num_piecewise_graphs_seen=1,
            num_piecewise_capturable_graphs_seen=1,
349
            num_backend_compilations=1,  # num_piecewise_capturable_graphs_seen
350
            num_cudagraph_captured=
351
            2,  # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
352
            **kwargs,
353
    ):
354
355
356
357
358
        outputs.append(
            run_model(llama_config,
                      use_inductor=use_inductor,
                      use_compile=True))
    run_model(tractable_config, use_inductor=use_inductor, use_compile=True)
359
360
361
362
363
364
365

    with compilation_counter.expect(
            num_graphs_seen=1,  # one graph for the model
            num_piecewise_graphs_seen=2 * llama_config.num_layers +
            1,  # 2 * num_layers + 1
            num_piecewise_capturable_graphs_seen=1 +
            llama_config.num_layers,  # 1 + num_layers
366
            num_backend_compilations=1 +
367
            llama_config.num_layers,  # num_piecewise_capturable_graphs_seen
368
            num_cudagraph_captured=2 *
369
370
371
372
        (1 + llama_config.num_layers
         ),  # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
    ):
        outputs.append(
373
374
375
376
377
378
379
380
            run_model(llama_config,
                      use_inductor=use_inductor,
                      use_compile=True,
                      split_attn=True))
    run_model(tractable_config,
              use_inductor=use_inductor,
              use_compile=True,
              split_attn=True)
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412

    for i in range(1, len(outputs)):
        assert torch.allclose(outputs[0], outputs[i])


@torch.inference_mode
def benchmark():
    from triton.testing import do_bench

    # similar to llama 3.1-8B
    llama_config = LlamaConfig(hidden_size=4096,
                               mlp_size=14336,
                               vocab_size=128 * 1024,
                               num_layers=32)

    # a tiny model to measure the overhead
    # of piecewise cudagraph
    llama_config = LlamaConfig(hidden_size=40,
                               mlp_size=80,
                               vocab_size=128,
                               num_layers=2)

    cudagraph_sizes = [1, 2, 4] + [i * 8 for i in range(1, 33)]

    eager_time = {}
    full_cudagraph_time = {}
    piecewise_cudagraph_time = {}

    pool = torch.cuda.graph_pool_handle()

    for piecewise in [False, True]:
        if piecewise:
413
414
415
            compilation_config = CompilationConfig(
                level=CompilationLevel.PIECEWISE,
                use_cudagraph=True,
416
                splitting_ops=["silly.attention"],
417
                cudagraph_capture_sizes=cudagraph_sizes,
418
            )
419
        else:
420
            compilation_config = CompilationConfig(
421
422
423
                level=CompilationLevel.PIECEWISE,
                cudagraph_capture_sizes=cudagraph_sizes,
            )
424

425
        vllm_config = VllmConfig(compilation_config=compilation_config)
426
427
428
429
        with set_current_vllm_config(vllm_config):
            model = LlamaModel(config=llama_config,
                               vllm_config=vllm_config,
                               prefix="").eval().cuda().to(torch.bfloat16)
430
431
432
433
434
435
436

        B = 256  # max batch size
        input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda()
        positions = torch.arange(B).cuda().to(torch.bfloat16)

        graphs = {}

437
438
439
440
441
        model(input_ids, positions)
        for b in cudagraph_sizes[::-1]:
            if not piecewise:
                graph = torch.cuda.CUDAGraph()
                with torch.cuda.graph(graph, pool=pool):
442
                    output = model(input_ids[:b], positions[:b])
443
444
445
446
                graphs[b] = (graph, output)
            else:
                output = model(input_ids[:b], positions[:b])
                graphs[b] = (model, output)
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
        for b in cudagraph_sizes:
            if piecewise:
                # noqa is for `Function definition does not bind loop variable`
                # it will be problematic if we save the created lambda function
                # and use it later, because it will look up the name `b` in the
                # enclosing scope, and the value of `b` will always be 256.
                # it is fine here, because we only use the lambda function once.
                runtime = do_bench(lambda: graphs[b][0]  # noqa
                                   (input_ids[:b], positions[:b]))  # noqa
                piecewise_cudagraph_time[b] = runtime
            else:
                runtime = do_bench(lambda: graphs[b][0].replay())  # noqa
                eager_runtime = do_bench(
                    lambda: model(input_ids[:b], positions[:b]))  # noqa
                full_cudagraph_time[b] = runtime
                eager_time[b] = eager_runtime

    # print in tabular format
    print("batch size\teager mode\tfull cudagraph\tpiecewise cudagraph")
    for b in cudagraph_sizes:
467
468
        print(f"{b}\t{eager_time[b]:.3f}\t{full_cudagraph_time[b]:.3f}"
              f"\t{piecewise_cudagraph_time[b]:.3f}")
469
470
471
472


if __name__ == "__main__":
    benchmark()