test_srt_engine.py 6.02 KB
Newer Older
Lianmin Zheng's avatar
Lianmin Zheng committed
1
2
3
4
5
"""
Usage:
python3 -m unittest test_srt_engine.TestSRTEngine.test_3_sync_streaming_combination
"""

6
import asyncio
7
8
import json
import unittest
9
from types import SimpleNamespace
10

James Xu's avatar
James Xu committed
11
12
import torch

13
import sglang as sgl
14
from sglang.bench_offline_throughput import BenchArgs, throughput_test
15
from sglang.srt.hf_transformers_utils import get_tokenizer
16
from sglang.srt.server_args import ServerArgs
17
from sglang.test.few_shot_gsm8k_engine import run_eval
Lianmin Zheng's avatar
Lianmin Zheng committed
18
from sglang.test.test_utils import (
James Xu's avatar
James Xu committed
19
    DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST,
Lianmin Zheng's avatar
Lianmin Zheng committed
20
21
    DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
)
22
23


Lianmin Zheng's avatar
Lianmin Zheng committed
24
class TestSRTEngine(unittest.TestCase):
25

26
    def test_1_engine_runtime_consistency(self):
27
        prompt = "Today is a sunny day and I like"
Lianmin Zheng's avatar
Lianmin Zheng committed
28
        model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
29
30
31

        sampling_params = {"temperature": 0, "max_new_tokens": 8}

32
        engine = sgl.Engine(model_path=model_path, random_seed=42)
33
34
35
36
37
38
39
40
41
42
43
44
        out1 = engine.generate(prompt, sampling_params)["text"]
        engine.shutdown()

        runtime = sgl.Runtime(model_path=model_path, random_seed=42)
        out2 = json.loads(runtime.generate(prompt, sampling_params))["text"]
        runtime.shutdown()

        print("==== Answer 1 ====")
        print(out1)

        print("==== Answer 2 ====")
        print(out2)
45
        self.assertEqual(out1, out2)
46

47
    def test_2_engine_multiple_generate(self):
48
49
        # just to ensure there is no issue running multiple generate calls
        prompt = "Today is a sunny day and I like"
Lianmin Zheng's avatar
Lianmin Zheng committed
50
        model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
51
52
53

        sampling_params = {"temperature": 0, "max_new_tokens": 8}

54
        engine = sgl.Engine(model_path=model_path, random_seed=42)
55
56
57
58
        engine.generate(prompt, sampling_params)
        engine.generate(prompt, sampling_params)
        engine.shutdown()

59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
    def test_3_sync_streaming_combination(self):

        prompt = "AI safety is..."
        sampling_params = {"temperature": 0.8, "top_p": 0.95}

        async def async_streaming(engine):

            generator = await engine.async_generate(
                prompt, sampling_params, stream=True
            )

            async for output in generator:
                print(output["text"], end="", flush=True)
            print()

        # Create an LLM.
        llm = sgl.Engine(
Lianmin Zheng's avatar
Lianmin Zheng committed
76
            model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
        )

        # 1. sync + non streaming
        print("\n\n==== 1. sync + non streaming ====")
        output = llm.generate(prompt, sampling_params)

        print(output["text"])

        # 2. sync + streaming
        print("\n\n==== 2. sync + streaming ====")
        output_generator = llm.generate(prompt, sampling_params, stream=True)
        for output in output_generator:
            print(output["text"], end="", flush=True)
        print()

        loop = asyncio.get_event_loop()
        # 3. async + non_streaming
        print("\n\n==== 3. async + non streaming ====")
        output = loop.run_until_complete(llm.async_generate(prompt, sampling_params))
        print(output["text"])

        # 4. async + streaming
        print("\n\n==== 4. async + streaming ====")
        loop.run_until_complete(async_streaming(llm))

        llm.shutdown()

    def test_4_gsm8k(self):

        args = SimpleNamespace(
107
            model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
108
109
110
111
112
113
            local_data_path=None,
            num_shots=5,
            num_questions=200,
        )

        metrics = run_eval(args)
114
        self.assertGreater(metrics["accuracy"], 0.3)
115

116
117
118
    def test_5_prompt_input_ids_consistency(self):
        prompt = "The capital of UK is"

Lianmin Zheng's avatar
Lianmin Zheng committed
119
        model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
120
121
122
        engine = sgl.Engine(
            model_path=model_path, random_seed=42, disable_radix_cache=True
        )
123
124
125
126
127
        sampling_params = {"temperature": 0, "max_new_tokens": 8}
        out1 = engine.generate(prompt, sampling_params)["text"]

        tokenizer = get_tokenizer(model_path)
        token_ids = tokenizer.encode(prompt)
Chayenne's avatar
Chayenne committed
128
129
130
        out2 = engine.generate(input_ids=token_ids, sampling_params=sampling_params)[
            "text"
        ]
131
132
133
134
135
136
137
138

        engine.shutdown()

        print("==== Answer 1 ====")
        print(out1)

        print("==== Answer 2 ====")
        print(out2)
139
        self.assertEqual(out1, out2)
140

James Xu's avatar
James Xu committed
141
142
143
144
    def test_6_engine_runtime_encode_consistency(self):
        prompt = "Today is a sunny day and I like"
        model_path = DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST

145
        engine = sgl.Engine(model_path=model_path, is_embedding=True, random_seed=42)
James Xu's avatar
James Xu committed
146
147
148
149
150
151
152
153
154
        out1 = torch.tensor(engine.encode(prompt)["embedding"])
        engine.shutdown()

        runtime = sgl.Runtime(model_path=model_path, is_embedding=True, random_seed=42)
        out2 = torch.tensor(json.loads(runtime.encode(prompt))["embedding"])
        runtime.shutdown()

        self.assertTrue(torch.allclose(out1, out2, atol=1e-5, rtol=1e-3))

155
156
    def test_7_engine_offline_throughput(self):
        server_args = ServerArgs(
157
            model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
158
        )
159
        bench_args = BenchArgs(num_prompts=10)
160
        result = throughput_test(server_args=server_args, bench_args=bench_args)
161
        self.assertGreater(result["total_throughput"], 3500)
162

163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
    def test_8_engine_cpu_offload(self):
        prompt = "Today is a sunny day and I like"
        model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST

        sampling_params = {"temperature": 0, "max_new_tokens": 8}

        engine = sgl.Engine(
            model_path=model_path,
            random_seed=42,
            max_total_tokens=128,
        )
        out1 = engine.generate(prompt, sampling_params)["text"]
        engine.shutdown()

        engine = sgl.Engine(
            model_path=model_path,
            random_seed=42,
            max_total_tokens=128,
            cpu_offload_gb=3,
        )
        out2 = engine.generate(prompt, sampling_params)["text"]
        engine.shutdown()

        print("==== Answer 1 ====")
        print(out1)

        print("==== Answer 2 ====")
        print(out2)
        self.assertEqual(out1, out2)

193
194
195

if __name__ == "__main__":
    unittest.main()