spec_decode.py 10.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10

from transformers import AutoTokenizer

from vllm import LLM, SamplingParams
from vllm.benchmarks.datasets import add_dataset_parser, get_samples
from vllm.v1.metrics.reader import Counter, Vector

try:
11
    from vllm.utils.argparse_utils import FlexibleArgumentParser
12
13
14
15
except ImportError:
    from argparse import ArgumentParser as FlexibleArgumentParser


16
17
QUESTION = "What is the content of each image?"
IMAGE_URLS = [
18
19
20
21
22
23
24
25
26
27
28
29
    "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/duck.jpg",
    "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/lion.jpg",
    "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/flycatcher.jpeg",
    "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/somefish.jpg",
    "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/starfish.jpg",
    "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/snail.jpg",
    "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/thistle.jpg",
    "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/husky.jpg",
    "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/orangetabbycat.jpg",
    "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/guineapig.jpg",
    "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/rabbit.jpg",
    "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/horsepony.jpg",
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
]


def get_custom_mm_prompts(num_prompts):
    prompts = []
    for url in IMAGE_URLS:
        prompts.append(
            [
                {"type": "image_url", "image_url": {"url": url}},
                {"type": "text", "text": QUESTION},
            ]
        )
    if num_prompts > len(IMAGE_URLS):
        prompts = prompts * (num_prompts // len(IMAGE_URLS) + 1)

    return [[{"role": "user", "content": prompt}] for prompt in prompts[:num_prompts]]


48
49
50
def parse_args():
    parser = FlexibleArgumentParser()
    add_dataset_parser(parser)
51
    parser.add_argument("--test", action="store_true")
52
    parser.add_argument(
53
        "--method",
54
        type=str,
55
        default="eagle",
56
        choices=["ngram", "eagle", "eagle3", "mtp", "draft_model"],
57
    )
58
    parser.add_argument("--backend", type=str, default="openai")
59
60
61
62
63
64
    parser.add_argument("--num-spec-tokens", type=int, default=2)
    parser.add_argument("--prompt-lookup-max", type=int, default=5)
    parser.add_argument("--prompt-lookup-min", type=int, default=2)
    parser.add_argument("--tp", type=int, default=1)
    parser.add_argument("--enforce-eager", action="store_true")
    parser.add_argument("--enable-chunked-prefill", action="store_true")
65
    parser.add_argument("--max-model-len", type=int, default=16384)
66
67
68
69
70
    parser.add_argument("--temp", type=float, default=0)
    parser.add_argument("--top-p", type=float, default=1.0)
    parser.add_argument("--top-k", type=int, default=-1)
    parser.add_argument("--print-output", action="store_true")
    parser.add_argument("--output-len", type=int, default=256)
71
72
    parser.add_argument("--model-dir", type=str, default=None)
    parser.add_argument("--eagle-dir", type=str, default=None)
73
    parser.add_argument("--draft-model", type=str, default=None)
74
    parser.add_argument("--custom-mm-prompts", action="store_true")
75
76
77
    parser.add_argument("--gpu-memory-utilization", type=float, default=0.9)
    parser.add_argument("--disable-padded-drafter-batch", action="store_true")
    parser.add_argument("--max-num-seqs", type=int, default=None)
78
    parser.add_argument("--allowed-local-media-path", type=str, default="")
79
80
81
    return parser.parse_args()


82
def main(args):
83
84
    model_dir = args.model_dir
    if args.model_dir is None:
85
86
87
88
89
90
        if args.custom_mm_prompts:
            raise ValueError(
                "custom_mm_prompts requires mm based models"
                "default llama3.1-8b-instruct is not mm based"
                "please specify model_dir to give a mm based model"
            )
91
        model_dir = "meta-llama/Llama-3.1-8B-Instruct"
92
93
    tokenizer = AutoTokenizer.from_pretrained(model_dir)

94
95
    if args.custom_mm_prompts:
        prompts = llm_prompts = get_custom_mm_prompts(args.num_prompts)
96
    else:
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
        prompts = get_samples(args, tokenizer)
        if args.enable_multimodal_chat:
            llm_prompts = [p.prompt for p in prompts]
        else:
            # add_special_tokens is False to avoid adding bos twice
            # when using chat templates
            llm_prompts = [
                {
                    "prompt_token_ids": tokenizer.encode(
                        prompt.prompt, add_special_tokens=False
                    ),
                    "multi_modal_data": prompt.multi_modal_data,
                }
                for prompt in prompts
            ]
112
    if args.method == "eagle" or args.method == "eagle3":
113
114
        eagle_dir = args.eagle_dir
        if args.method == "eagle" and eagle_dir is None:
115
            eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
116
117

        elif args.method == "eagle3" and eagle_dir is None:
118
119
120
121
122
            eagle_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
        speculative_config = {
            "method": args.method,
            "model": eagle_dir,
            "num_speculative_tokens": args.num_spec_tokens,
123
            "disable_padded_drafter_batch": args.disable_padded_drafter_batch,
124
125
126
127
128
129
130
131
        }
    elif args.method == "ngram":
        speculative_config = {
            "method": "ngram",
            "num_speculative_tokens": args.num_spec_tokens,
            "prompt_lookup_max": args.prompt_lookup_max,
            "prompt_lookup_min": args.prompt_lookup_min,
        }
132
133
134
135
136
137
138
139
140
    elif args.method == "draft_model":
        assert args.draft_model is not None and args.draft_model != ""
        speculative_config = {
            "method": args.method,
            "model": args.draft_model,
            "num_speculative_tokens": args.num_spec_tokens,
            "enforce_eager": args.enforce_eager,
            "max_model_len": args.max_model_len,
        }
141
    elif args.method == "mtp":
142
        speculative_config = {
143
            "method": "mtp",
144
145
            "num_speculative_tokens": args.num_spec_tokens,
        }
146
147
148
149
150
151
152
153
154
    else:
        raise ValueError(f"unknown method: {args.method}")

    llm = LLM(
        model=model_dir,
        trust_remote_code=True,
        tensor_parallel_size=args.tp,
        enable_chunked_prefill=args.enable_chunked_prefill,
        enforce_eager=args.enforce_eager,
155
        gpu_memory_utilization=args.gpu_memory_utilization,
156
157
        speculative_config=speculative_config,
        disable_log_stats=False,
158
        max_model_len=args.max_model_len,
159
160
        limit_mm_per_prompt={"image": 5},
        disable_chunked_mm_input=True,
161
        max_num_seqs=args.max_num_seqs,
162
        allowed_local_media_path=args.allowed_local_media_path,
163
164
165
    )

    sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len)
166
167
168
    if args.backend == "openai-chat":
        outputs = llm.chat(llm_prompts, sampling_params=sampling_params)
    else:
169
        outputs = llm.generate(
170
            llm_prompts,
171
            sampling_params=sampling_params,
172
        )
173
174
175

    # print the generated text
    if args.print_output:
176
        for i, output in enumerate(outputs):
177
            print("-" * 50)
178
179
180
181
            if not args.custom_mm_prompts:
                print(f"prompt: {prompts[i].prompt}")
            else:
                print(f"prompt: {prompts[i]}")
182
183
184
            print(f"generated text: {output.outputs[0].text}")
            print("-" * 50)

185
    metrics = llm.get_metrics()
186

187
188
189
190
191
192
    total_num_output_tokens = sum(
        len(output.outputs[0].token_ids) for output in outputs
    )
    num_drafts = 0
    num_draft_tokens = 0
    num_accepted_tokens = 0
193
194
195
196
197
    acceptance_counts = [0] * args.num_spec_tokens
    for metric in metrics:
        if metric.name == "vllm:spec_decode_num_drafts":
            assert isinstance(metric, Counter)
            num_drafts += metric.value
198
199
200
        elif metric.name == "vllm:spec_decode_num_draft_tokens":
            assert isinstance(metric, Counter)
            num_draft_tokens += metric.value
201
202
        elif metric.name == "vllm:spec_decode_num_accepted_tokens":
            assert isinstance(metric, Counter)
203
            num_accepted_tokens += metric.value
204
205
206
207
208
209
        elif metric.name == "vllm:spec_decode_num_accepted_tokens_per_pos":
            assert isinstance(metric, Vector)
            for pos in range(len(metric.values)):
                acceptance_counts[pos] += metric.values[pos]

    print("-" * 50)
210
211
212
213
214
215
    print(f"total_num_output_tokens: {total_num_output_tokens}")
    print(f"num_drafts: {num_drafts}")
    print(f"num_draft_tokens: {num_draft_tokens}")
    print(f"num_accepted_tokens: {num_accepted_tokens}")
    acceptance_length = 1 + (num_accepted_tokens / num_drafts) if num_drafts > 0 else 1
    print(f"mean acceptance length: {acceptance_length:.2f}")
216
217
218
219
    print("-" * 50)

    # print acceptance at each token position
    for i in range(len(acceptance_counts)):
220
221
        acceptance_rate = acceptance_counts[i] / num_drafts if num_drafts > 0 else 0
        print(f"acceptance at token {i}: {acceptance_rate:.2f}")
222

223
224
    return acceptance_length

225
226

if __name__ == "__main__":
227
    args = parse_args()
228
229
    args.enable_multimodal_chat = args.backend == "openai-chat"

230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
    acceptance_length = main(args)

    if args.test:
        # takes ~30s to run on 1xH100
        assert args.method in ["eagle", "eagle3"]
        assert args.tp == 1
        assert args.num_spec_tokens == 3
        assert args.dataset_name == "hf"
        assert args.dataset_path == "philschmid/mt-bench"
        assert args.num_prompts == 80
        assert args.temp == 0
        assert args.top_p == 1.0
        assert args.top_k == -1
        assert args.enable_chunked_prefill

        # check acceptance length is within 2% of expected value
        rtol = 0.02
        expected_acceptance_length = 2.296 if args.method == "eagle" else 2.811

        assert (
            acceptance_length <= (1 + rtol) * expected_acceptance_length
            and acceptance_length >= (1 - rtol) * expected_acceptance_length
        ), (
            f"acceptance_length {acceptance_length} is not "
            f"within {rtol * 100}% of {expected_acceptance_length}"
        )

        print(
            f"Test passed! Expected AL: "
            f"{expected_acceptance_length}, got {acceptance_length}"
        )