bench_sglang.py 4.74 KB
Newer Older
Lianmin Zheng's avatar
Lianmin Zheng committed
1
2
3
4
5
6
7
8
9
10
import argparse
import json
import os
import time

import numpy as np
import pandas as pd
import tiktoken
from tqdm import tqdm

Liangsheng Yin's avatar
Liangsheng Yin committed
11
12
13
14
from sglang.test.test_utils import (
    add_common_sglang_args_and_parse,
    select_sglang_backend,
)
Lianmin Zheng's avatar
Lianmin Zheng committed
15
16
17
18
19
20
21
22
23
24
25
26
27

choices = ["A", "B", "C", "D"]

tokenizer = tiktoken.encoding_for_model("gpt-3.5-turbo")


def format_subject(subject):
    l = subject.split("_")
    s = ""
    for entry in l:
        s += " " + entry
    return s

Liangsheng Yin's avatar
Liangsheng Yin committed
28

Lianmin Zheng's avatar
Lianmin Zheng committed
29
30
31
32
def format_example(df, idx, include_answer=True):
    prompt = df.iloc[idx, 0]
    k = df.shape[1] - 2
    for j in range(k):
Liangsheng Yin's avatar
Liangsheng Yin committed
33
        prompt += "\n{}. {}".format(choices[j], df.iloc[idx, j + 1])
Lianmin Zheng's avatar
Lianmin Zheng committed
34
35
36
37
38
    prompt += "\nAnswer:"
    if include_answer:
        prompt += " {}\n\n".format(df.iloc[idx, k + 1])
    return prompt

Liangsheng Yin's avatar
Liangsheng Yin committed
39

Lianmin Zheng's avatar
Lianmin Zheng committed
40
def gen_prompt(train_df, subject, k=-1):
Liangsheng Yin's avatar
Liangsheng Yin committed
41
42
43
    prompt = "The following are multiple choice questions (with answers) about{}.\n\n".format(
        format_subject(subject)
    )
Lianmin Zheng's avatar
Lianmin Zheng committed
44
45
46
47
48
49
    if k == -1:
        k = train_df.shape[0]
    for i in range(k):
        prompt += format_example(train_df, i)
    return prompt

Liangsheng Yin's avatar
Liangsheng Yin committed
50

Lianmin Zheng's avatar
Lianmin Zheng committed
51
52
53
54
55
56
57
58
59
60
61
62
63
64
def evaluate(args, subject, dev_df, test_df):
    prompts = []
    labels = []

    k = args.ntrain
    few_shot_examples = gen_prompt(dev_df, subject, k)
    while len(tokenizer.encode(few_shot_examples)) > 1536:
        k -= 1
        few_shot_examples = gen_prompt(dev_df, subject, k)

    for i in range(test_df.shape[0]):
        prompt_end = format_example(test_df, i, include_answer=False)
        prompts.append(prompt_end)

Liangsheng Yin's avatar
Liangsheng Yin committed
65
        label = test_df.iloc[i, test_df.shape[1] - 1]
Lianmin Zheng's avatar
Lianmin Zheng committed
66
67
68
69
70
71
72
73
74
        labels.append(label)

    arguments = [{"question": p} for p in prompts]

    #####################################
    ######### SGL Program Begin #########
    #####################################

    import sglang as sgl
75
76

    if args.backend.startswith("gpt-"):
Liangsheng Yin's avatar
Liangsheng Yin committed
77

78
79
80
81
        @sgl.function
        def few_shot_mmlu(s, examples, question):
            s += sgl.user(examples + question)
            s += sgl.assistant(sgl.gen("answer"))
Liangsheng Yin's avatar
Liangsheng Yin committed
82

83
    else:
Liangsheng Yin's avatar
Liangsheng Yin committed
84

85
86
87
        @sgl.function
        def few_shot_mmlu(s, examples, question):
            s += examples + question + sgl.gen("answer")
Lianmin Zheng's avatar
Lianmin Zheng committed
88
89
90
91
92
93
94
95
96
97

    #####################################
    ########## SGL Program End ##########
    #####################################

    # Select backend
    backend = select_sglang_backend(args)

    tic = time.time()
    states = few_shot_mmlu.bind(examples=few_shot_examples).run_batch(
Liangsheng Yin's avatar
Liangsheng Yin committed
98
99
100
101
102
103
104
105
106
        arguments,
        temperature=0,
        max_new_tokens=1,
        backend=backend,
        num_threads=args.parallel,
    )
    preds = [
        s["answer"].strip()[0] if len(s["answer"].strip()) > 0 else "" for s in states
    ]
Lianmin Zheng's avatar
Lianmin Zheng committed
107
108
109
110
111
112
    latency = time.time() - tic

    cors = [pred == label for pred, label in zip(preds, labels)]
    acc = np.mean(cors)
    cors = np.array(cors)

Liangsheng Yin's avatar
Liangsheng Yin committed
113
114
115
116
117
    print(
        "Average accuracy {:.3f}, latency {:.2f}, #q: {} - {}".format(
            acc, latency, len(prompts), subject
        )
    )
Lianmin Zheng's avatar
Lianmin Zheng committed
118
119
120
121
122

    return cors, acc, latency


def main(args):
Liangsheng Yin's avatar
Liangsheng Yin committed
123
124
125
126
127
128
129
    subjects = sorted(
        [
            f.split("_test.csv")[0]
            for f in os.listdir(os.path.join(args.data_dir, "test"))
            if "_test.csv" in f
        ]
    )
Lianmin Zheng's avatar
Lianmin Zheng committed
130
131
132
133
134

    all_cors = []
    all_latencies = []
    num_requests = 0

Liangsheng Yin's avatar
Liangsheng Yin committed
135
136
137
138
139
140
141
    for subject in tqdm(subjects[: args.nsub]):
        dev_df = pd.read_csv(
            os.path.join(args.data_dir, "dev", subject + "_dev.csv"), header=None
        )[: args.ntrain]
        test_df = pd.read_csv(
            os.path.join(args.data_dir, "test", subject + "_test.csv"), header=None
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165

        cors, acc, latency = evaluate(args, subject, dev_df, test_df)
        all_cors.append(cors)
        all_latencies.append(latency)
        num_requests += len(test_df)

    total_latency = np.sum(all_latencies)
    print("Total latency: {:.3f}".format(total_latency))

    weighted_acc = np.mean(np.concatenate(all_cors))
    print("Average accuracy: {:.3f}".format(weighted_acc))

    # Write results
    with open(args.result_file, "a") as fout:
        value = {
            "task": "mmlu",
            "backend": args.backend,
            "num_gpus": 1,
            "latency": round(total_latency, 3),
            "accuracy": round(weighted_acc, 3),
            "num_requests": num_requests,
            "other": {
                "nsub": args.nsub,
                "parallel": args.parallel,
Liangsheng Yin's avatar
Liangsheng Yin committed
166
            },
Lianmin Zheng's avatar
Lianmin Zheng committed
167
168
169
170
171
172
173
174
175
176
177
178
        }
        fout.write(json.dumps(value) + "\n")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--ntrain", "-k", type=int, default=5)
    parser.add_argument("--data_dir", "-d", type=str, default="data")
    parser.add_argument("--save_dir", "-s", type=str, default="results")
    parser.add_argument("--nsub", type=int, default=60)
    args = add_common_sglang_args_and_parse(parser)
    main(args)