test_toy_llama.py 16.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
"""
Test the piecewise compilation with a simple model, comparing the output
with and without the piecewise compilation.
6
7
8
9

This is a tractable model, the weights and computation are specially designed
if the config `tractable_init` is set to True. Otherwise, the weights are
initialized randomly with a fixed seed.
10
11
"""
from dataclasses import dataclass
12
from typing import Any, Optional
13

14
import pytest
15
16
import torch
from torch import nn
17
from torch.library import Library
18
19
20

from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile
21
22
from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig,
                         set_current_vllm_config)
23
from vllm.forward_context import set_forward_context
24
25
26
27
from vllm.utils import direct_register_custom_op

# create a library to hold the custom op
silly_lib = Library("silly", "FRAGMENT")  # noqa
28
29
30
31
32
33
34
35
36


def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
                    out: torch.Tensor) -> None:
    out.copy_(q)
    out += k
    out += v


37
38
def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
                         out: torch.Tensor) -> None:
39
40
41
    return


42
43
44
45
46
47
48
49
50
direct_register_custom_op(
    op_name="attention",
    op_func=silly_attention,
    mutates_args=["out"],
    fake_impl=silly_attention_fake,
    target_lib=silly_lib,
)


51
52
53
54
55
56
@dataclass
class LlamaConfig:
    hidden_size: int = 128
    mlp_size: int = 256
    vocab_size: int = 128
    num_layers: int = 2
57
58
59
60
    init_value: float = 1.0
    tractable_init: bool = False
    random_seed: int = 0

61
    def compute_hash(self) -> str:
62
        factors: list[Any] = []
63
64
65
66
67
68
        for k, v in self.__dict__.items():
            if k == "random_seed":
                continue
            factors.append((k, v))
        factors.sort()
        import hashlib
69
70
        return hashlib.md5(str(factors).encode(),
                           usedforsecurity=False).hexdigest()
71

72
73
    def __post_init__(self):
        assert self.mlp_size >= self.hidden_size
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90


class LlamaMLP(nn.Module):

    def __init__(self, config: LlamaConfig) -> None:
        super().__init__()
        self.gate_up_projection = nn.Linear(
            in_features=config.hidden_size,
            out_features=config.mlp_size * 2,
            bias=False,
        )
        self.down_projection = nn.Linear(
            in_features=config.mlp_size,
            out_features=config.hidden_size,
            bias=False,
        )

91
92
93
94
95
96
97
98
99
100
101
102
103
        if config.tractable_init:
            nn.init.eye_(self.gate_up_projection.weight.data[:config.mlp_size])
            nn.init.eye_(self.gate_up_projection.weight.data[config.mlp_size:])
            nn.init.eye_(self.down_projection.weight.data)
        else:
            nn.init.xavier_normal_(self.gate_up_projection.weight.data,
                                   generator=torch.Generator().manual_seed(
                                       config.random_seed),
                                   gain=0.001)
            nn.init.xavier_normal_(self.down_projection.weight.data,
                                   generator=torch.Generator().manual_seed(
                                       config.random_seed),
                                   gain=0.001)
104
105

    def forward(self, x):
106
107
        # for tractable_init and positive input, this is
        # essentially an elementwise-square
108
109
110
111
112
113
114
115
116
117
118
119
120
121
        x = self.gate_up_projection(x)
        x = x[:, :x.size(1) // 2] * torch.nn.functional.relu(
            x[:, x.size(1) // 2:])
        x = self.down_projection(x)
        return x


class LlamaAttention(nn.Module):

    def __init__(self, config: LlamaConfig) -> None:
        super().__init__()
        self.qkv_projection = nn.Linear(
            in_features=config.hidden_size,
            out_features=config.hidden_size * 3,
122
            bias=False,
123
124
125
126
127
        )

        self.output_projection = nn.Linear(
            in_features=config.hidden_size,
            out_features=config.hidden_size,
128
            bias=False,
129
130
        )

131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
        if config.tractable_init:
            nn.init.eye_(self.qkv_projection.weight.data[:config.hidden_size])
            nn.init.eye_(self.qkv_projection.weight.data[config.hidden_size:2 *
                                                         config.hidden_size])
            nn.init.eye_(self.qkv_projection.weight.data[2 *
                                                         config.hidden_size:])
            nn.init.eye_(self.output_projection.weight.data)
        else:
            nn.init.xavier_normal_(self.qkv_projection.weight.data,
                                   generator=torch.Generator().manual_seed(
                                       config.random_seed),
                                   gain=0.001)
            nn.init.xavier_normal_(self.output_projection.weight.data,
                                   generator=torch.Generator().manual_seed(
                                       config.random_seed),
                                   gain=0.001)
147
148
149
150
151
152

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
153
154
        # for tractable_init, this is:
        # output = (hidden_states * 3 + positions * 2)
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
        qkv = self.qkv_projection(hidden_states)
        hidden_size = qkv.size(-1) // 3
        q, k, v = qkv.split([hidden_size, hidden_size, hidden_size], dim=-1)

        q = q + positions.unsqueeze(1)
        k = k + positions.unsqueeze(1)

        attn_output = torch.empty_like(q)
        torch.ops.silly.attention(q, k, v, attn_output)

        output = self.output_projection(attn_output)
        return output


class LlamaDecoderLayer(nn.Module):

    def __init__(self, config: LlamaConfig) -> None:
        super().__init__()
        self.self_attention = LlamaAttention(config)
        self.mlp = LlamaMLP(config)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        residual: Optional[torch.Tensor],
181
    ) -> tuple[torch.Tensor, torch.Tensor]:
182
183
184
185
186
187
188
189
190
        """
        For tractable computation:
        - if residual is None, the outputs are:
            - residual = (hidden_states + 1) * 3 + positions * 2 + hidden_states = hidden_states * 4 + positions * 2 + 3
            - hidden_states = (residual + 1) ** 2
        - if residual is not None, the outputs are:
            - residual = (hidden_states + residual + 1) * 3 + positions * 2 + hidden_states + residual = (hidden_states + residual) * 4 + positions * 2 + 3
            - hidden_states = (residual + 1) ** 2
        """ # noqa
191
192
        if residual is None:
            residual = hidden_states
193
            hidden_states = hidden_states + 1
194
195
196
        else:
            hidden_states = hidden_states + residual
            residual = hidden_states
197
            hidden_states = hidden_states + 1
198
199
200
201
202
203

        hidden_states = self.self_attention(positions=positions,
                                            hidden_states=hidden_states)

        hidden_states = hidden_states + residual
        residual = hidden_states
204
        hidden_states = hidden_states + 1
205
206
207
208
209
        hidden_states = self.mlp(hidden_states)

        return hidden_states, residual


210
@support_torch_compile
211
212
class LlamaModel(nn.Module):

213
214
215
216
217
218
    def __init__(self,
                 *,
                 vllm_config: VllmConfig,
                 config: LlamaConfig,
                 prefix: str = '',
                 **kwargs) -> None:
219
220
221
222
223
224
225
226
        super().__init__()
        self.embedding_tokens = nn.Embedding(
            num_embeddings=config.vocab_size,
            embedding_dim=config.hidden_size,
        )
        self.layers = nn.ModuleList(
            [LlamaDecoderLayer(config) for _ in range(config.num_layers)])

227
228
        # this is the initial value of the hidden states
        self.embedding_tokens.weight.data.fill_(config.init_value)
229
230
231
232
233
234
235
236
237
238
239
240
241

    def forward(
        self,
        input_ids: Optional[torch.Tensor],
        positions: torch.Tensor,
    ) -> torch.Tensor:
        hidden_states = self.embedding_tokens(input_ids)
        residual = None
        for layer in self.layers:
            hidden_states, residual = layer(positions, hidden_states, residual)
        return hidden_states


242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
def tractable_computation(input_ids: torch.Tensor,
                          positions: torch.Tensor,
                          config: LlamaConfig,
                          init_value: float = 1.0) -> torch.Tensor:
    hidden_states = torch.ones(input_ids.size(0),
                               config.hidden_size,
                               device=input_ids.device,
                               dtype=input_ids.dtype) * init_value

    # first layer
    residual = hidden_states * 4 + positions.unsqueeze(1) * 2 + 3
    hidden_states = (residual + 1)**2

    # following layers
    for _ in range(config.num_layers - 1):
        hidden_states = hidden_states + residual
        residual = hidden_states * 4 + positions.unsqueeze(1) * 2 + 3
        hidden_states = (residual + 1)**2

    return hidden_states


264
265
266
@torch.inference_mode
def run_model(llama_config,
              use_compile: bool,
267
              use_inductor: bool,
268
269
270
              split_attn: bool = False) -> torch.Tensor:

    if use_compile:
271
272
273
        compilation_config = CompilationConfig(
            level=CompilationLevel.PIECEWISE,
            use_cudagraph=True,
274
            use_inductor=use_inductor,
275
            cudagraph_capture_sizes=[1, 2],
276
        )
277
        if split_attn:
278
            compilation_config.splitting_ops = ["silly.attention"]
279
    else:
280
281
        compilation_config = CompilationConfig(
            level=CompilationLevel.NO_COMPILATION, )
282

283
284
    vllm_config = VllmConfig(compilation_config=compilation_config,
                             additional_config=llama_config)
285
286
287
288
    with set_current_vllm_config(vllm_config):
        model = LlamaModel(config=llama_config,
                           vllm_config=vllm_config,
                           prefix="").eval().cuda()
289

290
291
292
293
    with set_forward_context({}, vllm_config=vllm_config):
        B = 16  # max batch size
        input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda()
        positions = torch.arange(B).cuda()
294

295
296
297
        model(input_ids, positions)
        model(input_ids[:2], positions[:2])
        model(input_ids[:1], positions[:1])
298

299
300
        input_ids[:2].zero_()
        output = model(input_ids[:2], positions[:2])
301

302
        output = output.cpu()
303

304
305
306
307
        if llama_config.tractable_init:
            expected_output = tractable_computation(input_ids[:2],
                                                    positions[:2],
                                                    llama_config).cpu()
308

309
310
311
            assert torch.allclose(output, expected_output)
        else:
            return output.cpu()
312
313


314
315
@pytest.mark.parametrize("use_inductor", [True, False])
def test_toy_llama(use_inductor: bool):
316
317
318
319
320
    # compare output with and without piecewise compilation

    llama_config = LlamaConfig(hidden_size=128,
                               mlp_size=256,
                               vocab_size=128,
321
322
323
324
325
326
327
                               num_layers=12)

    tractable_config = LlamaConfig(hidden_size=128,
                                   mlp_size=256,
                                   vocab_size=128,
                                   num_layers=2,
                                   tractable_init=True)
328
329
330
331
332
333

    outputs = []
    with compilation_counter.expect(
            num_graphs_seen=0,
            num_piecewise_graphs_seen=0,
            num_piecewise_capturable_graphs_seen=0,
334
            num_backend_compilations=0,
335
            num_cudagraph_captured=0,
336
    ):
337
338
339
340
341
342
343
344
        outputs.append(
            run_model(llama_config, use_inductor=False, use_compile=False))
    run_model(tractable_config, use_inductor=False, use_compile=False)

    if use_inductor:
        kwargs = {"num_inductor_compiles": 1, "num_eager_compiles": 0}
    else:
        kwargs = {"num_eager_compiles": 1, "num_inductor_compiles": 0}
345

346
347
348
349
    with compilation_counter.expect(
            num_graphs_seen=1,  # one graph for the model
            num_piecewise_graphs_seen=1,
            num_piecewise_capturable_graphs_seen=1,
350
            num_backend_compilations=1,  # num_piecewise_capturable_graphs_seen
351
            num_cudagraph_captured=
352
            2,  # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
353
            **kwargs,
354
    ):
355
356
357
358
359
        outputs.append(
            run_model(llama_config,
                      use_inductor=use_inductor,
                      use_compile=True))
    run_model(tractable_config, use_inductor=use_inductor, use_compile=True)
360
361
362
363
364
365
366

    with compilation_counter.expect(
            num_graphs_seen=1,  # one graph for the model
            num_piecewise_graphs_seen=2 * llama_config.num_layers +
            1,  # 2 * num_layers + 1
            num_piecewise_capturable_graphs_seen=1 +
            llama_config.num_layers,  # 1 + num_layers
367
            num_backend_compilations=1 +
368
            llama_config.num_layers,  # num_piecewise_capturable_graphs_seen
369
            num_cudagraph_captured=2 *
370
371
372
373
        (1 + llama_config.num_layers
         ),  # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
    ):
        outputs.append(
374
375
376
377
378
379
380
381
            run_model(llama_config,
                      use_inductor=use_inductor,
                      use_compile=True,
                      split_attn=True))
    run_model(tractable_config,
              use_inductor=use_inductor,
              use_compile=True,
              split_attn=True)
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413

    for i in range(1, len(outputs)):
        assert torch.allclose(outputs[0], outputs[i])


@torch.inference_mode
def benchmark():
    from triton.testing import do_bench

    # similar to llama 3.1-8B
    llama_config = LlamaConfig(hidden_size=4096,
                               mlp_size=14336,
                               vocab_size=128 * 1024,
                               num_layers=32)

    # a tiny model to measure the overhead
    # of piecewise cudagraph
    llama_config = LlamaConfig(hidden_size=40,
                               mlp_size=80,
                               vocab_size=128,
                               num_layers=2)

    cudagraph_sizes = [1, 2, 4] + [i * 8 for i in range(1, 33)]

    eager_time = {}
    full_cudagraph_time = {}
    piecewise_cudagraph_time = {}

    pool = torch.cuda.graph_pool_handle()

    for piecewise in [False, True]:
        if piecewise:
414
415
416
            compilation_config = CompilationConfig(
                level=CompilationLevel.PIECEWISE,
                use_cudagraph=True,
417
                splitting_ops=["silly.attention"],
418
                cudagraph_capture_sizes=cudagraph_sizes,
419
            )
420
        else:
421
            compilation_config = CompilationConfig(
422
423
424
                level=CompilationLevel.PIECEWISE,
                cudagraph_capture_sizes=cudagraph_sizes,
            )
425

426
        vllm_config = VllmConfig(compilation_config=compilation_config)
427
428
429
430
        with set_current_vllm_config(vllm_config):
            model = LlamaModel(config=llama_config,
                               vllm_config=vllm_config,
                               prefix="").eval().cuda().to(torch.bfloat16)
431
432
433
434
435
436
437

        B = 256  # max batch size
        input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda()
        positions = torch.arange(B).cuda().to(torch.bfloat16)

        graphs = {}

438
439
440
441
442
        model(input_ids, positions)
        for b in cudagraph_sizes[::-1]:
            if not piecewise:
                graph = torch.cuda.CUDAGraph()
                with torch.cuda.graph(graph, pool=pool):
443
                    output = model(input_ids[:b], positions[:b])
444
445
446
447
                graphs[b] = (graph, output)
            else:
                output = model(input_ids[:b], positions[:b])
                graphs[b] = (model, output)
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
        for b in cudagraph_sizes:
            if piecewise:
                # noqa is for `Function definition does not bind loop variable`
                # it will be problematic if we save the created lambda function
                # and use it later, because it will look up the name `b` in the
                # enclosing scope, and the value of `b` will always be 256.
                # it is fine here, because we only use the lambda function once.
                runtime = do_bench(lambda: graphs[b][0]  # noqa
                                   (input_ids[:b], positions[:b]))  # noqa
                piecewise_cudagraph_time[b] = runtime
            else:
                runtime = do_bench(lambda: graphs[b][0].replay())  # noqa
                eager_runtime = do_bench(
                    lambda: model(input_ids[:b], positions[:b]))  # noqa
                full_cudagraph_time[b] = runtime
                eager_time[b] = eager_runtime

    # print in tabular format
    print("batch size\teager mode\tfull cudagraph\tpiecewise cudagraph")
    for b in cudagraph_sizes:
468
469
        print(f"{b}\t{eager_time[b]:.3f}\t{full_cudagraph_time[b]:.3f}"
              f"\t{piecewise_cudagraph_time[b]:.3f}")
470
471
472
473


if __name__ == "__main__":
    benchmark()