test_tracing.py 3.97 KB
Newer Older
Lianmin Zheng's avatar
Lianmin Zheng committed
1
2
import unittest

Liangsheng Yin's avatar
Liangsheng Yin committed
3
import sglang as sgl
Ying Sheng's avatar
Ying Sheng committed
4
from sglang.lang.backend.base_backend import BaseBackend
Lianmin Zheng's avatar
Lianmin Zheng committed
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
from sglang.lang.chat_template import get_chat_template


class TestTracing(unittest.TestCase):
    def test_few_shot_qa(self):
        @sgl.function
        def few_shot_qa(s, question):
            s += "The following are questions with answers.\n\n"
            s += "Q: What is the capital of France?\n"
            s += "A: Paris\n"
            s += "Q: " + question + "\n"
            s += "A:" + sgl.gen("answer", stop="\n")

        tracer = few_shot_qa.trace()
        print(tracer.last_node.print_graph_dfs() + "\n")

    def test_select(self):
        @sgl.function
        def capital(s):
            s += "The capital of France is"
            s += sgl.select("capital", ["Paris. ", "London. "])
            s += "It is a city" + sgl.gen("description", stop=".")

        tracer = capital.trace()
        print(tracer.last_node.print_graph_dfs() + "\n")

    def test_raise_warning(self):
        @sgl.function
        def wrong(s, question):
            s += f"I want to ask {question}"

        try:
            tracer = wrong.trace()
            raised = False
        except TypeError:
            raised = True

        assert raised

    def test_multi_function(self):
        @sgl.function
        def expand(s, tip):
            s += (
                "Please expand the following tip into a detailed paragraph:"
                + tip
                + "\n"
            )
            s += sgl.gen("detailed_tip")

        @sgl.function
        def tip_suggestion(s, topic):
            s += "Here are 2 tips for " + topic + ".\n"

            s += "1." + sgl.gen("tip_1", stop=["\n", ":", "."]) + "\n"
            s += "2." + sgl.gen("tip_2", stop=["\n", ":", "."]) + "\n"

            branch1 = expand(tip=s["tip_1"])
            branch2 = expand(tip=s["tip_2"])

            s += "Tip 1: " + branch1["detailed_tip"] + "\n"
            s += "Tip 2: " + branch2["detailed_tip"] + "\n"
            s += "In summary" + sgl.gen("summary")

        compiled = tip_suggestion.compile()
        compiled.print_graph()

        sgl.set_default_backend(sgl.OpenAI("gpt-3.5-turbo-instruct"))
        state = compiled.run(topic="staying healthy")
        print(state.text() + "\n")

        states = compiled.run_batch(
            [
                {"topic": "staying healthy"},
                {"topic": "staying happy"},
                {"topic": "earning money"},
            ],
            temperature=0,
        )
        for s in states:
            print(s.text() + "\n")

    def test_role(self):
        @sgl.function
        def multi_turn_chat(s):
            s += sgl.user("Who are you?")
            s += sgl.assistant(sgl.gen("answer_1"))
            s += sgl.user("Who created you?")
            s += sgl.assistant(sgl.gen("answer_2"))

        backend = BaseBackend()
        backend.chat_template = get_chat_template("llama-2-chat")

        compiled = multi_turn_chat.compile(backend=backend)
        compiled.print_graph()

    def test_fork(self):
        @sgl.function
        def tip_suggestion(s):
            s += (
                "Here are three tips for staying healthy: "
                "1. Balanced Diet; "
                "2. Regular Exercise; "
                "3. Adequate Sleep\n"
            )

            forks = s.fork(3)
            for i in range(3):
                forks[i] += f"Now, expand tip {i+1} into a paragraph:\n"
113
                forks[i] += sgl.gen(f"detailed_tip")
Lianmin Zheng's avatar
Lianmin Zheng committed
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130

            s += "Tip 1:" + forks[0]["detailed_tip"] + "\n"
            s += "Tip 2:" + forks[1]["detailed_tip"] + "\n"
            s += "Tip 3:" + forks[2]["detailed_tip"] + "\n"
            s += "In summary" + sgl.gen("summary")

        tracer = tip_suggestion.trace()
        print(tracer.last_node.print_graph_dfs())

        a = tip_suggestion.run(backend=sgl.OpenAI("gpt-3.5-turbo-instruct"))
        print(a.text())


if __name__ == "__main__":
    unittest.main(warnings="ignore")

    # t = TestTracing()
Lianmin Zheng's avatar
Lianmin Zheng committed
131
    # t.test_multi_function()