spec_decode.py 11 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7

from transformers import AutoTokenizer

from vllm import LLM, SamplingParams
from vllm.benchmarks.datasets import add_dataset_parser, get_samples
8
from vllm.utils.argparse_utils import FlexibleArgumentParser
9
10
from vllm.v1.metrics.reader import Counter, Vector

11
12
QUESTION = "What is the content of each image?"
IMAGE_URLS = [
13
14
15
16
17
18
19
20
21
22
23
24
    "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/duck.jpg",
    "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/lion.jpg",
    "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/flycatcher.jpeg",
    "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/somefish.jpg",
    "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/starfish.jpg",
    "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/snail.jpg",
    "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/thistle.jpg",
    "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/husky.jpg",
    "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/orangetabbycat.jpg",
    "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/guineapig.jpg",
    "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/rabbit.jpg",
    "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/horsepony.jpg",
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
]


def get_custom_mm_prompts(num_prompts):
    prompts = []
    for url in IMAGE_URLS:
        prompts.append(
            [
                {"type": "image_url", "image_url": {"url": url}},
                {"type": "text", "text": QUESTION},
            ]
        )
    if num_prompts > len(IMAGE_URLS):
        prompts = prompts * (num_prompts // len(IMAGE_URLS) + 1)

    return [[{"role": "user", "content": prompt}] for prompt in prompts[:num_prompts]]


43
44
45
def parse_args():
    parser = FlexibleArgumentParser()
    add_dataset_parser(parser)
46
    parser.add_argument("--test", action="store_true")
47
    parser.add_argument(
48
        "--method",
49
        type=str,
50
        default="eagle",
51
        choices=["ngram", "eagle", "eagle3", "mtp", "draft_model"],
52
    )
53
    parser.add_argument("--backend", type=str, default="openai")
54
55
56
57
58
59
    parser.add_argument("--num-spec-tokens", type=int, default=2)
    parser.add_argument("--prompt-lookup-max", type=int, default=5)
    parser.add_argument("--prompt-lookup-min", type=int, default=2)
    parser.add_argument("--tp", type=int, default=1)
    parser.add_argument("--enforce-eager", action="store_true")
    parser.add_argument("--enable-chunked-prefill", action="store_true")
zhangning3's avatar
zhangning3 committed
60
61
62
63
64
    parser.add_argument(
        "--enable-multi-layers-mtp",
        action="store_true",
        help="Enable multi-layer MTP (only effective when --method=mtp).",
    )
65
    parser.add_argument("--max-model-len", type=int, default=16384)
66
67
68
69
70
    parser.add_argument("--temp", type=float, default=0)
    parser.add_argument("--top-p", type=float, default=1.0)
    parser.add_argument("--top-k", type=int, default=-1)
    parser.add_argument("--print-output", action="store_true")
    parser.add_argument("--output-len", type=int, default=256)
71
72
    parser.add_argument("--model-dir", type=str, default=None)
    parser.add_argument("--eagle-dir", type=str, default=None)
73
    parser.add_argument("--draft-model", type=str, default=None)
zhangning3's avatar
zhangning3 committed
74
    parser.add_argument("--tokenizer-dir", type=str, default=None)
75
    parser.add_argument("--custom-mm-prompts", action="store_true")
76
77
78
    parser.add_argument("--gpu-memory-utilization", type=float, default=0.9)
    parser.add_argument("--disable-padded-drafter-batch", action="store_true")
    parser.add_argument("--max-num-seqs", type=int, default=None)
79
    parser.add_argument("--parallel-drafting", action="store_true")
80
    parser.add_argument("--allowed-local-media-path", type=str, default="")
zhangning3's avatar
zhangning3 committed
81
    parser.add_argument("--trust-remote-code", action="store_true")
82
83
84
    return parser.parse_args()


85
def main(args):
86
87
    model_dir = args.model_dir
    if args.model_dir is None:
88
89
90
91
92
93
        if args.custom_mm_prompts:
            raise ValueError(
                "custom_mm_prompts requires mm based models"
                "default llama3.1-8b-instruct is not mm based"
                "please specify model_dir to give a mm based model"
            )
94
        model_dir = "meta-llama/Llama-3.1-8B-Instruct"
zhangning3's avatar
zhangning3 committed
95
96
97
98
99
    tokenizer_dir = args.tokenizer_dir
    if tokenizer_dir is None:
        tokenizer_dir = model_dir

    tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir)
100

101
102
    if args.custom_mm_prompts:
        prompts = llm_prompts = get_custom_mm_prompts(args.num_prompts)
103
    else:
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
        prompts = get_samples(args, tokenizer)
        if args.enable_multimodal_chat:
            llm_prompts = [p.prompt for p in prompts]
        else:
            # add_special_tokens is False to avoid adding bos twice
            # when using chat templates
            llm_prompts = [
                {
                    "prompt_token_ids": tokenizer.encode(
                        prompt.prompt, add_special_tokens=False
                    ),
                    "multi_modal_data": prompt.multi_modal_data,
                }
                for prompt in prompts
            ]
119
    if args.method == "eagle" or args.method == "eagle3":
120
121
        eagle_dir = args.eagle_dir
        if args.method == "eagle" and eagle_dir is None:
122
            eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
123
124

        elif args.method == "eagle3" and eagle_dir is None:
125
126
127
128
129
            eagle_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
        speculative_config = {
            "method": args.method,
            "model": eagle_dir,
            "num_speculative_tokens": args.num_spec_tokens,
130
            "disable_padded_drafter_batch": args.disable_padded_drafter_batch,
131
            "parallel_drafting": args.parallel_drafting,
132
133
134
135
136
137
138
139
        }
    elif args.method == "ngram":
        speculative_config = {
            "method": "ngram",
            "num_speculative_tokens": args.num_spec_tokens,
            "prompt_lookup_max": args.prompt_lookup_max,
            "prompt_lookup_min": args.prompt_lookup_min,
        }
140
141
142
143
144
145
146
147
    elif args.method == "draft_model":
        assert args.draft_model is not None and args.draft_model != ""
        speculative_config = {
            "method": args.method,
            "model": args.draft_model,
            "num_speculative_tokens": args.num_spec_tokens,
            "enforce_eager": args.enforce_eager,
            "max_model_len": args.max_model_len,
148
            "parallel_drafting": args.parallel_drafting,
149
        }
150
    elif args.method == "mtp":
151
        speculative_config = {
152
            "method": "mtp",
153
154
            "num_speculative_tokens": args.num_spec_tokens,
        }
zhangning3's avatar
zhangning3 committed
155
156
        if args.enable_multi_layers_mtp:
            speculative_config["enable_multi_layers_mtp"] = True
157
158
159
160
161
162
163
164
165
    else:
        raise ValueError(f"unknown method: {args.method}")

    llm = LLM(
        model=model_dir,
        trust_remote_code=True,
        tensor_parallel_size=args.tp,
        enable_chunked_prefill=args.enable_chunked_prefill,
        enforce_eager=args.enforce_eager,
166
        gpu_memory_utilization=args.gpu_memory_utilization,
167
168
        speculative_config=speculative_config,
        disable_log_stats=False,
169
        max_model_len=args.max_model_len,
170
171
        limit_mm_per_prompt={"image": 5},
        disable_chunked_mm_input=True,
172
        max_num_seqs=args.max_num_seqs,
173
        allowed_local_media_path=args.allowed_local_media_path,
174
175
176
    )

    sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len)
177
178
179
    if args.backend == "openai-chat":
        outputs = llm.chat(llm_prompts, sampling_params=sampling_params)
    else:
180
        outputs = llm.generate(
181
            llm_prompts,
182
            sampling_params=sampling_params,
183
        )
184
185
186

    # print the generated text
    if args.print_output:
187
        for i, output in enumerate(outputs):
188
            print("-" * 50)
189
190
191
192
            if not args.custom_mm_prompts:
                print(f"prompt: {prompts[i].prompt}")
            else:
                print(f"prompt: {prompts[i]}")
193
194
195
            print(f"generated text: {output.outputs[0].text}")
            print("-" * 50)

196
    metrics = llm.get_metrics()
197

198
199
200
201
202
203
    total_num_output_tokens = sum(
        len(output.outputs[0].token_ids) for output in outputs
    )
    num_drafts = 0
    num_draft_tokens = 0
    num_accepted_tokens = 0
204
205
206
207
208
    acceptance_counts = [0] * args.num_spec_tokens
    for metric in metrics:
        if metric.name == "vllm:spec_decode_num_drafts":
            assert isinstance(metric, Counter)
            num_drafts += metric.value
209
210
211
        elif metric.name == "vllm:spec_decode_num_draft_tokens":
            assert isinstance(metric, Counter)
            num_draft_tokens += metric.value
212
213
        elif metric.name == "vllm:spec_decode_num_accepted_tokens":
            assert isinstance(metric, Counter)
214
            num_accepted_tokens += metric.value
215
216
217
218
219
220
        elif metric.name == "vllm:spec_decode_num_accepted_tokens_per_pos":
            assert isinstance(metric, Vector)
            for pos in range(len(metric.values)):
                acceptance_counts[pos] += metric.values[pos]

    print("-" * 50)
221
222
223
224
225
226
    print(f"total_num_output_tokens: {total_num_output_tokens}")
    print(f"num_drafts: {num_drafts}")
    print(f"num_draft_tokens: {num_draft_tokens}")
    print(f"num_accepted_tokens: {num_accepted_tokens}")
    acceptance_length = 1 + (num_accepted_tokens / num_drafts) if num_drafts > 0 else 1
    print(f"mean acceptance length: {acceptance_length:.2f}")
227
228
229
230
    print("-" * 50)

    # print acceptance at each token position
    for i in range(len(acceptance_counts)):
231
232
        acceptance_rate = acceptance_counts[i] / num_drafts if num_drafts > 0 else 0
        print(f"acceptance at token {i}: {acceptance_rate:.2f}")
233

234
235
    return acceptance_length

236
237

if __name__ == "__main__":
238
    args = parse_args()
239
240
    args.enable_multimodal_chat = args.backend == "openai-chat"

241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
    acceptance_length = main(args)

    if args.test:
        # takes ~30s to run on 1xH100
        assert args.method in ["eagle", "eagle3"]
        assert args.tp == 1
        assert args.num_spec_tokens == 3
        assert args.dataset_name == "hf"
        assert args.dataset_path == "philschmid/mt-bench"
        assert args.num_prompts == 80
        assert args.temp == 0
        assert args.top_p == 1.0
        assert args.top_k == -1
        assert args.enable_chunked_prefill

        # check acceptance length is within 2% of expected value
        rtol = 0.02
        expected_acceptance_length = 2.296 if args.method == "eagle" else 2.811

        assert (
            acceptance_length <= (1 + rtol) * expected_acceptance_length
            and acceptance_length >= (1 - rtol) * expected_acceptance_length
        ), (
            f"acceptance_length {acceptance_length} is not "
            f"within {rtol * 100}% of {expected_acceptance_length}"
        )

        print(
            f"Test passed! Expected AL: "
            f"{expected_acceptance_length}, got {acceptance_length}"
        )