"tests/plugins_tests/test_stats_logger_plugins.py" did not exist on "c2bba690658823be5ec6f1742eb10294d4fa2479"
spec_decode.py 10.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10

from transformers import AutoTokenizer

from vllm import LLM, SamplingParams
from vllm.benchmarks.datasets import add_dataset_parser, get_samples
from vllm.v1.metrics.reader import Counter, Vector

try:
11
    from vllm.utils.argparse_utils import FlexibleArgumentParser
12
13
14
15
except ImportError:
    from argparse import ArgumentParser as FlexibleArgumentParser


16
17
QUESTION = "What is the content of each image?"
IMAGE_URLS = [
18
19
20
21
22
23
24
25
26
27
28
29
    "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/duck.jpg",
    "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/lion.jpg",
    "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/flycatcher.jpeg",
    "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/somefish.jpg",
    "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/starfish.jpg",
    "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/snail.jpg",
    "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/thistle.jpg",
    "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/husky.jpg",
    "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/orangetabbycat.jpg",
    "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/guineapig.jpg",
    "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/rabbit.jpg",
    "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/horsepony.jpg",
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
]


def get_custom_mm_prompts(num_prompts):
    prompts = []
    for url in IMAGE_URLS:
        prompts.append(
            [
                {"type": "image_url", "image_url": {"url": url}},
                {"type": "text", "text": QUESTION},
            ]
        )
    if num_prompts > len(IMAGE_URLS):
        prompts = prompts * (num_prompts // len(IMAGE_URLS) + 1)

    return [[{"role": "user", "content": prompt}] for prompt in prompts[:num_prompts]]


48
49
50
def parse_args():
    parser = FlexibleArgumentParser()
    add_dataset_parser(parser)
51
    parser.add_argument("--test", action="store_true")
52
    parser.add_argument(
53
        "--method",
54
        type=str,
55
        default="eagle",
56
        choices=["ngram", "eagle", "eagle3", "mtp", "draft_model"],
57
    )
58
    parser.add_argument("--backend", type=str, default="openai")
59
60
61
62
63
64
    parser.add_argument("--num-spec-tokens", type=int, default=2)
    parser.add_argument("--prompt-lookup-max", type=int, default=5)
    parser.add_argument("--prompt-lookup-min", type=int, default=2)
    parser.add_argument("--tp", type=int, default=1)
    parser.add_argument("--enforce-eager", action="store_true")
    parser.add_argument("--enable-chunked-prefill", action="store_true")
65
    parser.add_argument("--max-model-len", type=int, default=16384)
66
67
68
69
70
    parser.add_argument("--temp", type=float, default=0)
    parser.add_argument("--top-p", type=float, default=1.0)
    parser.add_argument("--top-k", type=int, default=-1)
    parser.add_argument("--print-output", action="store_true")
    parser.add_argument("--output-len", type=int, default=256)
71
72
    parser.add_argument("--model-dir", type=str, default=None)
    parser.add_argument("--eagle-dir", type=str, default=None)
73
    parser.add_argument("--draft-model", type=str, default=None)
74
    parser.add_argument("--custom-mm-prompts", action="store_true")
75
76
77
    parser.add_argument("--gpu-memory-utilization", type=float, default=0.9)
    parser.add_argument("--disable-padded-drafter-batch", action="store_true")
    parser.add_argument("--max-num-seqs", type=int, default=None)
78
    parser.add_argument("--parallel-drafting", action="store_true")
79
    parser.add_argument("--allowed-local-media-path", type=str, default="")
80
81
82
    return parser.parse_args()


83
def main(args):
84
85
    model_dir = args.model_dir
    if args.model_dir is None:
86
87
88
89
90
91
        if args.custom_mm_prompts:
            raise ValueError(
                "custom_mm_prompts requires mm based models"
                "default llama3.1-8b-instruct is not mm based"
                "please specify model_dir to give a mm based model"
            )
92
        model_dir = "meta-llama/Llama-3.1-8B-Instruct"
93
94
    tokenizer = AutoTokenizer.from_pretrained(model_dir)

95
96
    if args.custom_mm_prompts:
        prompts = llm_prompts = get_custom_mm_prompts(args.num_prompts)
97
    else:
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
        prompts = get_samples(args, tokenizer)
        if args.enable_multimodal_chat:
            llm_prompts = [p.prompt for p in prompts]
        else:
            # add_special_tokens is False to avoid adding bos twice
            # when using chat templates
            llm_prompts = [
                {
                    "prompt_token_ids": tokenizer.encode(
                        prompt.prompt, add_special_tokens=False
                    ),
                    "multi_modal_data": prompt.multi_modal_data,
                }
                for prompt in prompts
            ]
113
    if args.method == "eagle" or args.method == "eagle3":
114
115
        eagle_dir = args.eagle_dir
        if args.method == "eagle" and eagle_dir is None:
116
            eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
117
118

        elif args.method == "eagle3" and eagle_dir is None:
119
120
121
122
123
            eagle_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
        speculative_config = {
            "method": args.method,
            "model": eagle_dir,
            "num_speculative_tokens": args.num_spec_tokens,
124
            "disable_padded_drafter_batch": args.disable_padded_drafter_batch,
125
            "parallel_drafting": args.parallel_drafting,
126
127
128
129
130
131
132
133
        }
    elif args.method == "ngram":
        speculative_config = {
            "method": "ngram",
            "num_speculative_tokens": args.num_spec_tokens,
            "prompt_lookup_max": args.prompt_lookup_max,
            "prompt_lookup_min": args.prompt_lookup_min,
        }
134
135
136
137
138
139
140
141
    elif args.method == "draft_model":
        assert args.draft_model is not None and args.draft_model != ""
        speculative_config = {
            "method": args.method,
            "model": args.draft_model,
            "num_speculative_tokens": args.num_spec_tokens,
            "enforce_eager": args.enforce_eager,
            "max_model_len": args.max_model_len,
142
            "parallel_drafting": args.parallel_drafting,
143
        }
144
    elif args.method == "mtp":
145
        speculative_config = {
146
            "method": "mtp",
147
148
            "num_speculative_tokens": args.num_spec_tokens,
        }
149
150
151
152
153
154
155
156
157
    else:
        raise ValueError(f"unknown method: {args.method}")

    llm = LLM(
        model=model_dir,
        trust_remote_code=True,
        tensor_parallel_size=args.tp,
        enable_chunked_prefill=args.enable_chunked_prefill,
        enforce_eager=args.enforce_eager,
158
        gpu_memory_utilization=args.gpu_memory_utilization,
159
160
        speculative_config=speculative_config,
        disable_log_stats=False,
161
        max_model_len=args.max_model_len,
162
163
        limit_mm_per_prompt={"image": 5},
        disable_chunked_mm_input=True,
164
        max_num_seqs=args.max_num_seqs,
165
        allowed_local_media_path=args.allowed_local_media_path,
166
167
168
    )

    sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len)
169
170
171
    if args.backend == "openai-chat":
        outputs = llm.chat(llm_prompts, sampling_params=sampling_params)
    else:
172
        outputs = llm.generate(
173
            llm_prompts,
174
            sampling_params=sampling_params,
175
        )
176
177
178

    # print the generated text
    if args.print_output:
179
        for i, output in enumerate(outputs):
180
            print("-" * 50)
181
182
183
184
            if not args.custom_mm_prompts:
                print(f"prompt: {prompts[i].prompt}")
            else:
                print(f"prompt: {prompts[i]}")
185
186
187
            print(f"generated text: {output.outputs[0].text}")
            print("-" * 50)

188
    metrics = llm.get_metrics()
189

190
191
192
193
194
195
    total_num_output_tokens = sum(
        len(output.outputs[0].token_ids) for output in outputs
    )
    num_drafts = 0
    num_draft_tokens = 0
    num_accepted_tokens = 0
196
197
198
199
200
    acceptance_counts = [0] * args.num_spec_tokens
    for metric in metrics:
        if metric.name == "vllm:spec_decode_num_drafts":
            assert isinstance(metric, Counter)
            num_drafts += metric.value
201
202
203
        elif metric.name == "vllm:spec_decode_num_draft_tokens":
            assert isinstance(metric, Counter)
            num_draft_tokens += metric.value
204
205
        elif metric.name == "vllm:spec_decode_num_accepted_tokens":
            assert isinstance(metric, Counter)
206
            num_accepted_tokens += metric.value
207
208
209
210
211
212
        elif metric.name == "vllm:spec_decode_num_accepted_tokens_per_pos":
            assert isinstance(metric, Vector)
            for pos in range(len(metric.values)):
                acceptance_counts[pos] += metric.values[pos]

    print("-" * 50)
213
214
215
216
217
218
    print(f"total_num_output_tokens: {total_num_output_tokens}")
    print(f"num_drafts: {num_drafts}")
    print(f"num_draft_tokens: {num_draft_tokens}")
    print(f"num_accepted_tokens: {num_accepted_tokens}")
    acceptance_length = 1 + (num_accepted_tokens / num_drafts) if num_drafts > 0 else 1
    print(f"mean acceptance length: {acceptance_length:.2f}")
219
220
221
222
    print("-" * 50)

    # print acceptance at each token position
    for i in range(len(acceptance_counts)):
223
224
        acceptance_rate = acceptance_counts[i] / num_drafts if num_drafts > 0 else 0
        print(f"acceptance at token {i}: {acceptance_rate:.2f}")
225

226
227
    return acceptance_length

228
229

if __name__ == "__main__":
230
    args = parse_args()
231
232
    args.enable_multimodal_chat = args.backend == "openai-chat"

233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
    acceptance_length = main(args)

    if args.test:
        # takes ~30s to run on 1xH100
        assert args.method in ["eagle", "eagle3"]
        assert args.tp == 1
        assert args.num_spec_tokens == 3
        assert args.dataset_name == "hf"
        assert args.dataset_path == "philschmid/mt-bench"
        assert args.num_prompts == 80
        assert args.temp == 0
        assert args.top_p == 1.0
        assert args.top_k == -1
        assert args.enable_chunked_prefill

        # check acceptance length is within 2% of expected value
        rtol = 0.02
        expected_acceptance_length = 2.296 if args.method == "eagle" else 2.811

        assert (
            acceptance_length <= (1 + rtol) * expected_acceptance_length
            and acceptance_length >= (1 - rtol) * expected_acceptance_length
        ), (
            f"acceptance_length {acceptance_length} is not "
            f"within {rtol * 100}% of {expected_acceptance_length}"
        )

        print(
            f"Test passed! Expected AL: "
            f"{expected_acceptance_length}, got {acceptance_length}"
        )