bench_one.py 2.87 KB
Newer Older
Lianmin Zheng's avatar
Lianmin Zheng committed
1
2
3
4
5
6
7
8
9
10
import argparse
import time

import requests

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--host", type=str, default="http://127.0.0.1")
    parser.add_argument("--port", type=int, default=None)
    parser.add_argument("--backend", type=str, default="srt")
11
    parser.add_argument("--batch-size", type=int, default=1)
Liangsheng Yin's avatar
Liangsheng Yin committed
12
    parser.add_argument("--max-tokens", type=int, default=256)
Lianmin Zheng's avatar
Lianmin Zheng committed
13
14
15
16
17
18
19
20
21
    args = parser.parse_args()

    if args.port is None:
        if args.backend == "srt":
            args.port = 30000
        elif args.backend == "vllm":
            args.port = 21000
        elif args.backend == "lightllm":
            args.port = 22000
Lianmin Zheng's avatar
Lianmin Zheng committed
22
        elif args.backend == "ginfer":
23
            args.port = 9988
Lianmin Zheng's avatar
Lianmin Zheng committed
24
25
26
27
        else:
            raise ValueError(f"Invalid backend: {args.backend}")

    url = f"{args.host}:{args.port}"
28
    a = 20
Liangsheng Yin's avatar
Liangsheng Yin committed
29
    max_new_tokens = args.max_tokens
30
    prompt = f"{a, }"
Lianmin Zheng's avatar
Lianmin Zheng committed
31
32
33
34
35
36

    tic = time.time()
    if args.backend == "srt":
        response = requests.post(
            url + "/generate",
            json={
37
                "text": [prompt] * args.batch_size,
Lianmin Zheng's avatar
Lianmin Zheng committed
38
39
40
                "sampling_params": {
                    "temperature": 0,
                    "max_new_tokens": max_new_tokens,
Liangsheng Yin's avatar
Liangsheng Yin committed
41
                    "ignore_eos": True,
Lianmin Zheng's avatar
Lianmin Zheng committed
42
43
44
45
46
47
48
                },
            },
        )
    elif args.backend == "lightllm":
        response = requests.post(
            url + "/generate",
            json={
49
                "inputs": prompt,
Lianmin Zheng's avatar
Lianmin Zheng committed
50
51
52
                "parameters": {
                    "temperature": 0,
                    "max_new_tokens": max_new_tokens,
Liangsheng Yin's avatar
Liangsheng Yin committed
53
                    "ignore_eos": True,
Lianmin Zheng's avatar
Lianmin Zheng committed
54
55
56
57
58
59
60
                },
            },
        )
    elif args.backend == "vllm":
        response = requests.post(
            url + "/generate",
            json={
61
                "prompt": prompt,
Lianmin Zheng's avatar
Lianmin Zheng committed
62
63
                "temperature": 0,
                "max_tokens": max_new_tokens,
Liangsheng Yin's avatar
Liangsheng Yin committed
64
                "ignore_eos": True,
Lianmin Zheng's avatar
Lianmin Zheng committed
65
66
            },
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
67
    elif args.backend == "ginfer":
68
        import grpc
Lianmin Zheng's avatar
Lianmin Zheng committed
69
        from ginfer import sampler_pb2, sampler_pb2_grpc
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85

        sampler_channel = grpc.insecure_channel(url.replace("http://", ""))
        sampler = sampler_pb2_grpc.SamplerStub(sampler_channel)

        tic = time.time()
        sample_request = sampler_pb2.SampleTextRequest(
            prompt=prompt,
            settings=sampler_pb2.SampleSettings(
                max_len=max_new_tokens,
                rng_seed=0,
                temperature=0,
                nucleus_p=1,
            ),
        )
        stream = sampler.SampleText(sample_request)
        response = "".join([x.text for x in stream])
Lianmin Zheng's avatar
Lianmin Zheng committed
86
87
    latency = time.time() - tic

88
89
90
91
    if isinstance(response, str):
        ret = response
    else:
        ret = response.json()
Lianmin Zheng's avatar
Lianmin Zheng committed
92
93
    print(ret)

94
    speed = args.batch_size * max_new_tokens / latency
Ying Sheng's avatar
Ying Sheng committed
95
    print(f"latency: {latency:.2f} s, speed: {speed:.2f} token/s")