eagle.py 3.82 KB
Newer Older
1
2
3
4
5
6
7
8
9
# SPDX-License-Identifier: Apache-2.0
import argparse
import json
import os

from transformers import AutoTokenizer

from vllm import LLM, SamplingParams

Reid's avatar
Reid committed
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29

def load_prompts(dataset_path, num_prompts):
    if os.path.exists(dataset_path):
        prompts = []
        try:
            with open(dataset_path) as f:
                for line in f:
                    data = json.loads(line)
                    prompts.append(data["turns"][0])
        except Exception as e:
            print(f"Error reading dataset: {e}")
            return []
    else:
        prompts = [
            "The future of AI is", "The president of the United States is"
        ]

    return prompts[:num_prompts]


30
def parse_args():
Reid's avatar
Reid committed
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--dataset",
        type=str,
        default="./examples/data/gsm8k.jsonl",
        help="downloaded from the eagle repo " \
        "https://github.com/SafeAILab/EAGLE/blob/main/eagle/data/"
    )
    parser.add_argument("--max_num_seqs", type=int, default=8)
    parser.add_argument("--num_prompts", type=int, default=80)
    parser.add_argument("--num_spec_tokens", type=int, default=2)
    parser.add_argument("--tp", type=int, default=1)
    parser.add_argument("--draft_tp", type=int, default=1)
    parser.add_argument("--enforce_eager", action='store_true')
    parser.add_argument("--enable_chunked_prefill", action='store_true')
    parser.add_argument("--max_num_batched_tokens", type=int, default=2048)
    parser.add_argument("--temp", type=float, default=0)
48
49
50
51
52
53
    return parser.parse_args()


def main():

    args = parse_args()
Reid's avatar
Reid committed
54

55
56
    model_dir = "meta-llama/Llama-3.1-8B-Instruct"
    eagle_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
Reid's avatar
Reid committed
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83

    max_model_len = 2048

    tokenizer = AutoTokenizer.from_pretrained(model_dir)

    prompts = load_prompts(args.dataset, args.num_prompts)

    prompt_ids = [
        tokenizer.apply_chat_template([{
            "role": "user",
            "content": prompt
        }],
                                      add_generation_prompt=True)
        for prompt in prompts
    ]

    llm = LLM(
        model=model_dir,
        trust_remote_code=True,
        tensor_parallel_size=args.tp,
        enable_chunked_prefill=args.enable_chunked_prefill,
        max_num_batched_tokens=args.max_num_batched_tokens,
        enforce_eager=args.enforce_eager,
        max_model_len=max_model_len,
        max_num_seqs=args.max_num_seqs,
        gpu_memory_utilization=0.8,
        speculative_config={
84
            "method": "eagle3" if "eagle3" in eagle_dir.lower() else "eagle",
Reid's avatar
Reid committed
85
86
87
88
89
90
91
92
93
94
95
96
97
            "model": eagle_dir,
            "num_speculative_tokens": args.num_spec_tokens,
            "draft_tensor_parallel_size": args.draft_tp,
            "max_model_len": max_model_len,
        },
        disable_log_stats=False,
    )

    sampling_params = SamplingParams(temperature=args.temp, max_tokens=256)

    outputs = llm.generate(prompt_token_ids=prompt_ids,
                           sampling_params=sampling_params)

98
99
100
    if not hasattr(outputs, "metrics") or outputs.metrics is None:
        return

Reid's avatar
Reid committed
101
102
103
104
105
106
107
108
109
110
111
112
113
114
    # calculate the average number of accepted tokens per forward pass, +1 is
    # to account for the token from the target model that's always going to be
    # accepted
    acceptance_counts = [0] * (args.num_spec_tokens + 1)
    for output in outputs:
        for step, count in enumerate(
                output.metrics.spec_token_acceptance_counts):
            acceptance_counts[step] += count

    print("-" * 50)
    print(f"mean acceptance length: \
        {sum(acceptance_counts) / acceptance_counts[0]:.2f}")
    print("-" * 50)

115
116
117
118
119
    # print acceptance at each token position
    for i in range(len(acceptance_counts)):
        print(f"acceptance at token {i}:"
              f"{acceptance_counts[i] / (acceptance_counts[0]):.2f}")

Reid's avatar
Reid committed
120
121
122

if __name__ == "__main__":
    main()