test_update_weights_from_disk.py 11.5 KB
Newer Older
1
import json
2
import random
3
4
5
6
import unittest

import requests

7
import sglang as sgl
8
from sglang.srt.utils import kill_process_tree
9
from sglang.test.test_utils import (
Lianmin Zheng's avatar
Lianmin Zheng committed
10
    DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
11
12
    DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
    DEFAULT_URL_FOR_TEST,
13
    is_in_ci,
14
15
16
17
    popen_launch_server,
)


18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
###############################################################################
# Engine Mode Tests (Single-configuration)
###############################################################################
class TestEngineUpdateWeightsFromDisk(unittest.TestCase):
    def setUp(self):
        self.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
        # Initialize the engine in offline (direct) mode.
        self.engine = sgl.Engine(model_path=self.model)

    def tearDown(self):
        self.engine.shutdown()

    def run_decode(self):
        prompts = ["The capital of France is"]
        sampling_params = {"temperature": 0, "max_new_tokens": 32}
        outputs = self.engine.generate(prompts, sampling_params)
        print("=" * 100)
        print(
            f"[Engine Mode] Prompt: {prompts[0]}\nGenerated text: {outputs[0]['text']}"
        )
        return outputs[0]["text"]

    def run_update_weights(self, model_path):
        ret = self.engine.update_weights_from_disk(model_path)
        print(json.dumps(ret))
        return ret

    def test_update_weights(self):
        origin_response = self.run_decode()
        # Update weights: use new model (remove "-Instruct")
        new_model_path = self.model.replace("-Instruct", "")
        ret = self.run_update_weights(new_model_path)
        self.assertTrue(ret[0])  # ret is a tuple; index 0 holds the success flag

        updated_response = self.run_decode()
        self.assertNotEqual(origin_response[:32], updated_response[:32])

        # Revert back to original weights
        ret = self.run_update_weights(self.model)
        self.assertTrue(ret[0])
        reverted_response = self.run_decode()
        self.assertEqual(origin_response[:32], reverted_response[:32])

    def test_update_weights_unexist_model(self):
        origin_response = self.run_decode()
        new_model_path = self.model.replace("-Instruct", "wrong")
        ret = self.run_update_weights(new_model_path)
        self.assertFalse(ret[0])
        updated_response = self.run_decode()
        self.assertEqual(origin_response[:32], updated_response[:32])


###############################################################################
# HTTP Server Mode Tests (Single-configuration)
###############################################################################
class TestServerUpdateWeightsFromDisk(unittest.TestCase):
74
75
    @classmethod
    def setUpClass(cls):
Lianmin Zheng's avatar
Lianmin Zheng committed
76
        cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
77
78
79
80
        cls.base_url = DEFAULT_URL_FOR_TEST
        cls.process = popen_launch_server(
            cls.model, cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
        )
81
82
83

    @classmethod
    def tearDownClass(cls):
84
        kill_process_tree(cls.process.pid)
85
86
87
88
89
90

    def run_decode(self):
        response = requests.post(
            self.base_url + "/generate",
            json={
                "text": "The capital of France is",
91
                "sampling_params": {"temperature": 0, "max_new_tokens": 32},
92
93
94
            },
        )
        print("=" * 100)
95
96
        print(f"[Server Mode] Generated text: {response.json()['text']}")
        return response.json()["text"]
97
98
99
100
101
102
103
104
105

    def get_model_info(self):
        response = requests.get(self.base_url + "/get_model_info")
        model_path = response.json()["model_path"]
        print(json.dumps(response.json()))
        return model_path

    def run_update_weights(self, model_path):
        response = requests.post(
Chayenne's avatar
Chayenne committed
106
            self.base_url + "/update_weights_from_disk",
107
            json={"model_path": model_path},
108
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
109
        ret = response.json()
110
        print(json.dumps(ret))
Lianmin Zheng's avatar
Lianmin Zheng committed
111
        return ret
112

113
    def test_update_weights(self):
114
        origin_model_path = self.get_model_info()
115
        print(f"[Server Mode] origin_model_path: {origin_model_path}")
116
117
        origin_response = self.run_decode()

Lianmin Zheng's avatar
Lianmin Zheng committed
118
        new_model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST.replace("-Instruct", "")
Lianmin Zheng's avatar
Lianmin Zheng committed
119
        ret = self.run_update_weights(new_model_path)
120
        self.assertTrue(ret["success"])
121
122

        updated_model_path = self.get_model_info()
123
124
125
        print(f"[Server Mode] updated_model_path: {updated_model_path}")
        self.assertEqual(updated_model_path, new_model_path)
        self.assertNotEqual(updated_model_path, origin_model_path)
126
127

        updated_response = self.run_decode()
128
        self.assertNotEqual(origin_response[:32], updated_response[:32])
129

Lianmin Zheng's avatar
Lianmin Zheng committed
130
        ret = self.run_update_weights(origin_model_path)
131
        self.assertTrue(ret["success"])
132
        updated_model_path = self.get_model_info()
133
        self.assertEqual(updated_model_path, origin_model_path)
134
135

        updated_response = self.run_decode()
136
        self.assertEqual(origin_response[:32], updated_response[:32])
137

138
    def test_update_weights_unexist_model(self):
139
        origin_model_path = self.get_model_info()
140
        print(f"[Server Mode] origin_model_path: {origin_model_path}")
141
142
        origin_response = self.run_decode()

Lianmin Zheng's avatar
Lianmin Zheng committed
143
        new_model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST.replace("-Instruct", "wrong")
Lianmin Zheng's avatar
Lianmin Zheng committed
144
        ret = self.run_update_weights(new_model_path)
145
        self.assertFalse(ret["success"])
146
147

        updated_model_path = self.get_model_info()
148
149
        print(f"[Server Mode] updated_model_path: {updated_model_path}")
        self.assertEqual(updated_model_path, origin_model_path)
150
151

        updated_response = self.run_decode()
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
        self.assertEqual(origin_response[:32], updated_response[:32])


###############################################################################
# Parameterized Tests for update_weights_from_disk
# Test coverage is determined based on the value of is_in_ci:
# - In a CI environment: randomly select one mode (Engine or Server) and test only with tp=1, dp=1.
# - In a non-CI environment: test both Engine and Server modes, and enumerate all combinations
#   with tp and dp ranging from 1 to 2.
###############################################################################
class TestUpdateWeightsFromDiskParameterized(unittest.TestCase):
    def run_common_test(self, mode, tp, dp):
        """
        Common test procedure for update_weights_from_disk.
        For Engine mode, we instantiate the engine with tp_size=tp.
        For Server mode, we launch the server with additional arguments for tp (dp is not used in server launch here).
        """
        if mode == "Engine":
            # Instantiate engine with additional parameter tp_size.
            print(f"[Parameterized Engine] Testing with tp={tp}, dp={dp}")
            engine = sgl.Engine(
                model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
                random_seed=42,
                tp_size=tp,
                # dp parameter is not explicitly used in this API.
            )
            try:
                origin_response = self._engine_update_weights_test(engine)
            finally:
                engine.shutdown()
        elif mode == "Server":
            print(f"[Parameterized Server] Testing with tp={tp}, dp={dp}")
            # Pass additional arguments to launch the server.
            base_args = ["--tp-size", str(tp)]
            process = popen_launch_server(
                DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
                DEFAULT_URL_FOR_TEST,
                timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
                other_args=base_args,
            )
            try:
                origin_response = self._server_update_weights_test(DEFAULT_URL_FOR_TEST)
            finally:
                kill_process_tree(process.pid)
        else:
            raise ValueError(f"Unknown mode: {mode}")

    def _engine_update_weights_test(self, engine):
        # Run the update weights test on the given engine instance.
        def run_decode():
            prompts = ["The capital of France is"]
            sampling_params = {"temperature": 0, "max_new_tokens": 32}
            outputs = engine.generate(prompts, sampling_params)
            print("=" * 100)
            print(
                f"[Parameterized Engine] Prompt: {prompts[0]}\nGenerated text: {outputs[0]['text']}"
            )
            return outputs[0]["text"]

        def run_update_weights(model_path):
            ret = engine.update_weights_from_disk(model_path)
            print(json.dumps(ret))
            return ret

        origin_response = run_decode()
        new_model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST.replace("-Instruct", "")
        ret = run_update_weights(new_model_path)
        self.assertTrue(ret[0])
        updated_response = run_decode()
        self.assertNotEqual(origin_response[:32], updated_response[:32])
        ret = run_update_weights(DEFAULT_SMALL_MODEL_NAME_FOR_TEST)
        self.assertTrue(ret[0])
        reverted_response = run_decode()
        self.assertEqual(origin_response[:32], reverted_response[:32])
        return origin_response

    def _server_update_weights_test(self, base_url):
        def run_decode():
            response = requests.post(
                base_url + "/generate",
                json={
                    "text": "The capital of France is",
                    "sampling_params": {"temperature": 0, "max_new_tokens": 32},
                },
            )
            print("=" * 100)
            print(f"[Parameterized Server] Generated text: {response.json()['text']}")
            return response.json()["text"]

        def get_model_info():
            response = requests.get(base_url + "/get_model_info")
            model_path = response.json()["model_path"]
            print(json.dumps(response.json()))
            return model_path

        def run_update_weights(model_path):
            response = requests.post(
                base_url + "/update_weights_from_disk",
                json={"model_path": model_path},
            )
            ret = response.json()
            print(json.dumps(ret))
            return ret

        origin_model_path = get_model_info()
        origin_response = run_decode()
        new_model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST.replace("-Instruct", "")
        ret = run_update_weights(new_model_path)
        self.assertTrue(ret["success"])
        updated_model_path = get_model_info()
        self.assertEqual(updated_model_path, new_model_path)
        self.assertNotEqual(updated_model_path, origin_model_path)
        updated_response = run_decode()
        self.assertNotEqual(origin_response[:32], updated_response[:32])
        ret = run_update_weights(origin_model_path)
        self.assertTrue(ret["success"])
        updated_model_path = get_model_info()
        self.assertEqual(updated_model_path, origin_model_path)
        reverted_response = run_decode()
        self.assertEqual(origin_response[:32], reverted_response[:32])
        return origin_response

    def test_parameterized_update_weights(self):
        if is_in_ci():
            # In CI, choose one random mode (Engine or Server) with tp=1, dp=1.
            mode = random.choice(["Engine", "Server"])
            test_suits = [(1, 1, mode)]
        else:
            # Otherwise, test both modes and enumerate tp,dp combinations from 1 to 2.
            test_suits = []
            for mode in ["Engine", "Server"]:
                for tp in [1, 2]:
                    for dp in [1, 2]:
                        test_suits.append((tp, dp, mode))
        for tp, dp, mode in test_suits:
            with self.subTest(mode=mode, tp=tp, dp=dp):
                self.run_common_test(mode, tp, dp)
289
290
291
292


if __name__ == "__main__":
    unittest.main()