test_httpserver_classify.py 2.09 KB
Newer Older
1
2
"""
Usage:
3
python3 -m sglang.launch_server --model-path /model/llama-classification --is-embedding --disable-radix-cache
4

5
6
7
8
9
10
11
12
13
python3 test_httpserver_classify.py
"""

import argparse

import numpy as np
import requests


14
def get_logits_deprecated(url: str, prompt: str):
15
16
17
18
19
20
21
22
23
24
25
26
27
    response = requests.post(
        url + "/generate",
        json={
            "text": prompt,
            "sampling_params": {
                "max_new_tokens": 0,
            },
            "return_logprob": True,
        },
    )
    return response.json()["meta_info"]["normalized_prompt_logprob"]


28
def get_logits_batch_deprecated(url: str, prompts: list[str]):
29
30
31
32
33
34
35
36
37
38
39
    response = requests.post(
        url + "/generate",
        json={
            "text": prompts,
            "sampling_params": {
                "max_new_tokens": 0,
            },
            "return_logprob": True,
        },
    )
    ret = response.json()
Ying Sheng's avatar
Ying Sheng committed
40
41
42
43
44
45
    logits = np.array(
        list(
            ret[i]["meta_info"]["normalized_prompt_logprob"]
            for i in range(len(prompts))
        )
    )
46
47
48
    return logits


49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
def get_logits(url: str, prompt: str):
    response = requests.post(
        url + "/classify",
        json={"text": prompt},
    )
    return response.json()["embedding"]


def get_logits_batch(url: str, prompts: list[str]):
    response = requests.post(
        url + "/classify",
        json={"text": prompts},
    )
    return np.array([x["embedding"] for x in response.json()])


65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--host", type=str, default="http://127.0.0.1")
    parser.add_argument("--port", type=int, default=30000)
    args = parser.parse_args()

    url = f"{args.host}:{args.port}"

    # A single request
    prompt = "This is a test prompt.<|eot_id|>"
    logits = get_logits(url, prompt)
    print(f"{logits=}")

    # A batch of requests
    prompts = [
        "This is a test prompt.<|eot_id|>",
        "This is another test prompt.<|eot_id|>",
        "This is a long long long long test prompt.<|eot_id|>",
    ]
    logits = get_logits_batch(url, prompts)
Ying Sheng's avatar
Ying Sheng committed
85
    print(f"{logits=}")