test_bind_cache.py 1.47 KB
Newer Older
Lianmin Zheng's avatar
Lianmin Zheng committed
1
2
3
import unittest

import sglang as sgl
Yineng Zhang's avatar
Yineng Zhang committed
4
from sglang.test.test_utils import MODEL_NAME_FOR_TEST
Lianmin Zheng's avatar
Lianmin Zheng committed
5
6
7
8
9


class TestBind(unittest.TestCase):
    backend = None

10
11
    @classmethod
    def setUpClass(cls):
Yineng Zhang's avatar
Yineng Zhang committed
12
        cls.backend = sgl.Runtime(model_path=MODEL_NAME_FOR_TEST)
13
        sgl.set_default_backend(cls.backend)
Lianmin Zheng's avatar
Lianmin Zheng committed
14

15
16
17
    @classmethod
    def tearDownClass(cls):
        cls.backend.shutdown()
Lianmin Zheng's avatar
Lianmin Zheng committed
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34

    def test_bind(self):
        @sgl.function
        def few_shot_qa(s, prompt, question):
            s += prompt
            s += "Q: What is the capital of France?\n"
            s += "A: Paris\n"
            s += "Q: " + question + "\n"
            s += "A:" + sgl.gen("answer", stop="\n")

        few_shot_qa_2 = few_shot_qa.bind(
            prompt="The following are questions with answers.\n\n"
        )

        tracer = few_shot_qa_2.trace()
        print(tracer.last_node.print_graph_dfs() + "\n")

35
    def test_cache(self):
Lianmin Zheng's avatar
Lianmin Zheng committed
36
37
38
39
40
41
42
43
44
45
46
        @sgl.function
        def few_shot_qa(s, prompt, question):
            s += prompt
            s += "Q: What is the capital of France?\n"
            s += "A: Paris\n"
            s += "Q: " + question + "\n"
            s += "A:" + sgl.gen("answer", stop="\n")

        few_shot_qa_2 = few_shot_qa.bind(
            prompt="Answer the following questions as if you were a 5-year-old kid.\n\n"
        )
47
        few_shot_qa_2.cache(self.backend)
Lianmin Zheng's avatar
Lianmin Zheng committed
48
49
50
51
52
53


if __name__ == "__main__":
    unittest.main(warnings="ignore")

    # t = TestBind()
54
    # t.setUpClass()
55
    # t.test_cache()