bench_other.py 9.34 KB
Newer Older
Liangsheng Yin's avatar
Liangsheng Yin committed
1
2
3
4
5
6
7
import argparse
import json
import time
from concurrent.futures import ThreadPoolExecutor
from functools import partial

import guidance
Liangsheng Yin's avatar
Liangsheng Yin committed
8
9
from tqdm import tqdm

Liangsheng Yin's avatar
Liangsheng Yin committed
10
from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate
11
from sglang.utils import dump_state_text, read_jsonl
Liangsheng Yin's avatar
Liangsheng Yin committed
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32

# there are some FSM bugs with json regex converted from pydantic model
# here use a string regex instead
# regex_string = build_regex_from_object(HarryPoterRole)
character_regex = (
    r"""\{\n"""
    + r"""    "name": "[\w\d\s]{1,16}",\n"""
    + r"""    "house": "(Gryffindor|Slytherin|Ravenclaw|Hufflepuff)",\n"""
    + r"""    "blood status": "(Pure-blood|Half-blood|Muggle-born)",\n"""
    + r"""    "occupation": "(student|teacher|auror|ministry of magic|death eater|order of the phoenix)",\n"""
    + r"""    "wand": \{\n"""
    + r"""        "wood": "[\w\d\s]{1,16}",\n"""
    + r"""        "core": "[\w\d\s]{1,16}",\n"""
    + r"""        "length": [0-9]{1,2}\.[0-9]{0,2}\n"""
    + r"""    \},\n"""
    + r"""    "alive": "(Alive|Deceased)",\n"""
    + r"""    "patronus": "[\w\d\s]{1,16}",\n"""
    + r"""    "bogart": "[\w\d\s]{1,16}"\n"""
    + r"""\}"""
)

33
34
35
36
37
38
39
40
41
42
city_regex = (
    r"""\{\n"""
    + r"""  "name": "[\w\d\s]{1,16}",\n"""
    + r"""  "country": "[\w\d\s]{1,16}",\n"""
    + r"""  "latitude": [-+]?[0-9]*\.?[0-9]{0,2},\n"""
    + r"""  "population": [-+]?[0-9]{1,9},\n"""
    + r"""  "top 3 landmarks": \["[\w\d\s]{1,16}", "[\w\d\s]{1,16}", "[\w\d\s]{1,16}"\]\n"""
    + r"""\}"""
)

Liangsheng Yin's avatar
Liangsheng Yin committed
43
44
# fmt: off
def character_gen(name, generate):
45
    s = name + " is a character in Harry Potter. Please fill in the following information about this character.\n"
Liangsheng Yin's avatar
Liangsheng Yin committed
46
47
48
49
    s += generate(s, max_tokens=256, regex=character_regex)
    return s
# fmt: on

50
51
52
53
54
55
56
57
58
# fmt: off
def city_gen(document, generate):
    s = "Please extract the information of a city from the following wikipedia page.\n"
    s += "Page begin.\n" + document + "Page end.\n"
    s += "Here is the name, country, and symbol of the city in JSON format.\n"
    s += generate(s, max_tokens=256, regex=city_regex)
    return s
# fmt: on

Liangsheng Yin's avatar
Liangsheng Yin committed
59
60
61
62
63
64

@guidance
def character_maker(lm, name):
    regex_str_no_quote = r"[\w\d\s]+"
    regex_float = r"[0-9]+\.[0-9]+"
    lm += f"""\
Lianmin Zheng's avatar
Lianmin Zheng committed
65
    {name} is a character in Harry Potter. Please fill in the following information about this character.
Liangsheng Yin's avatar
Liangsheng Yin committed
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
    {{
        "name": "{guidance.gen("name", max_tokens=16, regex=regex_str_no_quote)}",
        "house": "{guidance.select(options=['Gryffindor', 'Slytherin', 'Ravenclaw', 'Hufflepuff'], name='house')}",
        "blood status": "{guidance.select(options=['Pure-blood', 'Half-blood', 'Muggle-born'], name='blood status')}",
        "occupation": "{guidance.select(options=['student', 'teacher', 'auror', 'ministry of magic', 'death eater', 'order of the phoenix'], name='occupation')}",
        "wand": {{
            "wood": "{guidance.gen("wood", max_tokens=16, regex=regex_str_no_quote)}",
            "core": "{guidance.gen('core', max_tokens=16, regex=regex_str_no_quote)}",
            "length": {guidance.gen('length', max_tokens=10, regex=regex_float)}
        }},
        "alive": "{guidance.select(options=['Alive', 'Deceased'], name='alive')}",
        "patronus": "{guidance.gen('patronus', max_tokens=16, regex=regex_str_no_quote)}",
        "bogart": "{guidance.gen('bogart', max_tokens=16, regex=regex_str_no_quote)}"
    }}
    """

    return lm


Liangsheng Yin's avatar
Liangsheng Yin committed
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
async def call_generate_lmql(
    prompt, temperature, max_tokens, regex, max_len=4096, model=None, **kwargs
):
    assert model is not None
    import lmql

    @lmql.query(model=model)
    async def program(question, max_tokens, regex):
        '''lmql
        """{question}[ANSWER]""" where len(TOKENS(ANSWER)) < max_tokens and REGEX(ANSWER, regex)
        return ANSWER
        '''

    return await program(
        question=prompt,
        temperature=temperature,
        max_tokens=max_tokens,
        max_len=max_len,
        regex=regex,
        **kwargs,
    )


108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
@guidance
def city_maker(lm, document):
    regex_str_no_quote = r"[\w\d\s]+"
    regex_float = r"[0-9]+\.[0-9]+"
    lm += f"""\
    Please extract the information of a city from the following wikipedia page.
    Page begin.
    {document}
    Page end.
    Here is the name, country, and symbol of the city in JSON format.
    {{
        "name": "{guidance.gen("name", max_tokens=16, regex=regex_str_no_quote)}",
        "country": "{guidance.gen("country", max_tokens=16, regex=regex_str_no_quote)}",
        "latitude": {guidance.gen("latitude", max_tokens=10, regex=regex_float)},
        "population": {guidance.gen("population", max_tokens=10, regex=r"[0-9]+")},
        "top 3 landmarks": [
            "{guidance.gen("landmark1", max_tokens=16, regex=regex_str_no_quote)}", "{guidance.gen("landmark2", max_tokens=16, regex=regex_str_no_quote)}", "{guidance.gen("landmark3", max_tokens=16, regex=regex_str_no_quote)}"
        ]
    }}
    """

    return lm


def bench_character(args):
Liangsheng Yin's avatar
Liangsheng Yin committed
133
134
135
136
137
138
139
140
141
    arguments = []
    with open(args.data_path, "r") as f:
        for line in f:
            arguments.append({"name": line.strip()})
    arguments = arguments[: args.num_jsons]

    states = [None] * len(arguments)

    # Select backend
Liangsheng Yin's avatar
Liangsheng Yin committed
142
143
    if args.backend == "outlines":
        call_generate = partial(get_call_generate(args), temperature=0)
Liangsheng Yin's avatar
Liangsheng Yin committed
144

Liangsheng Yin's avatar
Liangsheng Yin committed
145
146
        def get_one_answer(i):
            states[i] = character_gen(**arguments[i], generate=call_generate)
Liangsheng Yin's avatar
Liangsheng Yin committed
147
148
149

    elif args.backend == "guidance":
        model = guidance.models.LlamaCpp(
Liangsheng Yin's avatar
Liangsheng Yin committed
150
            args.model_path,
Liangsheng Yin's avatar
Liangsheng Yin committed
151
            n_gpu_layers=-1,
Liangsheng Yin's avatar
Liangsheng Yin committed
152
            n_ctx=args.n_ctx,
Liangsheng Yin's avatar
Liangsheng Yin committed
153
154
        )

Liangsheng Yin's avatar
Liangsheng Yin committed
155
        def get_one_answer(i):
Liangsheng Yin's avatar
Liangsheng Yin committed
156
157
158
            lm = model + character_maker(**arguments[i])
            states[i] = lm

Liangsheng Yin's avatar
Liangsheng Yin committed
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
    elif args.backend == "lmql":
        import asyncio

        import lmql

        model = lmql.model(args.model_path, endpoint=f"{args.host}:{args.port}")
        call_generate = partial(
            call_generate_lmql,
            model=model,
            max_tokens=256,
            regex=character_regex,
        )

        async def get_one_answer_async(i):
            states[i] = await call_generate(prompt=arguments[i]["name"], temperature=0)

Liangsheng Yin's avatar
Liangsheng Yin committed
175
176
177
    else:
        raise ValueError(f"Invalid backend: {args.backend}")

178
    tic = time.perf_counter()
Liangsheng Yin's avatar
Liangsheng Yin committed
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193

    if args.backend != "lmql":
        if args.parallel == 1:
            for i in tqdm(range(len(arguments))):
                get_one_answer(i)
        else:
            with ThreadPoolExecutor(args.parallel) as executor:
                rets = list(
                    tqdm(
                        executor.map(get_one_answer, list(range(len(arguments)))),
                        total=len(arguments),
                    )
                )
                for _ in rets:
                    pass
Liangsheng Yin's avatar
Liangsheng Yin committed
194
    else:
Liangsheng Yin's avatar
Liangsheng Yin committed
195
196
197
198
199
200
201
202
203
        batches = []
        for i in range(0, len(arguments), args.parallel):
            batches.append(list(range(i, min(i + args.parallel, len(arguments)))))
        loop = asyncio.get_event_loop()

        for bt in tqdm(batches):
            loop.run_until_complete(
                asyncio.gather(*[get_one_answer_async(i) for i in bt])
            )
Liangsheng Yin's avatar
Liangsheng Yin committed
204

205
    latency = time.perf_counter() - tic
Liangsheng Yin's avatar
Liangsheng Yin committed
206

207
208
209
210
211
212
213
214
215
216
217
218
    return states, latency


def bench_city_doc(args):
    arguments = []
    for line in read_jsonl(args.data_path):
        arguments.append({"document": line["document"]})
    arguments = arguments[: args.num_jsons]

    states = [None] * len(arguments)

    # Select backend
Liangsheng Yin's avatar
Liangsheng Yin committed
219
220
    if args.backend == "outlines":
        call_generate = partial(get_call_generate(args), temperature=0)
221

Liangsheng Yin's avatar
Liangsheng Yin committed
222
223
        def get_one_answer(i):
            states[i] = city_gen(**arguments[i], generate=call_generate)
224
225
226

    elif args.backend == "guidance":
        model = guidance.models.LlamaCpp(
Liangsheng Yin's avatar
Liangsheng Yin committed
227
            args.model_path,
228
            n_gpu_layers=-1,
Liangsheng Yin's avatar
Liangsheng Yin committed
229
            n_ctx=args.n_ctx,
230
231
        )

Liangsheng Yin's avatar
Liangsheng Yin committed
232
        def get_one_answer(i):
233
234
235
236
237
238
            lm = model + city_maker(**arguments[i])
            states[i] = lm

    else:
        raise ValueError(f"Invalid backend: {args.backend}")

239
    tic = time.perf_counter()
240
241
242
243
244
245
246
247
248
    if args.parallel == 1:
        for i in tqdm(range(len(arguments))):
            get_one_answer(i)
    else:
        with ThreadPoolExecutor(args.parallel) as executor:
            rets = executor.map(get_one_answer, list(range(len(arguments))))
            for _ in rets:
                pass

249
    latency = time.perf_counter() - tic
250
251
252
253
254
255
256
257
258
259
260
261

    return states, latency


def main(args):
    if args.mode == "character":
        args.data_path = "dataset.txt"
        states, latency = bench_character(args)
    elif args.mode == "city":
        args.data_path = "questions.jsonl"
        states, latency = bench_city_doc(args)

Liangsheng Yin's avatar
Liangsheng Yin committed
262
263
264
265
    # Compute accuracy
    print(f"Latency: {latency:.3f}")

    # Write results
266
    dump_state_text(f"tmp_output_{args.backend}_{args.mode}.txt", states)
Liangsheng Yin's avatar
Liangsheng Yin committed
267
268
269

    with open(args.result_file, "a") as fout:
        value = {
Liangsheng Yin's avatar
Liangsheng Yin committed
270
            "task": "json_jump_forward",
Liangsheng Yin's avatar
Liangsheng Yin committed
271
272
273
            "backend": args.backend,
            "latency": round(latency, 3),
            "num_jsons": args.num_jsons,
hnyls2002's avatar
hnyls2002 committed
274
            "mode": args.mode,
Liangsheng Yin's avatar
Liangsheng Yin committed
275
276
277
278
279
280
281
            "parallel": args.parallel,
        }
        fout.write(json.dumps(value) + "\n")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
282
    parser.add_argument("--data-path", type=str)
Liangsheng Yin's avatar
Liangsheng Yin committed
283
    parser.add_argument("--num-jsons", type=int, default=50)
284
285
286
    parser.add_argument(
        "--mode", type=str, default="character", choices=["character", "city"]
    )
Liangsheng Yin's avatar
Liangsheng Yin committed
287
288
    args = add_common_other_args_and_parse(parser)
    main(args)