test_srt_engine.py 6.87 KB
Newer Older
Lianmin Zheng's avatar
Lianmin Zheng committed
1
2
"""
Usage:
3
python3 -m unittest test_srt_engine.TestSRTEngine.test_4_sync_async_stream_combination
Lianmin Zheng's avatar
Lianmin Zheng committed
4
5
"""

6
import asyncio
7
8
import json
import unittest
9
from types import SimpleNamespace
10

James Xu's avatar
James Xu committed
11
12
import torch

13
import sglang as sgl
14
15
from sglang.bench_offline_throughput import BenchArgs, throughput_test
from sglang.srt.server_args import ServerArgs
16
from sglang.srt.utils.hf_transformers_utils import get_tokenizer
17
from sglang.test.few_shot_gsm8k_engine import run_eval
Lianmin Zheng's avatar
Lianmin Zheng committed
18
from sglang.test.test_utils import (
James Xu's avatar
James Xu committed
19
    DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST,
Lianmin Zheng's avatar
Lianmin Zheng committed
20
    DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
21
    CustomTestCase,
Lianmin Zheng's avatar
Lianmin Zheng committed
22
)
23
24


25
class TestSRTEngine(CustomTestCase):
26

27
    def test_1_engine_runtime_consistency(self):
28
        prompt = "Today is a sunny day and I like"
Lianmin Zheng's avatar
Lianmin Zheng committed
29
        model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
30
31
32

        sampling_params = {"temperature": 0, "max_new_tokens": 8}

33
        engine = sgl.Engine(model_path=model_path, random_seed=42)
34
35
36
37
38
39
40
41
42
43
44
45
        out1 = engine.generate(prompt, sampling_params)["text"]
        engine.shutdown()

        runtime = sgl.Runtime(model_path=model_path, random_seed=42)
        out2 = json.loads(runtime.generate(prompt, sampling_params))["text"]
        runtime.shutdown()

        print("==== Answer 1 ====")
        print(out1)

        print("==== Answer 2 ====")
        print(out2)
46
        self.assertEqual(out1, out2)
47

48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
    def test_2_engine_runtime_encode_consistency(self):
        prompt = "Today is a sunny day and I like"
        model_path = DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST

        engine = sgl.Engine(model_path=model_path, is_embedding=True, random_seed=42)
        out1 = torch.tensor(engine.encode(prompt)["embedding"])
        engine.shutdown()

        runtime = sgl.Runtime(model_path=model_path, is_embedding=True, random_seed=42)
        out2 = torch.tensor(json.loads(runtime.encode(prompt))["embedding"])
        runtime.shutdown()

        self.assertTrue(torch.allclose(out1, out2, atol=1e-5, rtol=1e-3))

    def test_3_engine_token_ids_consistency(self):
63
64
        # just to ensure there is no issue running multiple generate calls
        prompt = "Today is a sunny day and I like"
Lianmin Zheng's avatar
Lianmin Zheng committed
65
        model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
66
67
        sampling_params = {"temperature": 0, "max_new_tokens": 8}

68
69
70
71
        engine = sgl.Engine(
            model_path=model_path, random_seed=42, disable_radix_cache=True
        )
        out1 = engine.generate(prompt, sampling_params)["text"]
72

73
74
75
76
77
        tokenizer = get_tokenizer(model_path)
        token_ids = tokenizer.encode(prompt)
        out2 = engine.generate(input_ids=token_ids, sampling_params=sampling_params)[
            "text"
        ]
78

79
        engine.shutdown()
80

81
82
        print("==== Answer 1 ====")
        print(out1)
83

84
85
86
        print("==== Answer 2 ====")
        print(out2)
        self.assertEqual(out1, out2)
87

88
89
90
    def test_4_sync_async_stream_combination(self):
        prompt = "AI safety is"
        sampling_params = {"temperature": 0.8, "top_p": 0.95}
91
92
93

        # Create an LLM.
        llm = sgl.Engine(
Lianmin Zheng's avatar
Lianmin Zheng committed
94
            model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
95
96
        )

97
98
99
100
101
102
103
104
105
106
107
108
109
110
        if True:
            # 1. sync + non streaming
            print("\n\n==== 1. sync + non streaming ====")
            output = llm.generate(prompt, sampling_params)
            print(output["text"])

            # 2. sync + streaming
            print("\n\n==== 2. sync + streaming ====")
            output_generator = llm.generate(prompt, sampling_params, stream=True)
            offset = 0
            for output in output_generator:
                print(output["text"][offset:], end="", flush=True)
                offset = len(output["text"])
            print()
111

112
113
114
115
116
117
118
119
        if True:
            loop = asyncio.get_event_loop()
            # 3. async + non_streaming
            print("\n\n==== 3. async + non streaming ====")
            output = loop.run_until_complete(
                llm.async_generate(prompt, sampling_params)
            )
            print(output["text"])
120

121
122
123
124
125
            # 4. async + streaming
            async def async_streaming(engine):
                generator = await engine.async_generate(
                    prompt, sampling_params, stream=True
                )
126

127
128
129
130
131
                offset = 0
                async for output in generator:
                    print(output["text"][offset:], end="", flush=True)
                    offset = len(output["text"])
                print()
132

133
134
            print("\n\n==== 4. async + streaming ====")
            loop.run_until_complete(async_streaming(llm))
135
136
137

        llm.shutdown()

138
    def test_5_gsm8k(self):
139
140

        args = SimpleNamespace(
141
            model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
142
143
            local_data_path=None,
            num_shots=5,
144
            num_questions=1400,
145
146
147
        )

        metrics = run_eval(args)
148
        self.assertGreater(metrics["accuracy"], 0.33)
149

150
    def test_6_engine_cpu_offload(self):
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
        prompt = "Today is a sunny day and I like"
        model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST

        sampling_params = {"temperature": 0, "max_new_tokens": 8}

        engine = sgl.Engine(
            model_path=model_path,
            random_seed=42,
            max_total_tokens=128,
        )
        out1 = engine.generate(prompt, sampling_params)["text"]
        engine.shutdown()

        engine = sgl.Engine(
            model_path=model_path,
            random_seed=42,
            max_total_tokens=128,
            cpu_offload_gb=3,
        )
        out2 = engine.generate(prompt, sampling_params)["text"]
        engine.shutdown()

        print("==== Answer 1 ====")
        print(out1)

        print("==== Answer 2 ====")
        print(out2)
        self.assertEqual(out1, out2)

180
    def test_7_engine_offline_throughput(self):
Lianmin Zheng's avatar
Lianmin Zheng committed
181
182
183
184
185
        server_args = ServerArgs(
            model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
        )
        bench_args = BenchArgs(num_prompts=10)
        result = throughput_test(server_args=server_args, bench_args=bench_args)
186
        self.assertGreater(result["total_throughput"], 3000)
Lianmin Zheng's avatar
Lianmin Zheng committed
187

188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
    def test_8_engine_async_encode_consistency(self):
        prompt = "Today is a sunny day and I like"
        model_path = DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST

        engine = sgl.Engine(
            model_path=model_path,
            is_embedding=True,
            random_seed=42,
            disable_radix_cache=True,
        )

        # Get sync and async embeddings
        out1 = torch.tensor(engine.encode(prompt)["embedding"])
        loop = asyncio.get_event_loop()
        out2 = torch.tensor(
            loop.run_until_complete(engine.async_encode(prompt))["embedding"]
        )

        engine.shutdown()

        print("\n==== Shapes ====")
        print(f"sync shape: {out1.shape}")
        print(f"async shape: {out2.shape}")

        self.assertTrue(
            torch.allclose(out1, out2, atol=1e-5, rtol=1e-3),
            "Sync and async embeddings are not equal within tolerance",
        )

217
218
219

if __name__ == "__main__":
    unittest.main()