spec_decode.py 9.13 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7

from transformers import AutoTokenizer

from vllm import LLM, SamplingParams
from vllm.benchmarks.datasets import add_dataset_parser, get_samples
8
from vllm.inputs import TokensPrompt
9
10
11
from vllm.v1.metrics.reader import Counter, Vector

try:
12
    from vllm.utils.argparse_utils import FlexibleArgumentParser
13
14
15
16
except ImportError:
    from argparse import ArgumentParser as FlexibleArgumentParser


17
18
QUESTION = "What is the content of each image?"
IMAGE_URLS = [
19
20
21
22
23
24
25
26
27
28
29
30
    "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/duck.jpg",
    "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/lion.jpg",
    "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/flycatcher.jpeg",
    "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/somefish.jpg",
    "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/starfish.jpg",
    "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/snail.jpg",
    "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/thistle.jpg",
    "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/husky.jpg",
    "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/orangetabbycat.jpg",
    "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/guineapig.jpg",
    "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/rabbit.jpg",
    "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/horsepony.jpg",
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
]


def get_custom_mm_prompts(num_prompts):
    prompts = []
    for url in IMAGE_URLS:
        prompts.append(
            [
                {"type": "image_url", "image_url": {"url": url}},
                {"type": "text", "text": QUESTION},
            ]
        )
    if num_prompts > len(IMAGE_URLS):
        prompts = prompts * (num_prompts // len(IMAGE_URLS) + 1)

    return [[{"role": "user", "content": prompt}] for prompt in prompts[:num_prompts]]


49
50
51
def parse_args():
    parser = FlexibleArgumentParser()
    add_dataset_parser(parser)
52
    parser.add_argument("--test", action="store_true")
53
    parser.add_argument(
54
        "--method",
55
        type=str,
56
        default="eagle",
57
        choices=["ngram", "eagle", "eagle3", "mtp"],
58
59
60
61
62
63
64
    )
    parser.add_argument("--num-spec-tokens", type=int, default=2)
    parser.add_argument("--prompt-lookup-max", type=int, default=5)
    parser.add_argument("--prompt-lookup-min", type=int, default=2)
    parser.add_argument("--tp", type=int, default=1)
    parser.add_argument("--enforce-eager", action="store_true")
    parser.add_argument("--enable-chunked-prefill", action="store_true")
65
    parser.add_argument("--max-model-len", type=int, default=16384)
66
67
68
69
70
    parser.add_argument("--temp", type=float, default=0)
    parser.add_argument("--top-p", type=float, default=1.0)
    parser.add_argument("--top-k", type=int, default=-1)
    parser.add_argument("--print-output", action="store_true")
    parser.add_argument("--output-len", type=int, default=256)
71
72
    parser.add_argument("--model-dir", type=str, default=None)
    parser.add_argument("--eagle-dir", type=str, default=None)
73
    parser.add_argument("--custom-mm-prompts", action="store_true")
74
75
76
    return parser.parse_args()


77
def main(args):
78
79
    args.endpoint_type = "openai-chat"

80
81
    model_dir = args.model_dir
    if args.model_dir is None:
82
83
84
85
86
87
        if args.custom_mm_prompts:
            raise ValueError(
                "custom_mm_prompts requires mm based models"
                "default llama3.1-8b-instruct is not mm based"
                "please specify model_dir to give a mm based model"
            )
88
        model_dir = "meta-llama/Llama-3.1-8B-Instruct"
89
    tokenizer = AutoTokenizer.from_pretrained(model_dir)
90
    args.custom_skip_chat_template = True
91

92
93
94
95
96
97
98
99
100
101
    if not args.custom_mm_prompts:
        prompts = get_samples(args, tokenizer)
        # add_special_tokens is False to avoid adding bos twice
        # when using chat templates
        prompt_ids = [
            tokenizer.encode(prompt.prompt, add_special_tokens=False)
            for prompt in prompts
        ]
    else:
        prompts = get_custom_mm_prompts(args.num_prompts)
102
103

    if args.method == "eagle" or args.method == "eagle3":
104
105
        eagle_dir = args.eagle_dir
        if args.method == "eagle" and eagle_dir is None:
106
            eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
107
108

        elif args.method == "eagle3" and eagle_dir is None:
109
110
111
112
113
114
115
116
117
118
119
120
121
            eagle_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
        speculative_config = {
            "method": args.method,
            "model": eagle_dir,
            "num_speculative_tokens": args.num_spec_tokens,
        }
    elif args.method == "ngram":
        speculative_config = {
            "method": "ngram",
            "num_speculative_tokens": args.num_spec_tokens,
            "prompt_lookup_max": args.prompt_lookup_max,
            "prompt_lookup_min": args.prompt_lookup_min,
        }
122
    elif args.method == "mtp":
123
        speculative_config = {
124
            "method": "mtp",
125
126
            "num_speculative_tokens": args.num_spec_tokens,
        }
127
128
129
130
131
132
133
134
135
    else:
        raise ValueError(f"unknown method: {args.method}")

    llm = LLM(
        model=model_dir,
        trust_remote_code=True,
        tensor_parallel_size=args.tp,
        enable_chunked_prefill=args.enable_chunked_prefill,
        enforce_eager=args.enforce_eager,
136
        gpu_memory_utilization=0.9,
137
138
        speculative_config=speculative_config,
        disable_log_stats=False,
139
        max_model_len=args.max_model_len,
140
141
        limit_mm_per_prompt={"image": 5},
        disable_chunked_mm_input=True,
142
143
144
    )

    sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len)
145
146
    if not args.custom_mm_prompts:
        outputs = llm.generate(
147
            [TokensPrompt(prompt_token_ids=x) for x in prompt_ids],
148
            sampling_params=sampling_params,
149
150
151
        )
    else:
        outputs = llm.chat(prompts, sampling_params=sampling_params)
152
153
154

    # print the generated text
    if args.print_output:
155
        for i, output in enumerate(outputs):
156
            print("-" * 50)
157
158
159
160
            if not args.custom_mm_prompts:
                print(f"prompt: {prompts[i].prompt}")
            else:
                print(f"prompt: {prompts[i]}")
161
162
163
            print(f"generated text: {output.outputs[0].text}")
            print("-" * 50)

164
    metrics = llm.get_metrics()
165

166
167
168
169
170
171
    total_num_output_tokens = sum(
        len(output.outputs[0].token_ids) for output in outputs
    )
    num_drafts = 0
    num_draft_tokens = 0
    num_accepted_tokens = 0
172
173
174
175
176
    acceptance_counts = [0] * args.num_spec_tokens
    for metric in metrics:
        if metric.name == "vllm:spec_decode_num_drafts":
            assert isinstance(metric, Counter)
            num_drafts += metric.value
177
178
179
        elif metric.name == "vllm:spec_decode_num_draft_tokens":
            assert isinstance(metric, Counter)
            num_draft_tokens += metric.value
180
181
        elif metric.name == "vllm:spec_decode_num_accepted_tokens":
            assert isinstance(metric, Counter)
182
            num_accepted_tokens += metric.value
183
184
185
186
187
188
        elif metric.name == "vllm:spec_decode_num_accepted_tokens_per_pos":
            assert isinstance(metric, Vector)
            for pos in range(len(metric.values)):
                acceptance_counts[pos] += metric.values[pos]

    print("-" * 50)
189
190
191
192
193
194
    print(f"total_num_output_tokens: {total_num_output_tokens}")
    print(f"num_drafts: {num_drafts}")
    print(f"num_draft_tokens: {num_draft_tokens}")
    print(f"num_accepted_tokens: {num_accepted_tokens}")
    acceptance_length = 1 + (num_accepted_tokens / num_drafts) if num_drafts > 0 else 1
    print(f"mean acceptance length: {acceptance_length:.2f}")
195
196
197
198
    print("-" * 50)

    # print acceptance at each token position
    for i in range(len(acceptance_counts)):
199
200
        acceptance_rate = acceptance_counts[i] / num_drafts if num_drafts > 0 else 0
        print(f"acceptance at token {i}: {acceptance_rate:.2f}")
201

202
203
    return acceptance_length

204
205

if __name__ == "__main__":
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
    args = parse_args()
    acceptance_length = main(args)

    if args.test:
        # takes ~30s to run on 1xH100
        assert args.method in ["eagle", "eagle3"]
        assert args.tp == 1
        assert args.num_spec_tokens == 3
        assert args.dataset_name == "hf"
        assert args.dataset_path == "philschmid/mt-bench"
        assert args.num_prompts == 80
        assert args.temp == 0
        assert args.top_p == 1.0
        assert args.top_k == -1
        assert args.enable_chunked_prefill

        # check acceptance length is within 2% of expected value
        rtol = 0.02
        expected_acceptance_length = 2.296 if args.method == "eagle" else 2.811

        assert (
            acceptance_length <= (1 + rtol) * expected_acceptance_length
            and acceptance_length >= (1 - rtol) * expected_acceptance_length
        ), (
            f"acceptance_length {acceptance_length} is not "
            f"within {rtol * 100}% of {expected_acceptance_length}"
        )

        print(
            f"Test passed! Expected AL: "
            f"{expected_acceptance_length}, got {acceptance_length}"
        )