bench_one.py 4.28 KB
Newer Older
1
2
3
4
5
"""
Usage:
python3 bench_one.py --input-len 2048 --batch-size 1 2 4 8 16 32 64 128 256 512
"""

Lianmin Zheng's avatar
Lianmin Zheng committed
6
import argparse
7
import json
Lianmin Zheng's avatar
Lianmin Zheng committed
8
9
import time

10
import numpy as np
Lianmin Zheng's avatar
Lianmin Zheng committed
11
12
13
import requests


14
def run_one_batch_size(bs):
Lianmin Zheng's avatar
Lianmin Zheng committed
15
    url = f"{args.host}:{args.port}"
Liangsheng Yin's avatar
Liangsheng Yin committed
16
    max_new_tokens = args.max_tokens
17

18
19
    if args.input_len:
        input_ids = [
20
21
            [int(x) for x in np.random.randint(0, high=16384, size=(args.input_len,))]
            for _ in range(bs)
22
23
24
        ]
    else:
        text = [f"{i, }" for i in range(bs)]
Lianmin Zheng's avatar
Lianmin Zheng committed
25
26
27

    tic = time.time()
    if args.backend == "srt":
28
        if args.input_len:
29
            inputs = {"input_ids": input_ids}
30
        else:
31
            inputs = {"text": text}
32

Lianmin Zheng's avatar
Lianmin Zheng committed
33
34
35
36
37
38
        response = requests.post(
            url + "/generate",
            json={
                "sampling_params": {
                    "temperature": 0,
                    "max_new_tokens": max_new_tokens,
Liangsheng Yin's avatar
Liangsheng Yin committed
39
                    "ignore_eos": True,
Lianmin Zheng's avatar
Lianmin Zheng committed
40
                },
41
                **inputs,
Lianmin Zheng's avatar
Lianmin Zheng committed
42
43
44
45
46
47
            },
        )
    elif args.backend == "lightllm":
        response = requests.post(
            url + "/generate",
            json={
48
                "inputs": text[0],
Lianmin Zheng's avatar
Lianmin Zheng committed
49
50
51
                "parameters": {
                    "temperature": 0,
                    "max_new_tokens": max_new_tokens,
Liangsheng Yin's avatar
Liangsheng Yin committed
52
                    "ignore_eos": True,
Lianmin Zheng's avatar
Lianmin Zheng committed
53
54
55
56
                },
            },
        )
    elif args.backend == "vllm":
57
58
59
60
61
        if args.input_len:
            inputs = {"prompt": input_ids}
        else:
            inputs = {"prompt": text}

Lianmin Zheng's avatar
Lianmin Zheng committed
62
        response = requests.post(
63
            url + "/v1/completions",
Lianmin Zheng's avatar
Lianmin Zheng committed
64
            json={
65
                "model": args.vllm_model_name,
Lianmin Zheng's avatar
Lianmin Zheng committed
66
67
                "temperature": 0,
                "max_tokens": max_new_tokens,
Liangsheng Yin's avatar
Liangsheng Yin committed
68
                "ignore_eos": True,
69
                **inputs,
Lianmin Zheng's avatar
Lianmin Zheng committed
70
71
            },
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
72
    elif args.backend == "ginfer":
73
        import grpc
Lianmin Zheng's avatar
Lianmin Zheng committed
74
        from ginfer import sampler_pb2, sampler_pb2_grpc
75
76
77
78
79
80

        sampler_channel = grpc.insecure_channel(url.replace("http://", ""))
        sampler = sampler_pb2_grpc.SamplerStub(sampler_channel)

        tic = time.time()
        sample_request = sampler_pb2.SampleTextRequest(
81
            prompt=text[0],
82
83
84
85
86
87
88
89
90
            settings=sampler_pb2.SampleSettings(
                max_len=max_new_tokens,
                rng_seed=0,
                temperature=0,
                nucleus_p=1,
            ),
        )
        stream = sampler.SampleText(sample_request)
        response = "".join([x.text for x in stream])
Lianmin Zheng's avatar
Lianmin Zheng committed
91
92
    latency = time.time() - tic

93
94
95
96
    if isinstance(response, str):
        ret = response
    else:
        ret = response.json()
Lianmin Zheng's avatar
Lianmin Zheng committed
97
98
    print(ret)

99
    output_throughput = bs * max_new_tokens / latency
Ying Sheng's avatar
Ying Sheng committed
100
101
102
103
    overall_throughput = bs * (args.input_len + max_new_tokens) / latency
    print(f"latency: {latency:.2f} s")
    print(f"decode throughput: {output_throughput:.2f} token/s")
    print(f"overall throughput: {overall_throughput:.2f} token/s")
104

105
    with open("results.jsonl", "a") as fout:
106
        res = {
107
            "backend": args.backend,
108
109
110
111
            "input_len": args.input_len,
            "output_len": args.max_tokens,
            "batch_size": bs,
            "latency": latency,
112
            "output_throughput": output_throughput,
Ying Sheng's avatar
Ying Sheng committed
113
            "overall_throughput": overall_throughput,
114
115
116
117
118
119
120
121
122
123
        }
        fout.write(json.dumps(res) + "\n")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--host", type=str, default="http://127.0.0.1")
    parser.add_argument("--port", type=int, default=None)
    parser.add_argument("--backend", type=str, default="srt")
    parser.add_argument("--input-len", type=int, default=None)
124
    parser.add_argument("--batch-size", type=int, nargs="*", default=[1])
125
    parser.add_argument("--max-tokens", type=int, default=256)
126
127
128
    parser.add_argument(
        "--vllm-model-name", type=str, default="meta-llama/Meta-Llama-3-70B"
    )
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
    args = parser.parse_args()

    if args.port is None:
        if args.backend == "srt":
            args.port = 30000
        elif args.backend == "vllm":
            args.port = 21000
        elif args.backend == "lightllm":
            args.port = 22000
        elif args.backend == "ginfer":
            args.port = 9988
        else:
            raise ValueError(f"Invalid backend: {args.backend}")

    for bs in args.batch_size:
        run_one_batch_size(bs)