json_logprobs.py 3.19 KB
Newer Older
Liangsheng Yin's avatar
Liangsheng Yin committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
# NOTE: Currently this can only be run through HTTP requests.
import json
from concurrent.futures import ThreadPoolExecutor

from json_decode import character_regex

from sglang.utils import http_request

character_names = ["Hermione Granger", "Ron Weasley", "Harry Potter"]

base_url = "http://localhost:30000"

prompt = "is a character in Harry Potter. Please fill in the following information about this character.\n"


def openai_api_request(name):
    data = {
        "model": "",
        "prompt": name + prompt,
        "temperature": 0,
        "max_tokens": 128,
        "regex": character_regex,
        "logprobs": 3,
    }
    res = http_request(base_url + "/v1/completions", json=data).json()

    # with open(f"json_logprobs_{name.replace(' ', '_')}_tmp.json", "w") as fout:
    #     fout.write(json.dumps(res, indent=4))

    logprobs = res["choices"][0]["logprobs"]
    usage = res["usage"]
    assert len(logprobs["token_logprobs"]) == len(logprobs["tokens"])
    assert len(logprobs["token_logprobs"]) == len(logprobs["top_logprobs"])
    assert len(logprobs["token_logprobs"]) == usage["completion_tokens"] - 1

    return res


def srt_api_request(name):
    data = {
        "text": name + prompt,
        "sampling_params": {
            "temperature": 0,
            "max_new_tokens": 128,
            "regex": character_regex,
        },
        "return_logprob": True,
        "logprob_start_len": 0,
        "top_logprobs_num": 3,
        "return_text_in_logprobs": True,
    }

    res = http_request(base_url + "/generate", json=data).json()

    # with open(f"json_logprobs_{name.replace(' ', '_')}_tmp.json", "w") as fout:
    #     fout.write(json.dumps(res, indent=4))

    meta_info = res["meta_info"]
59
60
    assert len(meta_info["input_token_logprobs"]) == len(
        meta_info["input_top_logprobs"]
Liangsheng Yin's avatar
Liangsheng Yin committed
61
    )
62
63
    assert len(meta_info["output_token_logprobs"]) == len(
        meta_info["output_top_logprobs"]
Liangsheng Yin's avatar
Liangsheng Yin committed
64
    )
65
66
    assert len(meta_info["input_token_logprobs"]) == meta_info["prompt_tokens"]
    assert len(meta_info["output_token_logprobs"]) == meta_info["completion_tokens"] - 1
Liangsheng Yin's avatar
Liangsheng Yin committed
67
68
69
70
71
72
73
74

    return res


def pretty_print(res):
    meta_info = res["meta_info"]

    print("\n\n", "=" * 30, "Prefill", "=" * 30)
75
76
    for i in range(len(meta_info["input_token_logprobs"])):
        print(f"{str(meta_info['input_token_logprobs'][i][2].encode()): <20}", end="")
Liangsheng Yin's avatar
Liangsheng Yin committed
77
        top_ks = (
78
79
            [str(t[2].encode()) for t in meta_info["input_top_logprobs"][i]]
            if meta_info["input_top_logprobs"][i]
Liangsheng Yin's avatar
Liangsheng Yin committed
80
81
82
83
84
85
86
            else []
        )
        for top_k in top_ks:
            print(f"{top_k: <15}", end="")
        print()

    print("\n\n", "=" * 30, "Decode", "=" * 30)
87
88
89
    for i in range(len(meta_info["output_token_logprobs"])):
        print(f"{str(meta_info['output_token_logprobs'][i][2].encode()): <20}", end="")
        top_ks = [str(t[2].encode()) for t in meta_info["output_top_logprobs"][i]]
Liangsheng Yin's avatar
Liangsheng Yin committed
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
        for top_k in top_ks:
            print(f"{top_k: <15}", end="")
        print()

    print(res["text"])


if __name__ == "__main__":
    with ThreadPoolExecutor() as executor:
        ress = executor.map(srt_api_request, character_names)

    for res in ress:
        pretty_print(res)

    openai_api_request("Hermione Granger")