spec_decode.py 10.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7

from transformers import AutoTokenizer

from vllm import LLM, SamplingParams
from vllm.benchmarks.datasets import add_dataset_parser, get_samples
8
from vllm.utils.argparse_utils import FlexibleArgumentParser
9
10
from vllm.v1.metrics.reader import Counter, Vector

11
12
QUESTION = "What is the content of each image?"
IMAGE_URLS = [
13
14
15
16
17
18
19
20
21
22
23
24
    "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/duck.jpg",
    "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/lion.jpg",
    "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/flycatcher.jpeg",
    "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/somefish.jpg",
    "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/starfish.jpg",
    "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/snail.jpg",
    "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/thistle.jpg",
    "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/husky.jpg",
    "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/orangetabbycat.jpg",
    "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/guineapig.jpg",
    "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/rabbit.jpg",
    "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/horsepony.jpg",
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
]


def get_custom_mm_prompts(num_prompts):
    prompts = []
    for url in IMAGE_URLS:
        prompts.append(
            [
                {"type": "image_url", "image_url": {"url": url}},
                {"type": "text", "text": QUESTION},
            ]
        )
    if num_prompts > len(IMAGE_URLS):
        prompts = prompts * (num_prompts // len(IMAGE_URLS) + 1)

    return [[{"role": "user", "content": prompt}] for prompt in prompts[:num_prompts]]


43
44
45
def parse_args():
    parser = FlexibleArgumentParser()
    add_dataset_parser(parser)
46
    parser.add_argument("--test", action="store_true")
47
    parser.add_argument(
48
        "--method",
49
        type=str,
50
        default="eagle",
51
        choices=["ngram", "eagle", "eagle3", "mtp", "draft_model"],
52
    )
53
    parser.add_argument("--backend", type=str, default="openai")
54
55
56
57
58
59
    parser.add_argument("--num-spec-tokens", type=int, default=2)
    parser.add_argument("--prompt-lookup-max", type=int, default=5)
    parser.add_argument("--prompt-lookup-min", type=int, default=2)
    parser.add_argument("--tp", type=int, default=1)
    parser.add_argument("--enforce-eager", action="store_true")
    parser.add_argument("--enable-chunked-prefill", action="store_true")
60
    parser.add_argument("--max-model-len", type=int, default=16384)
61
62
63
64
65
    parser.add_argument("--temp", type=float, default=0)
    parser.add_argument("--top-p", type=float, default=1.0)
    parser.add_argument("--top-k", type=int, default=-1)
    parser.add_argument("--print-output", action="store_true")
    parser.add_argument("--output-len", type=int, default=256)
66
67
    parser.add_argument("--model-dir", type=str, default=None)
    parser.add_argument("--eagle-dir", type=str, default=None)
68
    parser.add_argument("--draft-model", type=str, default=None)
69
    parser.add_argument("--custom-mm-prompts", action="store_true")
70
71
72
    parser.add_argument("--gpu-memory-utilization", type=float, default=0.9)
    parser.add_argument("--disable-padded-drafter-batch", action="store_true")
    parser.add_argument("--max-num-seqs", type=int, default=None)
73
    parser.add_argument("--parallel-drafting", action="store_true")
74
    parser.add_argument("--allowed-local-media-path", type=str, default="")
75
76
77
    return parser.parse_args()


78
def main(args):
79
80
    model_dir = args.model_dir
    if args.model_dir is None:
81
82
83
84
85
86
        if args.custom_mm_prompts:
            raise ValueError(
                "custom_mm_prompts requires mm based models"
                "default llama3.1-8b-instruct is not mm based"
                "please specify model_dir to give a mm based model"
            )
87
        model_dir = "meta-llama/Llama-3.1-8B-Instruct"
88
89
    tokenizer = AutoTokenizer.from_pretrained(model_dir)

90
91
    if args.custom_mm_prompts:
        prompts = llm_prompts = get_custom_mm_prompts(args.num_prompts)
92
    else:
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
        prompts = get_samples(args, tokenizer)
        if args.enable_multimodal_chat:
            llm_prompts = [p.prompt for p in prompts]
        else:
            # add_special_tokens is False to avoid adding bos twice
            # when using chat templates
            llm_prompts = [
                {
                    "prompt_token_ids": tokenizer.encode(
                        prompt.prompt, add_special_tokens=False
                    ),
                    "multi_modal_data": prompt.multi_modal_data,
                }
                for prompt in prompts
            ]
108
    if args.method == "eagle" or args.method == "eagle3":
109
110
        eagle_dir = args.eagle_dir
        if args.method == "eagle" and eagle_dir is None:
111
            eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
112
113

        elif args.method == "eagle3" and eagle_dir is None:
114
115
116
117
118
            eagle_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
        speculative_config = {
            "method": args.method,
            "model": eagle_dir,
            "num_speculative_tokens": args.num_spec_tokens,
119
            "disable_padded_drafter_batch": args.disable_padded_drafter_batch,
120
            "parallel_drafting": args.parallel_drafting,
121
122
123
124
125
126
127
128
        }
    elif args.method == "ngram":
        speculative_config = {
            "method": "ngram",
            "num_speculative_tokens": args.num_spec_tokens,
            "prompt_lookup_max": args.prompt_lookup_max,
            "prompt_lookup_min": args.prompt_lookup_min,
        }
129
130
131
132
133
134
135
136
    elif args.method == "draft_model":
        assert args.draft_model is not None and args.draft_model != ""
        speculative_config = {
            "method": args.method,
            "model": args.draft_model,
            "num_speculative_tokens": args.num_spec_tokens,
            "enforce_eager": args.enforce_eager,
            "max_model_len": args.max_model_len,
137
            "parallel_drafting": args.parallel_drafting,
138
        }
139
    elif args.method == "mtp":
140
        speculative_config = {
141
            "method": "mtp",
142
143
            "num_speculative_tokens": args.num_spec_tokens,
        }
144
145
146
147
148
149
150
151
152
    else:
        raise ValueError(f"unknown method: {args.method}")

    llm = LLM(
        model=model_dir,
        trust_remote_code=True,
        tensor_parallel_size=args.tp,
        enable_chunked_prefill=args.enable_chunked_prefill,
        enforce_eager=args.enforce_eager,
153
        gpu_memory_utilization=args.gpu_memory_utilization,
154
155
        speculative_config=speculative_config,
        disable_log_stats=False,
156
        max_model_len=args.max_model_len,
157
158
        limit_mm_per_prompt={"image": 5},
        disable_chunked_mm_input=True,
159
        max_num_seqs=args.max_num_seqs,
160
        allowed_local_media_path=args.allowed_local_media_path,
161
162
163
    )

    sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len)
164
165
166
    if args.backend == "openai-chat":
        outputs = llm.chat(llm_prompts, sampling_params=sampling_params)
    else:
167
        outputs = llm.generate(
168
            llm_prompts,
169
            sampling_params=sampling_params,
170
        )
171
172
173

    # print the generated text
    if args.print_output:
174
        for i, output in enumerate(outputs):
175
            print("-" * 50)
176
177
178
179
            if not args.custom_mm_prompts:
                print(f"prompt: {prompts[i].prompt}")
            else:
                print(f"prompt: {prompts[i]}")
180
181
182
            print(f"generated text: {output.outputs[0].text}")
            print("-" * 50)

183
    metrics = llm.get_metrics()
184

185
186
187
188
189
190
    total_num_output_tokens = sum(
        len(output.outputs[0].token_ids) for output in outputs
    )
    num_drafts = 0
    num_draft_tokens = 0
    num_accepted_tokens = 0
191
192
193
194
195
    acceptance_counts = [0] * args.num_spec_tokens
    for metric in metrics:
        if metric.name == "vllm:spec_decode_num_drafts":
            assert isinstance(metric, Counter)
            num_drafts += metric.value
196
197
198
        elif metric.name == "vllm:spec_decode_num_draft_tokens":
            assert isinstance(metric, Counter)
            num_draft_tokens += metric.value
199
200
        elif metric.name == "vllm:spec_decode_num_accepted_tokens":
            assert isinstance(metric, Counter)
201
            num_accepted_tokens += metric.value
202
203
204
205
206
207
        elif metric.name == "vllm:spec_decode_num_accepted_tokens_per_pos":
            assert isinstance(metric, Vector)
            for pos in range(len(metric.values)):
                acceptance_counts[pos] += metric.values[pos]

    print("-" * 50)
208
209
210
211
212
213
    print(f"total_num_output_tokens: {total_num_output_tokens}")
    print(f"num_drafts: {num_drafts}")
    print(f"num_draft_tokens: {num_draft_tokens}")
    print(f"num_accepted_tokens: {num_accepted_tokens}")
    acceptance_length = 1 + (num_accepted_tokens / num_drafts) if num_drafts > 0 else 1
    print(f"mean acceptance length: {acceptance_length:.2f}")
214
215
216
217
    print("-" * 50)

    # print acceptance at each token position
    for i in range(len(acceptance_counts)):
218
219
        acceptance_rate = acceptance_counts[i] / num_drafts if num_drafts > 0 else 0
        print(f"acceptance at token {i}: {acceptance_rate:.2f}")
220

221
222
    return acceptance_length

223
224

if __name__ == "__main__":
225
    args = parse_args()
226
227
    args.enable_multimodal_chat = args.backend == "openai-chat"

228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
    acceptance_length = main(args)

    if args.test:
        # takes ~30s to run on 1xH100
        assert args.method in ["eagle", "eagle3"]
        assert args.tp == 1
        assert args.num_spec_tokens == 3
        assert args.dataset_name == "hf"
        assert args.dataset_path == "philschmid/mt-bench"
        assert args.num_prompts == 80
        assert args.temp == 0
        assert args.top_p == 1.0
        assert args.top_k == -1
        assert args.enable_chunked_prefill

        # check acceptance length is within 2% of expected value
        rtol = 0.02
        expected_acceptance_length = 2.296 if args.method == "eagle" else 2.811

        assert (
            acceptance_length <= (1 + rtol) * expected_acceptance_length
            and acceptance_length >= (1 - rtol) * expected_acceptance_length
        ), (
            f"acceptance_length {acceptance_length} is not "
            f"within {rtol * 100}% of {expected_acceptance_length}"
        )

        print(
            f"Test passed! Expected AL: "
            f"{expected_acceptance_length}, got {acceptance_length}"
        )