openai_chat_speculative.py 4.79 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
"""
Usage:
***Note: for speculative execution to work, user must put all "gen" in "assistant".
Show in "assistant" the desired answer format. Each "gen" term should have a stop token.
The stream mode is not supported in speculative execution.

E.g.
correct:
    sgl.assistant("\nName:" + sgl.gen("name", stop="\n") + "\nBirthday:" + sgl.gen("birthday", stop="\n") + "\nJob:" + sgl.gen("job", stop="\n"))
incorrect:
    s += sgl.assistant("\nName:" + sgl.gen("name", stop="\n"))
    s += sgl.assistant("\nBirthday:" + sgl.gen("birthday", stop="\n"))
    s += sgl.assistant("\nJob:" + sgl.gen("job", stop="\n"))

export OPENAI_API_KEY=sk-******
python3 openai_chat_speculative.py
"""

import sglang as sgl
from sglang import OpenAI, function, set_default_backend


@function(num_api_spec_tokens=256)
def gen_character_spec(s):
    s += sgl.system("You are a helpful assistant.")
    s += sgl.user("Construct a character within the following format:")
    s += sgl.assistant(
        "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n"
    )
    s += sgl.user("Please generate new Name, Birthday and Job.\n")
    s += sgl.assistant(
        "Name:"
        + sgl.gen("name", stop="\n")
        + "\nBirthday:"
        + sgl.gen("birthday", stop="\n")
        + "\nJob:"
        + sgl.gen("job", stop="\n")
    )


@function(num_api_spec_tokens=256)
def gen_character_spec_no_few_shot(s):
    s += sgl.user("Construct a character. For each field stop with a newline\n")
    s += sgl.assistant(
        "Name:"
        + sgl.gen("name", stop="\n")
        + "\nAge:"
        + sgl.gen("age", stop="\n")
        + "\nJob:"
        + sgl.gen("job", stop="\n")
    )


@function
def gen_character_normal(s):
    s += sgl.system("You are a helpful assistant.")
    s += sgl.user("What's the answer of 23 + 8?")
    s += sgl.assistant(sgl.gen("answer", max_tokens=64))


@function(num_api_spec_tokens=1024)
def multi_turn_question(s, question_1, question_2):
    s += sgl.system("You are a helpful assistant.")
    s += sgl.user("Answer questions in the following format:")
    s += sgl.user(
        "Question 1: What is the capital of France?\nQuestion 2: What is the population of this city?\n"
    )
    s += sgl.assistant(
        "Answer 1: The capital of France is Paris.\nAnswer 2: The population of Paris in 2024 is estimated to be around 2.1 million for the city proper.\n"
    )
    s += sgl.user("Question 1: " + question_1 + "\nQuestion 2: " + question_2)
    s += sgl.assistant(
        "Answer 1: "
        + sgl.gen("answer_1", stop="\n")
        + "\nAnswer 2: "
        + sgl.gen("answer_2", stop="\n")
    )


def test_spec_single_turn():
    backend.token_usage.reset()

    state = gen_character_spec.run()
    for m in state.messages():
        print(m["role"], ":", m["content"])

    print("\n-- name:", state["name"])
    print("-- birthday:", state["birthday"])
    print("-- job:", state["job"])
    print(backend.token_usage)


def test_inaccurate_spec_single_turn():
    state = gen_character_spec_no_few_shot.run()
    for m in state.messages():
        print(m["role"], ":", m["content"])

    print("\n-- name:", state["name"])
    print("\n-- age:", state["age"])
    print("\n-- job:", state["job"])


def test_normal_single_turn():
    state = gen_character_normal.run()
    for m in state.messages():
        print(m["role"], ":", m["content"])


def test_spec_multi_turn():
    state = multi_turn_question.run(
        question_1="What is the capital of the United States?",
        question_2="List two local attractions in the capital of the United States.",
    )

    for m in state.messages():
        print(m["role"], ":", m["content"])

    print("\n-- answer_1 --\n", state["answer_1"])
    print("\n-- answer_2 --\n", state["answer_2"])


def test_spec_multi_turn_stream():
    state = multi_turn_question.run(
        question_1="What is the capital of the United States?",
        question_2="List two local attractions.",
        stream=True,
    )

    for out in state.text_iter():
        print(out, end="", flush=True)


if __name__ == "__main__":
    backend = OpenAI("gpt-4-turbo")
    set_default_backend(backend)

    print("\n========== test spec single turn ==========\n")
    # expect reasonable answer for each field
    test_spec_single_turn()

    print("\n========== test inaccurate spec single turn ==========\n")
    # expect incomplete or unreasonable answers
    test_inaccurate_spec_single_turn()

    print("\n========== test normal single turn ==========\n")
    # expect reasonable answer
    test_normal_single_turn()

    print("\n========== test spec multi turn ==========\n")
    # expect answer with same format as in the few shot
    test_spec_multi_turn()

    print("\n========== test spec multi turn stream ==========\n")
    # expect error in stream_executor: stream is not supported...
    test_spec_multi_turn_stream()