test_latency.py 2.6 KB
Newer Older
Lianmin Zheng's avatar
Lianmin Zheng committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import argparse
import random
import time

import requests

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--host", type=str, default="http://127.0.0.1")
    parser.add_argument("--port", type=int, default=None)
    parser.add_argument("--backend", type=str, default="srt")
    args = parser.parse_args()

    if args.port is None:
        if args.backend == "srt":
            args.port = 30000
        elif args.backend == "vllm":
            args.port = 21000
        elif args.backend == "lightllm":
            args.port = 22000
Lianmin Zheng's avatar
Lianmin Zheng committed
21
        elif args.backend == "ginfer":
22
            args.port = 9988
Lianmin Zheng's avatar
Lianmin Zheng committed
23
24
25
26
        else:
            raise ValueError(f"Invalid backend: {args.backend}")

    url = f"{args.host}:{args.port}"
27
    a = 20
Lianmin Zheng's avatar
Lianmin Zheng committed
28
    max_new_tokens = 256
29
    prompt = f"{a, }"
Lianmin Zheng's avatar
Lianmin Zheng committed
30
31
32
33
34
35

    tic = time.time()
    if args.backend == "srt":
        response = requests.post(
            url + "/generate",
            json={
36
                "text": prompt,
Lianmin Zheng's avatar
Lianmin Zheng committed
37
38
39
40
41
42
43
44
45
46
                "sampling_params": {
                    "temperature": 0,
                    "max_new_tokens": max_new_tokens,
                },
            },
        )
    elif args.backend == "lightllm":
        response = requests.post(
            url + "/generate",
            json={
47
                "inputs": prompt,
Lianmin Zheng's avatar
Lianmin Zheng committed
48
49
50
51
52
53
54
55
56
57
                "parameters": {
                    "temperature": 0,
                    "max_new_tokens": max_new_tokens,
                },
            },
        )
    elif args.backend == "vllm":
        response = requests.post(
            url + "/generate",
            json={
58
                "prompt": prompt,
Lianmin Zheng's avatar
Lianmin Zheng committed
59
60
61
62
                "temperature": 0,
                "max_tokens": max_new_tokens,
            },
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
63
    elif args.backend == "ginfer":
64
        import grpc
Lianmin Zheng's avatar
Lianmin Zheng committed
65
        from ginfer import sampler_pb2, sampler_pb2_grpc
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81

        sampler_channel = grpc.insecure_channel(url.replace("http://", ""))
        sampler = sampler_pb2_grpc.SamplerStub(sampler_channel)

        tic = time.time()
        sample_request = sampler_pb2.SampleTextRequest(
            prompt=prompt,
            settings=sampler_pb2.SampleSettings(
                max_len=max_new_tokens,
                rng_seed=0,
                temperature=0,
                nucleus_p=1,
            ),
        )
        stream = sampler.SampleText(sample_request)
        response = "".join([x.text for x in stream])
Lianmin Zheng's avatar
Lianmin Zheng committed
82
83
    latency = time.time() - tic

84
85
86
87
    if isinstance(response, str):
        ret = response
    else:
        ret = response.json()
Lianmin Zheng's avatar
Lianmin Zheng committed
88
89
90
91
    print(ret)

    speed = max_new_tokens / latency
    print(f"latency: {latency:.2f} s, speed: {speed:.2f} token/s")