bench_one.py 4.04 KB
Newer Older
1
2
3
4
5
"""
Usage:
python3 bench_one.py --input-len 2048 --batch-size 1 2 4 8 16 32 64 128 256 512
"""

Lianmin Zheng's avatar
Lianmin Zheng committed
6
import argparse
7
import json
Lianmin Zheng's avatar
Lianmin Zheng committed
8
9
import time

10
import numpy as np
Lianmin Zheng's avatar
Lianmin Zheng committed
11
12
13
import requests


14
def run_one_batch_size(bs):
Lianmin Zheng's avatar
Lianmin Zheng committed
15
    url = f"{args.host}:{args.port}"
Liangsheng Yin's avatar
Liangsheng Yin committed
16
    max_new_tokens = args.max_tokens
17

18
19
20
21
22
23
    if args.input_len:
        input_ids = [
            [int(x) for x in np.random.randint(0, high=16384, size=(args.input_len,))] for _ in range(bs)
        ]
    else:
        text = [f"{i, }" for i in range(bs)]
Lianmin Zheng's avatar
Lianmin Zheng committed
24
25
26

    tic = time.time()
    if args.backend == "srt":
27
        if args.input_len:
28
            inputs = {"input_ids": input_ids}
29
        else:
30
            inputs = {"text": text}
31

Lianmin Zheng's avatar
Lianmin Zheng committed
32
33
34
35
36
37
        response = requests.post(
            url + "/generate",
            json={
                "sampling_params": {
                    "temperature": 0,
                    "max_new_tokens": max_new_tokens,
Liangsheng Yin's avatar
Liangsheng Yin committed
38
                    "ignore_eos": True,
Lianmin Zheng's avatar
Lianmin Zheng committed
39
                },
40
                **inputs,
Lianmin Zheng's avatar
Lianmin Zheng committed
41
42
43
44
45
46
            },
        )
    elif args.backend == "lightllm":
        response = requests.post(
            url + "/generate",
            json={
47
                "inputs": text[0],
Lianmin Zheng's avatar
Lianmin Zheng committed
48
49
50
                "parameters": {
                    "temperature": 0,
                    "max_new_tokens": max_new_tokens,
Liangsheng Yin's avatar
Liangsheng Yin committed
51
                    "ignore_eos": True,
Lianmin Zheng's avatar
Lianmin Zheng committed
52
53
54
55
                },
            },
        )
    elif args.backend == "vllm":
56
57
58
59
60
        if args.input_len:
            inputs = {"prompt": input_ids}
        else:
            inputs = {"prompt": text}

Lianmin Zheng's avatar
Lianmin Zheng committed
61
        response = requests.post(
62
            url + "/v1/completions",
Lianmin Zheng's avatar
Lianmin Zheng committed
63
            json={
64
                "model": args.vllm_model_name,
Lianmin Zheng's avatar
Lianmin Zheng committed
65
66
                "temperature": 0,
                "max_tokens": max_new_tokens,
Liangsheng Yin's avatar
Liangsheng Yin committed
67
                "ignore_eos": True,
68
                **inputs,
Lianmin Zheng's avatar
Lianmin Zheng committed
69
70
            },
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
71
    elif args.backend == "ginfer":
72
        import grpc
Lianmin Zheng's avatar
Lianmin Zheng committed
73
        from ginfer import sampler_pb2, sampler_pb2_grpc
74
75
76
77
78
79

        sampler_channel = grpc.insecure_channel(url.replace("http://", ""))
        sampler = sampler_pb2_grpc.SamplerStub(sampler_channel)

        tic = time.time()
        sample_request = sampler_pb2.SampleTextRequest(
80
            prompt=text[0],
81
82
83
84
85
86
87
88
89
            settings=sampler_pb2.SampleSettings(
                max_len=max_new_tokens,
                rng_seed=0,
                temperature=0,
                nucleus_p=1,
            ),
        )
        stream = sampler.SampleText(sample_request)
        response = "".join([x.text for x in stream])
Lianmin Zheng's avatar
Lianmin Zheng committed
90
91
    latency = time.time() - tic

92
93
94
95
    if isinstance(response, str):
        ret = response
    else:
        ret = response.json()
Lianmin Zheng's avatar
Lianmin Zheng committed
96
97
    print(ret)

98
99
100
    output_throughput = bs * max_new_tokens / latency
    print(f"latency: {latency:.2f} s, speed: {output_throughput:.2f} token/s")

101
    with open("results.jsonl", "a") as fout:
102
        res = {
103
            "backend": args.backend,
104
105
106
107
            "input_len": args.input_len,
            "output_len": args.max_tokens,
            "batch_size": bs,
            "latency": latency,
108
            "output_throughput": output_throughput,
109
110
111
112
113
114
115
116
117
118
119
120
        }
        fout.write(json.dumps(res) + "\n")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--host", type=str, default="http://127.0.0.1")
    parser.add_argument("--port", type=int, default=None)
    parser.add_argument("--backend", type=str, default="srt")
    parser.add_argument("--input-len", type=int, default=None)
    parser.add_argument("--batch-size", type=int, nargs='*', default=[1])
    parser.add_argument("--max-tokens", type=int, default=256)
121
    parser.add_argument("--vllm-model-name", type=str, default="meta-llama/Meta-Llama-3-70B")
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
    args = parser.parse_args()

    if args.port is None:
        if args.backend == "srt":
            args.port = 30000
        elif args.backend == "vllm":
            args.port = 21000
        elif args.backend == "lightllm":
            args.port = 22000
        elif args.backend == "ginfer":
            args.port = 9988
        else:
            raise ValueError(f"Invalid backend: {args.backend}")

    for bs in args.batch_size:
        run_one_batch_size(bs)