openai_chat_speculative.py 4.8 KB
Newer Older
1
2
"""
Usage:
3
4
5
6
***Note: for speculative execution to work, user must put all "gen" in "assistant".
Show in "assistant" the desired answer format. Each "gen" term should have a stop token.
The stream mode is not supported in speculative execution.

7
8
9
E.g. 
correct: 
    sgl.assistant("\nName:" + sgl.gen("name", stop="\n") + "\nBirthday:" + sgl.gen("birthday", stop="\n") + "\nJob:" + sgl.gen("job", stop="\n"))
10
incorrect:
11
12
13
14
15
    s += sgl.assistant("\nName:" + sgl.gen("name", stop="\n"))
    s += sgl.assistant("\nBirthday:" + sgl.gen("birthday", stop="\n"))
    s += sgl.assistant("\nJob:" + sgl.gen("job", stop="\n"))

export OPENAI_API_KEY=sk-******
16
python3 openai_chat_speculative.py
17
"""
zhyncs's avatar
zhyncs committed
18

19
import sglang as sgl
zhyncs's avatar
zhyncs committed
20
from sglang import OpenAI, function, set_default_backend
21
22


23
@function(num_api_spec_tokens=256)
24
25
26
def gen_character_spec(s):
    s += sgl.system("You are a helpful assistant.")
    s += sgl.user("Construct a character within the following format:")
zhyncs's avatar
zhyncs committed
27
28
29
    s += sgl.assistant(
        "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n"
    )
30
    s += sgl.user("Please generate new Name, Birthday and Job.\n")
zhyncs's avatar
zhyncs committed
31
32
33
34
35
36
37
38
    s += sgl.assistant(
        "Name:"
        + sgl.gen("name", stop="\n")
        + "\nBirthday:"
        + sgl.gen("birthday", stop="\n")
        + "\nJob:"
        + sgl.gen("job", stop="\n")
    )
39
40


41
@function(num_api_spec_tokens=256)
42
43
def gen_character_spec_no_few_shot(s):
    s += sgl.user("Construct a character. For each field stop with a newline\n")
zhyncs's avatar
zhyncs committed
44
45
46
47
48
49
50
51
    s += sgl.assistant(
        "Name:"
        + sgl.gen("name", stop="\n")
        + "\nAge:"
        + sgl.gen("age", stop="\n")
        + "\nJob:"
        + sgl.gen("job", stop="\n")
    )
52
53
54
55
56
57
58
59
60


@function
def gen_character_normal(s):
    s += sgl.system("You are a helpful assistant.")
    s += sgl.user("What's the answer of 23 + 8?")
    s += sgl.assistant(sgl.gen("answer", max_tokens=64))


61
@function(num_api_spec_tokens=1024)
62
63
64
def multi_turn_question(s, question_1, question_2):
    s += sgl.system("You are a helpful assistant.")
    s += sgl.user("Answer questions in the following format:")
zhyncs's avatar
zhyncs committed
65
66
67
68
69
70
71
72
73
74
75
76
77
    s += sgl.user(
        "Question 1: What is the capital of France?\nQuestion 2: What is the population of this city?\n"
    )
    s += sgl.assistant(
        "Answer 1: The capital of France is Paris.\nAnswer 2: The population of Paris in 2024 is estimated to be around 2.1 million for the city proper.\n"
    )
    s += sgl.user("Question 1: " + question_1 + "\nQuestion 2: " + question_2)
    s += sgl.assistant(
        "Answer 1: "
        + sgl.gen("answer_1", stop="\n")
        + "\nAnswer 2: "
        + sgl.gen("answer_2", stop="\n")
    )
78
79
80


def test_spec_single_turn():
81
82
    backend.token_usage.reset()

83
84
85
86
87
    state = gen_character_spec.run()
    for m in state.messages():
        print(m["role"], ":", m["content"])

    print("\n-- name:", state["name"])
88
89
90
    print("-- birthday:", state["birthday"])
    print("-- job:", state["job"])
    print(backend.token_usage)
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125


def test_inaccurate_spec_single_turn():
    state = gen_character_spec_no_few_shot.run()
    for m in state.messages():
        print(m["role"], ":", m["content"])

    print("\n-- name:", state["name"])
    print("\n-- age:", state["age"])
    print("\n-- job:", state["job"])


def test_normal_single_turn():
    state = gen_character_normal.run()
    for m in state.messages():
        print(m["role"], ":", m["content"])


def test_spec_multi_turn():
    state = multi_turn_question.run(
        question_1="What is the capital of the United States?",
        question_2="List two local attractions in the capital of the United States.",
    )

    for m in state.messages():
        print(m["role"], ":", m["content"])

    print("\n-- answer_1 --\n", state["answer_1"])
    print("\n-- answer_2 --\n", state["answer_2"])


def test_spec_multi_turn_stream():
    state = multi_turn_question.run(
        question_1="What is the capital of the United States?",
        question_2="List two local attractions.",
zhyncs's avatar
zhyncs committed
126
        stream=True,
127
128
129
130
131
132
133
    )

    for out in state.text_iter():
        print(out, end="", flush=True)


if __name__ == "__main__":
134
135
    backend = OpenAI("gpt-4-turbo")
    set_default_backend(backend)
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154

    print("\n========== test spec single turn ==========\n")
    # expect reasonable answer for each field
    test_spec_single_turn()

    print("\n========== test inaccurate spec single turn ==========\n")
    # expect incomplete or unreasonable answers
    test_inaccurate_spec_single_turn()

    print("\n========== test normal single turn ==========\n")
    # expect reasonable answer
    test_normal_single_turn()

    print("\n========== test spec multi turn ==========\n")
    # expect answer with same format as in the few shot
    test_spec_multi_turn()

    print("\n========== test spec multi turn stream ==========\n")
    # expect error in stream_executor: stream is not supported...
zhyncs's avatar
zhyncs committed
155
    test_spec_multi_turn_stream()