test_update_weights_from_disk.py 14.2 KB
Newer Older
1
import json
2
import random
3
import time
4
import unittest
5
from concurrent.futures import ThreadPoolExecutor, as_completed
6
7
8

import requests

9
import sglang as sgl
10
from sglang.srt.utils import kill_process_tree
11
from sglang.test.test_utils import (
Lianmin Zheng's avatar
Lianmin Zheng committed
12
    DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
13
14
    DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
    DEFAULT_URL_FOR_TEST,
15
    CustomTestCase,
16
    is_in_ci,
17
18
19
20
    popen_launch_server,
)


21
22
23
###############################################################################
# Engine Mode Tests (Single-configuration)
###############################################################################
24
class TestEngineUpdateWeightsFromDisk(CustomTestCase):
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
    def setUp(self):
        self.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
        # Initialize the engine in offline (direct) mode.
        self.engine = sgl.Engine(model_path=self.model)

    def tearDown(self):
        self.engine.shutdown()

    def run_decode(self):
        prompts = ["The capital of France is"]
        sampling_params = {"temperature": 0, "max_new_tokens": 32}
        outputs = self.engine.generate(prompts, sampling_params)
        print("=" * 100)
        print(
            f"[Engine Mode] Prompt: {prompts[0]}\nGenerated text: {outputs[0]['text']}"
        )
        return outputs[0]["text"]

    def run_update_weights(self, model_path):
        ret = self.engine.update_weights_from_disk(model_path)
        print(json.dumps(ret))
        return ret

    def test_update_weights(self):
        origin_response = self.run_decode()
        # Update weights: use new model (remove "-Instruct")
        new_model_path = self.model.replace("-Instruct", "")
        ret = self.run_update_weights(new_model_path)
        self.assertTrue(ret[0])  # ret is a tuple; index 0 holds the success flag

        updated_response = self.run_decode()
        self.assertNotEqual(origin_response[:32], updated_response[:32])

        # Revert back to original weights
        ret = self.run_update_weights(self.model)
        self.assertTrue(ret[0])
        reverted_response = self.run_decode()
        self.assertEqual(origin_response[:32], reverted_response[:32])

    def test_update_weights_unexist_model(self):
        origin_response = self.run_decode()
        new_model_path = self.model.replace("-Instruct", "wrong")
        ret = self.run_update_weights(new_model_path)
        self.assertFalse(ret[0])
        updated_response = self.run_decode()
        self.assertEqual(origin_response[:32], updated_response[:32])


###############################################################################
# HTTP Server Mode Tests (Single-configuration)
###############################################################################
76
class TestServerUpdateWeightsFromDisk(CustomTestCase):
77
78
    @classmethod
    def setUpClass(cls):
Lianmin Zheng's avatar
Lianmin Zheng committed
79
        cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
80
81
82
83
        cls.base_url = DEFAULT_URL_FOR_TEST
        cls.process = popen_launch_server(
            cls.model, cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
        )
84
85
86

    @classmethod
    def tearDownClass(cls):
87
        kill_process_tree(cls.process.pid)
88
89
90
91
92
93

    def run_decode(self):
        response = requests.post(
            self.base_url + "/generate",
            json={
                "text": "The capital of France is",
94
                "sampling_params": {"temperature": 0, "max_new_tokens": 32},
95
96
97
            },
        )
        print("=" * 100)
98
99
        print(f"[Server Mode] Generated text: {response.json()['text']}")
        return response.json()["text"]
100
101
102
103
104
105
106
107
108

    def get_model_info(self):
        response = requests.get(self.base_url + "/get_model_info")
        model_path = response.json()["model_path"]
        print(json.dumps(response.json()))
        return model_path

    def run_update_weights(self, model_path):
        response = requests.post(
Chayenne's avatar
Chayenne committed
109
            self.base_url + "/update_weights_from_disk",
110
            json={"model_path": model_path},
111
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
112
        ret = response.json()
113
        print(json.dumps(ret))
Lianmin Zheng's avatar
Lianmin Zheng committed
114
        return ret
115

116
    def test_update_weights(self):
117
        origin_model_path = self.get_model_info()
118
        print(f"[Server Mode] origin_model_path: {origin_model_path}")
119
120
        origin_response = self.run_decode()

Lianmin Zheng's avatar
Lianmin Zheng committed
121
        new_model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST.replace("-Instruct", "")
Lianmin Zheng's avatar
Lianmin Zheng committed
122
        ret = self.run_update_weights(new_model_path)
123
        self.assertTrue(ret["success"])
124
125

        updated_model_path = self.get_model_info()
126
127
128
        print(f"[Server Mode] updated_model_path: {updated_model_path}")
        self.assertEqual(updated_model_path, new_model_path)
        self.assertNotEqual(updated_model_path, origin_model_path)
129
130

        updated_response = self.run_decode()
131
        self.assertNotEqual(origin_response[:32], updated_response[:32])
132

Lianmin Zheng's avatar
Lianmin Zheng committed
133
        ret = self.run_update_weights(origin_model_path)
134
        self.assertTrue(ret["success"])
135
        updated_model_path = self.get_model_info()
136
        self.assertEqual(updated_model_path, origin_model_path)
137
138

        updated_response = self.run_decode()
139
        self.assertEqual(origin_response[:32], updated_response[:32])
140

141
    def test_update_weights_unexist_model(self):
142
        origin_model_path = self.get_model_info()
143
        print(f"[Server Mode] origin_model_path: {origin_model_path}")
144
145
        origin_response = self.run_decode()

Lianmin Zheng's avatar
Lianmin Zheng committed
146
        new_model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST.replace("-Instruct", "wrong")
Lianmin Zheng's avatar
Lianmin Zheng committed
147
        ret = self.run_update_weights(new_model_path)
148
        self.assertFalse(ret["success"])
149
150

        updated_model_path = self.get_model_info()
151
152
        print(f"[Server Mode] updated_model_path: {updated_model_path}")
        self.assertEqual(updated_model_path, origin_model_path)
153
154

        updated_response = self.run_decode()
155
156
157
        self.assertEqual(origin_response[:32], updated_response[:32])


158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
class TestServerUpdateWeightsFromDiskAbortAllRequests(CustomTestCase):
    @classmethod
    def setUpClass(cls):
        cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
        cls.base_url = DEFAULT_URL_FOR_TEST
        cls.process = popen_launch_server(
            cls.model,
            cls.base_url,
            timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
            other_args=["--max-running-requests", 8],
        )

    @classmethod
    def tearDownClass(cls):
        kill_process_tree(cls.process.pid)

    def run_decode(self, max_new_tokens=32):
        response = requests.post(
            self.base_url + "/generate",
            json={
                "text": "The capital of France is",
                "sampling_params": {
                    "temperature": 0,
                    "max_new_tokens": max_new_tokens,
                    "ignore_eos": True,
                },
            },
        )
        return response.json()

    def get_model_info(self):
        response = requests.get(self.base_url + "/get_model_info")
        model_path = response.json()["model_path"]
        print(json.dumps(response.json()))
        return model_path

    def run_update_weights(self, model_path, abort_all_requests=False):
        response = requests.post(
            self.base_url + "/update_weights_from_disk",
            json={
                "model_path": model_path,
                "abort_all_requests": abort_all_requests,
            },
        )
        ret = response.json()
        print(json.dumps(ret))
        return ret

    def test_update_weights_abort_all_requests(self):
        origin_model_path = self.get_model_info()
        print(f"[Server Mode] origin_model_path: {origin_model_path}")

        num_requests = 32
        with ThreadPoolExecutor(num_requests) as executor:
            futures = [
                executor.submit(self.run_decode, 16000) for _ in range(num_requests)
            ]

            # ensure the decode has been started
            time.sleep(2)

            new_model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST.replace("-Instruct", "")
            ret = self.run_update_weights(new_model_path, abort_all_requests=True)
            self.assertTrue(ret["success"])

            for future in as_completed(futures):
                self.assertEqual(
                    future.result()["meta_info"]["finish_reason"]["type"], "abort"
                )

        updated_model_path = self.get_model_info()
        print(f"[Server Mode] updated_model_path: {updated_model_path}")
        self.assertEqual(updated_model_path, new_model_path)
        self.assertNotEqual(updated_model_path, origin_model_path)


234
235
236
237
238
239
240
###############################################################################
# Parameterized Tests for update_weights_from_disk
# Test coverage is determined based on the value of is_in_ci:
# - In a CI environment: randomly select one mode (Engine or Server) and test only with tp=1, dp=1.
# - In a non-CI environment: test both Engine and Server modes, and enumerate all combinations
#   with tp and dp ranging from 1 to 2.
###############################################################################
241
class TestUpdateWeightsFromDiskParameterized(CustomTestCase):
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
    def run_common_test(self, mode, tp, dp):
        """
        Common test procedure for update_weights_from_disk.
        For Engine mode, we instantiate the engine with tp_size=tp.
        For Server mode, we launch the server with additional arguments for tp (dp is not used in server launch here).
        """
        if mode == "Engine":
            # Instantiate engine with additional parameter tp_size.
            print(f"[Parameterized Engine] Testing with tp={tp}, dp={dp}")
            engine = sgl.Engine(
                model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
                random_seed=42,
                tp_size=tp,
                # dp parameter is not explicitly used in this API.
            )
            try:
                origin_response = self._engine_update_weights_test(engine)
            finally:
                engine.shutdown()
        elif mode == "Server":
            print(f"[Parameterized Server] Testing with tp={tp}, dp={dp}")
            # Pass additional arguments to launch the server.
            base_args = ["--tp-size", str(tp)]
            process = popen_launch_server(
                DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
                DEFAULT_URL_FOR_TEST,
                timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
                other_args=base_args,
            )
            try:
                origin_response = self._server_update_weights_test(DEFAULT_URL_FOR_TEST)
            finally:
                kill_process_tree(process.pid)
        else:
            raise ValueError(f"Unknown mode: {mode}")

    def _engine_update_weights_test(self, engine):
        # Run the update weights test on the given engine instance.
        def run_decode():
            prompts = ["The capital of France is"]
            sampling_params = {"temperature": 0, "max_new_tokens": 32}
            outputs = engine.generate(prompts, sampling_params)
            print("=" * 100)
            print(
                f"[Parameterized Engine] Prompt: {prompts[0]}\nGenerated text: {outputs[0]['text']}"
            )
            return outputs[0]["text"]

        def run_update_weights(model_path):
            ret = engine.update_weights_from_disk(model_path)
            print(json.dumps(ret))
            return ret

        origin_response = run_decode()
        new_model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST.replace("-Instruct", "")
        ret = run_update_weights(new_model_path)
        self.assertTrue(ret[0])
        updated_response = run_decode()
        self.assertNotEqual(origin_response[:32], updated_response[:32])
        ret = run_update_weights(DEFAULT_SMALL_MODEL_NAME_FOR_TEST)
        self.assertTrue(ret[0])
        reverted_response = run_decode()
        self.assertEqual(origin_response[:32], reverted_response[:32])
        return origin_response

    def _server_update_weights_test(self, base_url):
        def run_decode():
            response = requests.post(
                base_url + "/generate",
                json={
                    "text": "The capital of France is",
                    "sampling_params": {"temperature": 0, "max_new_tokens": 32},
                },
            )
            print("=" * 100)
            print(f"[Parameterized Server] Generated text: {response.json()['text']}")
            return response.json()["text"]

        def get_model_info():
            response = requests.get(base_url + "/get_model_info")
            model_path = response.json()["model_path"]
            print(json.dumps(response.json()))
            return model_path

        def run_update_weights(model_path):
            response = requests.post(
                base_url + "/update_weights_from_disk",
                json={"model_path": model_path},
            )
            ret = response.json()
            print(json.dumps(ret))
            return ret

        origin_model_path = get_model_info()
        origin_response = run_decode()
        new_model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST.replace("-Instruct", "")
        ret = run_update_weights(new_model_path)
        self.assertTrue(ret["success"])
        updated_model_path = get_model_info()
        self.assertEqual(updated_model_path, new_model_path)
        self.assertNotEqual(updated_model_path, origin_model_path)
        updated_response = run_decode()
        self.assertNotEqual(origin_response[:32], updated_response[:32])
        ret = run_update_weights(origin_model_path)
        self.assertTrue(ret["success"])
        updated_model_path = get_model_info()
        self.assertEqual(updated_model_path, origin_model_path)
        reverted_response = run_decode()
        self.assertEqual(origin_response[:32], reverted_response[:32])
        return origin_response

    def test_parameterized_update_weights(self):
        if is_in_ci():
            # In CI, choose one random mode (Engine or Server) with tp=1, dp=1.
            mode = random.choice(["Engine", "Server"])
            test_suits = [(1, 1, mode)]
        else:
            # Otherwise, test both modes and enumerate tp,dp combinations from 1 to 2.
            test_suits = []
            for mode in ["Engine", "Server"]:
                for tp in [1, 2]:
                    for dp in [1, 2]:
                        test_suits.append((tp, dp, mode))
        for tp, dp, mode in test_suits:
            with self.subTest(mode=mode, tp=tp, dp=dp):
                self.run_common_test(mode, tp, dp)
368
369
370
371


if __name__ == "__main__":
    unittest.main()