choices_logprob.py 1.32 KB
Newer Older
1
"""
2
Usage:
3
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
4
python choices_logprob.py
5
"""
Liangsheng Yin's avatar
Liangsheng Yin committed
6

7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import sglang as sgl


@sgl.function
def tool_use(s, question):
    s += "To answer this question: " + question + ", "
    s += "I need to use a " + sgl.gen("tool", choices=["calculator", "search engine"])


def main():
    # Run one case
    question = "What is 5 + 5?"
    state = tool_use.run(question)
    print("questions:", question)
    print("choice:", state["tool"])
    meta_info = state.get_meta_info("tool")
23
24
    print("logprobs of choice 1", meta_info["input_token_logprobs"][0])
    print("logprobs of choice 2", meta_info["input_token_logprobs"][1])
Liangsheng Yin's avatar
Liangsheng Yin committed
25
    print("-" * 50)
26
27
28
29
30
31
32
33
34
35
36

    # Run a batch
    questions = [
        "What is 5 + 6?",
        "Who is Michael Jordan?",
    ]
    states = tool_use.run_batch([{"question": q} for q in questions])
    for question, state in zip(questions, states):
        print("questions:", question)
        print("choice:", state["tool"])
        meta_info = state.get_meta_info("tool")
37
38
        print("logprobs of choice 1", meta_info["input_token_logprobs"][0])
        print("logprobs of choice 2", meta_info["input_token_logprobs"][1])
Liangsheng Yin's avatar
Liangsheng Yin committed
39
        print("-" * 50)
40
41
42
43
44


if __name__ == "__main__":
    sgl.set_default_backend(sgl.RuntimeEndpoint("http://localhost:30000"))
    main()