spec_decode.py 9.94 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7

from transformers import AutoTokenizer

from vllm import LLM, SamplingParams
from vllm.benchmarks.datasets import add_dataset_parser, get_samples
8
from vllm.inputs import TokensPrompt
9
10
11
from vllm.v1.metrics.reader import Counter, Vector

try:
12
    from vllm.utils.argparse_utils import FlexibleArgumentParser
13
14
15
16
except ImportError:
    from argparse import ArgumentParser as FlexibleArgumentParser


17
18
QUESTION = "What is the content of each image?"
IMAGE_URLS = [
19
20
21
22
23
24
25
26
27
28
29
30
    "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/duck.jpg",
    "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/lion.jpg",
    "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/flycatcher.jpeg",
    "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/somefish.jpg",
    "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/starfish.jpg",
    "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/snail.jpg",
    "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/thistle.jpg",
    "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/husky.jpg",
    "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/orangetabbycat.jpg",
    "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/guineapig.jpg",
    "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/rabbit.jpg",
    "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/horsepony.jpg",
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
]


def get_custom_mm_prompts(num_prompts):
    prompts = []
    for url in IMAGE_URLS:
        prompts.append(
            [
                {"type": "image_url", "image_url": {"url": url}},
                {"type": "text", "text": QUESTION},
            ]
        )
    if num_prompts > len(IMAGE_URLS):
        prompts = prompts * (num_prompts // len(IMAGE_URLS) + 1)

    return [[{"role": "user", "content": prompt}] for prompt in prompts[:num_prompts]]


49
50
51
def parse_args():
    parser = FlexibleArgumentParser()
    add_dataset_parser(parser)
52
    parser.add_argument("--test", action="store_true")
53
    parser.add_argument(
54
        "--method",
55
        type=str,
56
        default="eagle",
57
        choices=["ngram", "eagle", "eagle3", "mtp", "draft_model"],
58
59
60
61
62
63
64
    )
    parser.add_argument("--num-spec-tokens", type=int, default=2)
    parser.add_argument("--prompt-lookup-max", type=int, default=5)
    parser.add_argument("--prompt-lookup-min", type=int, default=2)
    parser.add_argument("--tp", type=int, default=1)
    parser.add_argument("--enforce-eager", action="store_true")
    parser.add_argument("--enable-chunked-prefill", action="store_true")
65
    parser.add_argument("--max-model-len", type=int, default=16384)
66
67
68
69
70
    parser.add_argument("--temp", type=float, default=0)
    parser.add_argument("--top-p", type=float, default=1.0)
    parser.add_argument("--top-k", type=int, default=-1)
    parser.add_argument("--print-output", action="store_true")
    parser.add_argument("--output-len", type=int, default=256)
71
72
    parser.add_argument("--model-dir", type=str, default=None)
    parser.add_argument("--eagle-dir", type=str, default=None)
73
    parser.add_argument("--draft-model", type=str, default=None)
74
    parser.add_argument("--custom-mm-prompts", action="store_true")
75
76
77
    parser.add_argument("--gpu-memory-utilization", type=float, default=0.9)
    parser.add_argument("--disable-padded-drafter-batch", action="store_true")
    parser.add_argument("--max-num-seqs", type=int, default=None)
78
79
80
    return parser.parse_args()


81
def main(args):
82
83
    args.endpoint_type = "openai-chat"

84
85
    model_dir = args.model_dir
    if args.model_dir is None:
86
87
88
89
90
91
        if args.custom_mm_prompts:
            raise ValueError(
                "custom_mm_prompts requires mm based models"
                "default llama3.1-8b-instruct is not mm based"
                "please specify model_dir to give a mm based model"
            )
92
        model_dir = "meta-llama/Llama-3.1-8B-Instruct"
93
    tokenizer = AutoTokenizer.from_pretrained(model_dir)
94
    args.custom_skip_chat_template = True
95

96
97
98
99
100
101
102
103
104
105
    if not args.custom_mm_prompts:
        prompts = get_samples(args, tokenizer)
        # add_special_tokens is False to avoid adding bos twice
        # when using chat templates
        prompt_ids = [
            tokenizer.encode(prompt.prompt, add_special_tokens=False)
            for prompt in prompts
        ]
    else:
        prompts = get_custom_mm_prompts(args.num_prompts)
106
107

    if args.method == "eagle" or args.method == "eagle3":
108
109
        eagle_dir = args.eagle_dir
        if args.method == "eagle" and eagle_dir is None:
110
            eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
111
112

        elif args.method == "eagle3" and eagle_dir is None:
113
114
115
116
117
            eagle_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
        speculative_config = {
            "method": args.method,
            "model": eagle_dir,
            "num_speculative_tokens": args.num_spec_tokens,
118
            "disable_padded_drafter_batch": args.disable_padded_drafter_batch,
119
120
121
122
123
124
125
126
        }
    elif args.method == "ngram":
        speculative_config = {
            "method": "ngram",
            "num_speculative_tokens": args.num_spec_tokens,
            "prompt_lookup_max": args.prompt_lookup_max,
            "prompt_lookup_min": args.prompt_lookup_min,
        }
127
128
129
130
131
132
133
134
135
    elif args.method == "draft_model":
        assert args.draft_model is not None and args.draft_model != ""
        speculative_config = {
            "method": args.method,
            "model": args.draft_model,
            "num_speculative_tokens": args.num_spec_tokens,
            "enforce_eager": args.enforce_eager,
            "max_model_len": args.max_model_len,
        }
136
    elif args.method == "mtp":
137
        speculative_config = {
138
            "method": "mtp",
139
140
            "num_speculative_tokens": args.num_spec_tokens,
        }
141
142
143
144
145
146
147
148
149
    else:
        raise ValueError(f"unknown method: {args.method}")

    llm = LLM(
        model=model_dir,
        trust_remote_code=True,
        tensor_parallel_size=args.tp,
        enable_chunked_prefill=args.enable_chunked_prefill,
        enforce_eager=args.enforce_eager,
150
        gpu_memory_utilization=args.gpu_memory_utilization,
151
152
        speculative_config=speculative_config,
        disable_log_stats=False,
153
        max_model_len=args.max_model_len,
154
155
        limit_mm_per_prompt={"image": 5},
        disable_chunked_mm_input=True,
156
        max_num_seqs=args.max_num_seqs,
157
158
159
    )

    sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len)
160
161
    if not args.custom_mm_prompts:
        outputs = llm.generate(
162
            [TokensPrompt(prompt_token_ids=x) for x in prompt_ids],
163
            sampling_params=sampling_params,
164
165
166
        )
    else:
        outputs = llm.chat(prompts, sampling_params=sampling_params)
167
168
169

    # print the generated text
    if args.print_output:
170
        for i, output in enumerate(outputs):
171
            print("-" * 50)
172
173
174
175
            if not args.custom_mm_prompts:
                print(f"prompt: {prompts[i].prompt}")
            else:
                print(f"prompt: {prompts[i]}")
176
177
178
            print(f"generated text: {output.outputs[0].text}")
            print("-" * 50)

179
    metrics = llm.get_metrics()
180

181
182
183
184
185
186
    total_num_output_tokens = sum(
        len(output.outputs[0].token_ids) for output in outputs
    )
    num_drafts = 0
    num_draft_tokens = 0
    num_accepted_tokens = 0
187
188
189
190
191
    acceptance_counts = [0] * args.num_spec_tokens
    for metric in metrics:
        if metric.name == "vllm:spec_decode_num_drafts":
            assert isinstance(metric, Counter)
            num_drafts += metric.value
192
193
194
        elif metric.name == "vllm:spec_decode_num_draft_tokens":
            assert isinstance(metric, Counter)
            num_draft_tokens += metric.value
195
196
        elif metric.name == "vllm:spec_decode_num_accepted_tokens":
            assert isinstance(metric, Counter)
197
            num_accepted_tokens += metric.value
198
199
200
201
202
203
        elif metric.name == "vllm:spec_decode_num_accepted_tokens_per_pos":
            assert isinstance(metric, Vector)
            for pos in range(len(metric.values)):
                acceptance_counts[pos] += metric.values[pos]

    print("-" * 50)
204
205
206
207
208
209
    print(f"total_num_output_tokens: {total_num_output_tokens}")
    print(f"num_drafts: {num_drafts}")
    print(f"num_draft_tokens: {num_draft_tokens}")
    print(f"num_accepted_tokens: {num_accepted_tokens}")
    acceptance_length = 1 + (num_accepted_tokens / num_drafts) if num_drafts > 0 else 1
    print(f"mean acceptance length: {acceptance_length:.2f}")
210
211
212
213
    print("-" * 50)

    # print acceptance at each token position
    for i in range(len(acceptance_counts)):
214
215
        acceptance_rate = acceptance_counts[i] / num_drafts if num_drafts > 0 else 0
        print(f"acceptance at token {i}: {acceptance_rate:.2f}")
216

217
218
    return acceptance_length

219
220

if __name__ == "__main__":
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
    args = parse_args()
    acceptance_length = main(args)

    if args.test:
        # takes ~30s to run on 1xH100
        assert args.method in ["eagle", "eagle3"]
        assert args.tp == 1
        assert args.num_spec_tokens == 3
        assert args.dataset_name == "hf"
        assert args.dataset_path == "philschmid/mt-bench"
        assert args.num_prompts == 80
        assert args.temp == 0
        assert args.top_p == 1.0
        assert args.top_k == -1
        assert args.enable_chunked_prefill

        # check acceptance length is within 2% of expected value
        rtol = 0.02
        expected_acceptance_length = 2.296 if args.method == "eagle" else 2.811

        assert (
            acceptance_length <= (1 + rtol) * expected_acceptance_length
            and acceptance_length >= (1 - rtol) * expected_acceptance_length
        ), (
            f"acceptance_length {acceptance_length} is not "
            f"within {rtol * 100}% of {expected_acceptance_length}"
        )

        print(
            f"Test passed! Expected AL: "
            f"{expected_acceptance_length}, got {acceptance_length}"
        )