test_latency.py 1.82 KB
Newer Older
Lianmin Zheng's avatar
Lianmin Zheng committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import argparse
import random
import time

import requests

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--host", type=str, default="http://127.0.0.1")
    parser.add_argument("--port", type=int, default=None)
    parser.add_argument("--backend", type=str, default="srt")
    args = parser.parse_args()

    if args.port is None:
        if args.backend == "srt":
            args.port = 30000
        elif args.backend == "vllm":
            args.port = 21000
        elif args.backend == "lightllm":
            args.port = 22000
        else:
            raise ValueError(f"Invalid backend: {args.backend}")

    url = f"{args.host}:{args.port}"
    a = random.randint(0, 1 << 20)
    max_new_tokens = 256

    tic = time.time()
    if args.backend == "srt":
        response = requests.post(
            url + "/generate",
            json={
                "text": f"{a}, ",
                "sampling_params": {
                    "temperature": 0,
                    "max_new_tokens": max_new_tokens,
                },
            },
        )
    elif args.backend == "lightllm":
        response = requests.post(
            url + "/generate",
            json={
                "inputs": f"{a}, ",
                "parameters": {
                    "temperature": 0,
                    "max_new_tokens": max_new_tokens,
                },
            },
        )
    elif args.backend == "vllm":
        response = requests.post(
            url + "/generate",
            json={
                "prompt": f"{a}, ",
                "temperature": 0,
                "max_tokens": max_new_tokens,
            },
        )
    latency = time.time() - tic

    ret = response.json()
    print(ret)

    speed = max_new_tokens / latency
    print(f"latency: {latency:.2f} s, speed: {speed:.2f} token/s")