"docs/architecture/architecture.md" did not exist on "6d46288c0cad1f07188afde7960c18928b27011c"
benchmark_ngram_proposer.py 6.57 KB
Newer Older
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import gc
4
5
import time
from unittest import mock
6
7

import numpy as np
8
from benchmark_utils import TimeCollector
9
10
from tabulate import tabulate

11
12
13
14
15
16
17
18
19
20
21
from vllm.config import (
    CacheConfig,
    DeviceConfig,
    LoadConfig,
    ModelConfig,
    ParallelConfig,
    SchedulerConfig,
    SpeculativeConfig,
    VllmConfig,
)
from vllm.platforms import current_platform
22
from vllm.utils.argparse_utils import FlexibleArgumentParser
23
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
24
25
from vllm.v1.worker.gpu_input_batch import InputBatch
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
26
27


28
def benchmark_propose(args):
29
30
31
32
33
34
35
36
37
38
    rows = []
    for max_ngram in args.max_ngram:
        collector = TimeCollector(TimeCollector.US)

        model_config = ModelConfig(
            model="facebook/opt-125m",
            max_model_len=args.num_token + args.num_spec_token,
            tokenizer="facebook/opt-125m",
            tokenizer_mode="auto",
            dtype="auto",
39
            seed=0,
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
            trust_remote_code=False,
        )
        proposer = NgramProposer(
            vllm_config=VllmConfig(
                model_config=model_config,
                speculative_config=SpeculativeConfig(
                    prompt_lookup_min=args.min_ngram,
                    prompt_lookup_max=max_ngram,
                    num_speculative_tokens=args.num_spec_token,
                    method="ngram",
                ),
            )
        )

        # Warm up
        proposer.propose(np.random.randint(0, 20, (args.num_token,)))

        gc.collect()
        for _ in range(args.num_iteration):
            tokens = np.random.randint(0, 20, (args.num_req, args.num_token))
            with collector:
                for i in range(args.num_req):
                    proposer.propose(tokens[i, :])
        rows.append(
            [args.num_req, args.num_token, args.min_ngram, max_ngram]
            + collector.dump_avg_max()
        )

    print(
        tabulate(
            rows,
            headers=[
                "# Request",
                "# Token",
                "Min Ngram",
                "Max Ngram",
                "Avg (us)",
                "Max (us)",
            ],
            tablefmt="grid",
            floatfmt=".3f",
        )
    )


85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
def benchmark_batched_propose(args):
    NUM_SPECULATIVE_TOKENS_NGRAM = 10
    PROMPT_LOOKUP_MIN = 5
    PROMPT_LOOKUP_MAX = 15
    MAX_MODEL_LEN = int(1e7)
    DEVICE = current_platform.device_type

    model_config = ModelConfig(model="facebook/opt-125m", runner="generate")

    speculative_config = SpeculativeConfig(
        target_model_config=model_config,
        target_parallel_config=ParallelConfig(),
        method="ngram",
        num_speculative_tokens=NUM_SPECULATIVE_TOKENS_NGRAM,
        prompt_lookup_max=PROMPT_LOOKUP_MAX,
        prompt_lookup_min=PROMPT_LOOKUP_MIN,
    )

    vllm_config = VllmConfig(
        model_config=model_config,
        cache_config=CacheConfig(),
        speculative_config=speculative_config,
        device_config=DeviceConfig(device=current_platform.device_type),
        parallel_config=ParallelConfig(),
        load_config=LoadConfig(),
110
111
112
113
        scheduler_config=SchedulerConfig(
            max_model_len=model_config.max_model_len,
            is_encoder_decoder=model_config.is_encoder_decoder,
        ),
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
    )

    # monkey patch vllm.v1.worker.gpu_model_runner.get_pp_group
    mock_pp_group = mock.MagicMock()
    mock_pp_group.world_size = 1
    with mock.patch(
        "vllm.v1.worker.gpu_model_runner.get_pp_group", return_value=mock_pp_group
    ):
        runner = GPUModelRunner(vllm_config, DEVICE)

        # hack max model len
        runner.max_model_len = MAX_MODEL_LEN
        runner.drafter.max_model_len = MAX_MODEL_LEN

        dummy_input_batch = InputBatch(
            max_num_reqs=args.num_req,
            max_model_len=MAX_MODEL_LEN,
            max_num_batched_tokens=args.num_req * args.num_token,
            device=DEVICE,
            pin_memory=False,
            vocab_size=256000,
            block_sizes=[16],
        )
        dummy_input_batch._req_ids = list(str(id) for id in range(args.num_req))
        dummy_input_batch.spec_decode_unsupported_reqs = ()
        dummy_input_batch.num_tokens_no_spec = [args.num_token] * args.num_req
        dummy_input_batch.token_ids_cpu = np.random.randint(
            0, 20, (args.num_req, args.num_token)
        )

        runner.input_batch = dummy_input_batch

        sampled_token_ids = [[0]] * args.num_req

        print("Starting benchmark")
        # first run is warmup so ignore it
        for _ in range(args.num_iteration):
            start = time.time()
            runner.drafter.propose(
                sampled_token_ids,
                dummy_input_batch.req_ids,
                dummy_input_batch.num_tokens_no_spec,
                dummy_input_batch.token_ids_cpu,
                dummy_input_batch.spec_decode_unsupported_reqs,
            )
            end = time.time()
            print(f"Iteration time (s): {end - start}")


163
164
165
166
def invoke_main() -> None:
    parser = FlexibleArgumentParser(
        description="Benchmark the performance of N-gram speculative decode drafting"
    )
167
168
    parser.add_argument(
        "--batched", action="store_true", help="consider time to prepare batch"
169
    )
170
171
172
173
    parser.add_argument(
        "--num-iteration",
        type=int,
        default=100,
co63oc's avatar
co63oc committed
174
        help="Number of iterations to run to stabilize final data readings",
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
    )
    parser.add_argument(
        "--num-req", type=int, default=128, help="Number of requests in the batch"
    )
    parser.add_argument(
        "--num-token", type=int, default=1500, help="Number of tokens for each request"
    )
    parser.add_argument(
        "--min-ngram",
        type=int,
        default=3,
        help="Minimum n-gram to match",
    )
    parser.add_argument(
        "--max-ngram",
        type=int,
        nargs="*",
        default=[5, 7, 10, 15, 20],
        help="Maximum n-gram to match",
    )
    parser.add_argument(
        "--num-spec-token",
        type=int,
        default=3,
        help="Number of speculative tokens to generate",
    )
    args = parser.parse_args()
202
203
204
205
206

    if not args.batched:
        benchmark_propose(args)
    else:
        benchmark_batched_propose(args)
207
208


209
210
211
212
213
"""
# Example command lines:
# time python3 benchmarks/benchmark_ngram_proposer.py
# time python3 benchmarks/benchmark_ngram_proposer.py --batched --num-iteration 4 --num-token 1000000 --num-req 128
"""  # noqa: E501
214
215
if __name__ == "__main__":
    invoke_main()  # pragma: no cover