bench_sglang.py 4.29 KB
Newer Older
Liangsheng Yin's avatar
Liangsheng Yin committed
1
2
3
4
5
6
7
8
9
import argparse
import json
import time

import sglang as sgl
from sglang.test.test_utils import (
    add_common_sglang_args_and_parse,
    select_sglang_backend,
)
10
from sglang.utils import dump_state_text, read_jsonl
Liangsheng Yin's avatar
Liangsheng Yin committed
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31

# there are some FSM bugs with json regex converted from pydantic model
# here use a string regex instead
# regex_string = build_regex_from_object(HarryPoterRole)
character_regex = (
    r"""\{\n"""
    + r"""    "name": "[\w\d\s]{1,16}",\n"""
    + r"""    "house": "(Gryffindor|Slytherin|Ravenclaw|Hufflepuff)",\n"""
    + r"""    "blood status": "(Pure-blood|Half-blood|Muggle-born)",\n"""
    + r"""    "occupation": "(student|teacher|auror|ministry of magic|death eater|order of the phoenix)",\n"""
    + r"""    "wand": \{\n"""
    + r"""        "wood": "[\w\d\s]{1,16}",\n"""
    + r"""        "core": "[\w\d\s]{1,16}",\n"""
    + r"""        "length": [0-9]{1,2}\.[0-9]{0,2}\n"""
    + r"""    \},\n"""
    + r"""    "alive": "(Alive|Deceased)",\n"""
    + r"""    "patronus": "[\w\d\s]{1,16}",\n"""
    + r"""    "bogart": "[\w\d\s]{1,16}"\n"""
    + r"""\}"""
)

32
33
34
35
36
37
38
39
40
41
city_regex = (
    r"""\{\n"""
    + r"""  "name": "[\w\d\s]{1,16}",\n"""
    + r"""  "country": "[\w\d\s]{1,16}",\n"""
    + r"""  "latitude": [-+]?[0-9]*\.?[0-9]{0,2},\n"""
    + r"""  "population": [-+]?[0-9]{1,9},\n"""
    + r"""  "top 3 landmarks": \["[\w\d\s]{1,16}", "[\w\d\s]{1,16}", "[\w\d\s]{1,16}"\]\n"""
    + r"""\}"""
)

Liangsheng Yin's avatar
Liangsheng Yin committed
42
43
44
# fmt: off
@sgl.function
def character_gen(s, name):
45
    s += name + " is a character in Harry Potter. Please fill in the following information about this character.\n"
Liangsheng Yin's avatar
Liangsheng Yin committed
46
47
48
    s += sgl.gen("json_output", max_tokens=256, regex=character_regex)
# fmt: on

49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
# fmt: off
@sgl.function
def city_gen(s, document):
    s += "Please extract the information of a city from the following wikipedia page.\n"
    s += "Page begin.\n" + document + "Page end.\n"
    s += "Here is the name, country, and symbol of the city in JSON format.\n"
    s += sgl.gen("json_output",max_tokens=256, regex=city_regex)
# fmt: on


def bench_city_doc(args):
    arguments = []
    for line in read_jsonl(args.data_path):
        arguments.append({"document": line["document"]})
    arguments = arguments[: args.num_jsons]

    # Select backend
    backend = select_sglang_backend(args)
    sgl.set_default_backend(backend)

    # Run requests
    tic = time.time()
    states = city_gen.run_batch(
        arguments,
        temperature=0,
        num_threads=args.parallel,
Liangsheng Yin's avatar
Liangsheng Yin committed
75
        progress_bar=True,
76
77
78
79
80
    )
    latency = time.time() - tic

    return states, latency

Liangsheng Yin's avatar
Liangsheng Yin committed
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98

def bench_character(args):
    arguments = []
    with open(args.data_path, "r") as f:
        for line in f:
            arguments.append({"name": line.strip()})
    arguments = arguments[: args.num_jsons]

    # Select backend
    backend = select_sglang_backend(args)
    sgl.set_default_backend(backend)

    # Run requests
    tic = time.time()
    states = character_gen.run_batch(
        arguments,
        temperature=0,
        num_threads=args.parallel,
Liangsheng Yin's avatar
Liangsheng Yin committed
99
        progress_bar=True,
Liangsheng Yin's avatar
Liangsheng Yin committed
100
101
102
103
104
105
106
    )
    latency = time.time() - tic

    return states, latency


def main(args):
107
108
109
110
111
112
    if args.mode == "character":
        args.data_path = "dataset.txt"
        states, latency = bench_character(args)
    elif args.mode == "city":
        args.data_path = "questions.jsonl"
        states, latency = bench_city_doc(args)
Liangsheng Yin's avatar
Liangsheng Yin committed
113
114
115
116
117

    # Compute accuracy
    print(f"Latency: {latency:.3f}")

    # Write results
118
119
    dump_state_text(f"tmp_output_{args.backend}_{args.mode}.txt", states)
    with open(f"{args.backend}_{args.mode}.json", "w") as fout:
Liangsheng Yin's avatar
Liangsheng Yin committed
120
121
122
123
124
        for state in states:
            fout.write(state["json_output"] + "\n")

    with open(args.result_file, "a") as fout:
        value = {
Liangsheng Yin's avatar
Liangsheng Yin committed
125
            "task": "json_jump_forward",
Liangsheng Yin's avatar
Liangsheng Yin committed
126
127
128
            "backend": args.backend,
            "latency": round(latency, 3),
            "num_jsons": args.num_jsons,
129
            "mode": args.mode,
Liangsheng Yin's avatar
Liangsheng Yin committed
130
131
132
133
134
135
136
            "parallel": args.parallel,
        }
        fout.write(json.dumps(value) + "\n")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
137
    parser.add_argument("--data-path", type=str)
Liangsheng Yin's avatar
Liangsheng Yin committed
138
    parser.add_argument("--num-jsons", type=int, default=50)
139
140
141
    parser.add_argument(
        "--mode", type=str, default="character", choices=["character", "city"]
    )
Liangsheng Yin's avatar
Liangsheng Yin committed
142
143
    args = add_common_sglang_args_and_parse(parser)
    main(args)