bench_sglang.py 2.18 KB
Newer Older
Lianmin Zheng's avatar
Lianmin Zheng committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import argparse
import json
import time

import numpy as np
import sglang as sgl
from sglang.test.test_utils import add_common_sglang_args_and_parse, select_sglang_backend
from sglang.utils import read_jsonl, dump_state_text


@sgl.function
def json_decode(s, document):
    s += "Please extract the information of a city from the following wikipedia page.\n"
    s += "Page begin.\n" + document + "Page end.\n"
    s += "Here is the name, country, and symbol of the city in JSON format.\n"
    s += '{\n'
    s += '  "name": "' + sgl.gen("name", max_tokens=8, stop='"') + '",\n'
    s += '  "country": "' + sgl.gen("country", max_tokens=8, stop='"') + '",\n'
    s += '  "air port code": "' + sgl.gen("air port code", max_tokens=8, stop='"') + '",\n'
    s += '  "top 3 landmarks": "' + sgl.gen("landmarks", max_tokens=24, stop='"') + '",\n'
    s += '}\n'


def main(args):
    lines = read_jsonl(args.data_path)
    arguments = []
    for i in range(len(lines[:args.num_questions])):
        arguments.append({
            "document": lines[i]["document"],
        })

    # Select backend
    backend = select_sglang_backend(args)
    sgl.set_default_backend(backend)

    # Run requests
    tic = time.time()
    states = json_decode.run_batch(
Liangsheng Yin's avatar
Liangsheng Yin committed
39
        arguments, temperature=0, num_threads=args.parallel, progress_bar=True)
Lianmin Zheng's avatar
Lianmin Zheng committed
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
    latency = time.time() - tic

    # Compute accuracy 
    print(f"Latency: {latency:.3f}")

    # Write results
    dump_state_text(f"tmp_output_{args.backend}.txt", states)

    with open(args.result_file, "a") as fout:
        value = {
            "task": "long_json_decode",
            "backend": args.backend,
            "num_gpus": 1,
            "latency": round(latency, 3),
            "num_requests": args.num_questions,
            "other": {
                "num_questions": args.num_questions,
                "parallel": args.parallel,
            }
        }
        fout.write(json.dumps(value) + "\n")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--data-path", type=str, default="questions.jsonl")
    parser.add_argument("--num-questions", type=int, default=10)
    args = add_common_sglang_args_and_parse(parser)
    main(args)