spec_decode.py 5.33 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
# SPDX-License-Identifier: Apache-2.0

from transformers import AutoTokenizer

from vllm import LLM, SamplingParams
from vllm.benchmarks.datasets import add_dataset_parser, get_samples
from vllm.v1.metrics.reader import Counter, Vector

try:
    from vllm.utils import FlexibleArgumentParser
except ImportError:
    from argparse import ArgumentParser as FlexibleArgumentParser


def parse_args():
    parser = FlexibleArgumentParser()
    add_dataset_parser(parser)
    parser.add_argument(
        "--dataset",
        type=str,
        default="./examples/data/gsm8k.jsonl",
        help="downloaded from the eagle repo "
        "https://github.com/SafeAILab/EAGLE/blob/main/eagle/data/",
    )
    parser.add_argument(
        "--method", type=str, default="eagle", choices=["ngram", "eagle", "eagle3"]
    )
    parser.add_argument("--max-num-seqs", type=int, default=8)
    parser.add_argument("--num-spec-tokens", type=int, default=2)
    parser.add_argument("--prompt-lookup-max", type=int, default=5)
    parser.add_argument("--prompt-lookup-min", type=int, default=2)
    parser.add_argument("--tp", type=int, default=1)
    parser.add_argument("--draft-tp", type=int, default=1)
    parser.add_argument("--enforce-eager", action="store_true")
    parser.add_argument("--enable-chunked-prefill", action="store_true")
    parser.add_argument("--max-num-batched-tokens", type=int, default=2048)
    parser.add_argument("--temp", type=float, default=0)
    parser.add_argument("--top-p", type=float, default=1.0)
    parser.add_argument("--top-k", type=int, default=-1)
    parser.add_argument("--print-output", action="store_true")
    parser.add_argument("--output-len", type=int, default=256)
42
43
44
    parser.add_argument("--model-dir", type=str, default=None)
    parser.add_argument("--eagle-dir", type=str, default=None)
    parser.add_argument("--max-model-len", type=int, default=2048)
45
46
47
48
49
50
51
    return parser.parse_args()


def main():
    args = parse_args()
    args.endpoint_type = "openai-chat"

52
53
54
    model_dir = args.model_dir
    if args.model_dir is None:
        model_dir = "meta-llama/Llama-3.1-8B-Instruct"
55
56
57
58
59
60
61
62
63
    tokenizer = AutoTokenizer.from_pretrained(model_dir)

    prompts = get_samples(args, tokenizer)
    # add_special_tokens is False to avoid adding bos twice when using chat templates
    prompt_ids = [
        tokenizer.encode(prompt.prompt, add_special_tokens=False) for prompt in prompts
    ]

    if args.method == "eagle" or args.method == "eagle3":
64
65
        eagle_dir = args.eagle_dir
        if args.method == "eagle" and eagle_dir is None:
66
            eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
67
68

        elif args.method == "eagle3" and eagle_dir is None:
69
70
71
72
73
74
            eagle_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
        speculative_config = {
            "method": args.method,
            "model": eagle_dir,
            "num_speculative_tokens": args.num_spec_tokens,
            "draft_tensor_parallel_size": args.draft_tp,
75
            "max_model_len": args.max_model_len,
76
77
78
79
80
81
82
        }
    elif args.method == "ngram":
        speculative_config = {
            "method": "ngram",
            "num_speculative_tokens": args.num_spec_tokens,
            "prompt_lookup_max": args.prompt_lookup_max,
            "prompt_lookup_min": args.prompt_lookup_min,
83
            "max_model_len": args.max_model_len,
84
85
86
87
88
89
90
91
92
93
94
        }
    else:
        raise ValueError(f"unknown method: {args.method}")

    llm = LLM(
        model=model_dir,
        trust_remote_code=True,
        tensor_parallel_size=args.tp,
        enable_chunked_prefill=args.enable_chunked_prefill,
        max_num_batched_tokens=args.max_num_batched_tokens,
        enforce_eager=args.enforce_eager,
95
        max_model_len=args.max_model_len,
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
        max_num_seqs=args.max_num_seqs,
        gpu_memory_utilization=0.8,
        speculative_config=speculative_config,
        disable_log_stats=False,
    )

    sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len)
    outputs = llm.generate(prompt_token_ids=prompt_ids, sampling_params=sampling_params)

    # print the generated text
    if args.print_output:
        for output in outputs:
            print("-" * 50)
            print(f"prompt: {output.prompt}")
            print(f"generated text: {output.outputs[0].text}")
            print("-" * 50)

    try:
        metrics = llm.get_metrics()
    except AssertionError:
        print("Metrics are not supported in the V0 engine.")
        return

    num_drafts = num_accepted = 0
    acceptance_counts = [0] * args.num_spec_tokens
    for metric in metrics:
        if metric.name == "vllm:spec_decode_num_drafts":
            assert isinstance(metric, Counter)
            num_drafts += metric.value
        elif metric.name == "vllm:spec_decode_num_accepted_tokens":
            assert isinstance(metric, Counter)
            num_accepted += metric.value
        elif metric.name == "vllm:spec_decode_num_accepted_tokens_per_pos":
            assert isinstance(metric, Vector)
            for pos in range(len(metric.values)):
                acceptance_counts[pos] += metric.values[pos]

    print("-" * 50)
    print(f"mean acceptance length: {1 + (num_accepted / num_drafts):.2f}")
    print("-" * 50)

    # print acceptance at each token position
    for i in range(len(acceptance_counts)):
        print(f"acceptance at token {i}:{acceptance_counts[i] / num_drafts:.2f}")


if __name__ == "__main__":
    main()