bench_sglang.py 4.32 KB
Newer Older
Liangsheng Yin's avatar
Liangsheng Yin committed
1
2
3
4
5
6
7
8
9
import argparse
import json
import time

import sglang as sgl
from sglang.test.test_utils import (
    add_common_sglang_args_and_parse,
    select_sglang_backend,
)
10
from sglang.utils import dump_state_text, read_jsonl
Liangsheng Yin's avatar
Liangsheng Yin committed
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31

# there are some FSM bugs with json regex converted from pydantic model
# here use a string regex instead
# regex_string = build_regex_from_object(HarryPoterRole)
character_regex = (
    r"""\{\n"""
    + r"""    "name": "[\w\d\s]{1,16}",\n"""
    + r"""    "house": "(Gryffindor|Slytherin|Ravenclaw|Hufflepuff)",\n"""
    + r"""    "blood status": "(Pure-blood|Half-blood|Muggle-born)",\n"""
    + r"""    "occupation": "(student|teacher|auror|ministry of magic|death eater|order of the phoenix)",\n"""
    + r"""    "wand": \{\n"""
    + r"""        "wood": "[\w\d\s]{1,16}",\n"""
    + r"""        "core": "[\w\d\s]{1,16}",\n"""
    + r"""        "length": [0-9]{1,2}\.[0-9]{0,2}\n"""
    + r"""    \},\n"""
    + r"""    "alive": "(Alive|Deceased)",\n"""
    + r"""    "patronus": "[\w\d\s]{1,16}",\n"""
    + r"""    "bogart": "[\w\d\s]{1,16}"\n"""
    + r"""\}"""
)

32
33
34
35
36
37
38
39
40
41
city_regex = (
    r"""\{\n"""
    + r"""  "name": "[\w\d\s]{1,16}",\n"""
    + r"""  "country": "[\w\d\s]{1,16}",\n"""
    + r"""  "latitude": [-+]?[0-9]*\.?[0-9]{0,2},\n"""
    + r"""  "population": [-+]?[0-9]{1,9},\n"""
    + r"""  "top 3 landmarks": \["[\w\d\s]{1,16}", "[\w\d\s]{1,16}", "[\w\d\s]{1,16}"\]\n"""
    + r"""\}"""
)

Liangsheng Yin's avatar
Liangsheng Yin committed
42
43
44
# fmt: off
@sgl.function
def character_gen(s, name):
45
    s += name + " is a character in Harry Potter. Please fill in the following information about this character.\n"
Liangsheng Yin's avatar
Liangsheng Yin committed
46
47
48
    s += sgl.gen("json_output", max_tokens=256, regex=character_regex)
# fmt: on

49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
# fmt: off
@sgl.function
def city_gen(s, document):
    s += "Please extract the information of a city from the following wikipedia page.\n"
    s += "Page begin.\n" + document + "Page end.\n"
    s += "Here is the name, country, and symbol of the city in JSON format.\n"
    s += sgl.gen("json_output",max_tokens=256, regex=city_regex)
# fmt: on


def bench_city_doc(args):
    arguments = []
    for line in read_jsonl(args.data_path):
        arguments.append({"document": line["document"]})
    arguments = arguments[: args.num_jsons]

    # Select backend
    backend = select_sglang_backend(args)
    sgl.set_default_backend(backend)

    # Run requests
    tic = time.time()
    states = city_gen.run_batch(
        arguments,
        temperature=0,
        num_threads=args.parallel,
        progress_bar=(args.parallel == 1),
    )
    latency = time.time() - tic

    return states, latency

Liangsheng Yin's avatar
Liangsheng Yin committed
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106

def bench_character(args):
    arguments = []
    with open(args.data_path, "r") as f:
        for line in f:
            arguments.append({"name": line.strip()})
    arguments = arguments[: args.num_jsons]

    # Select backend
    backend = select_sglang_backend(args)
    sgl.set_default_backend(backend)

    # Run requests
    tic = time.time()
    states = character_gen.run_batch(
        arguments,
        temperature=0,
        num_threads=args.parallel,
        progress_bar=(args.parallel == 1),
    )
    latency = time.time() - tic

    return states, latency


def main(args):
107
108
109
110
111
112
    if args.mode == "character":
        args.data_path = "dataset.txt"
        states, latency = bench_character(args)
    elif args.mode == "city":
        args.data_path = "questions.jsonl"
        states, latency = bench_city_doc(args)
Liangsheng Yin's avatar
Liangsheng Yin committed
113
114
115
116
117

    # Compute accuracy
    print(f"Latency: {latency:.3f}")

    # Write results
118
119
    dump_state_text(f"tmp_output_{args.backend}_{args.mode}.txt", states)
    with open(f"{args.backend}_{args.mode}.json", "w") as fout:
Liangsheng Yin's avatar
Liangsheng Yin committed
120
121
122
123
124
125
126
127
128
        for state in states:
            fout.write(state["json_output"] + "\n")

    with open(args.result_file, "a") as fout:
        value = {
            "task": "json_fast_forward",
            "backend": args.backend,
            "latency": round(latency, 3),
            "num_jsons": args.num_jsons,
129
            "mode": args.mode,
Liangsheng Yin's avatar
Liangsheng Yin committed
130
131
132
133
134
135
136
            "parallel": args.parallel,
        }
        fout.write(json.dumps(value) + "\n")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
137
    parser.add_argument("--data-path", type=str)
Liangsheng Yin's avatar
Liangsheng Yin committed
138
    parser.add_argument("--num-jsons", type=int, default=50)
139
140
141
    parser.add_argument(
        "--mode", type=str, default="character", choices=["character", "city"]
    )
Liangsheng Yin's avatar
Liangsheng Yin committed
142
143
    args = add_common_sglang_args_and_parse(parser)
    main(args)