test_update_weights_from_disk.py 11.5 KB
Newer Older
1
import json
2
import random
3
4
5
6
import unittest

import requests

7
import sglang as sgl
8
from sglang.srt.utils import kill_process_tree
9
from sglang.test.test_utils import (
Lianmin Zheng's avatar
Lianmin Zheng committed
10
    DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
11
12
    DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
    DEFAULT_URL_FOR_TEST,
13
    CustomTestCase,
14
    is_in_ci,
15
16
17
18
    popen_launch_server,
)


19
20
21
###############################################################################
# Engine Mode Tests (Single-configuration)
###############################################################################
22
class TestEngineUpdateWeightsFromDisk(CustomTestCase):
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
    def setUp(self):
        self.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
        # Initialize the engine in offline (direct) mode.
        self.engine = sgl.Engine(model_path=self.model)

    def tearDown(self):
        self.engine.shutdown()

    def run_decode(self):
        prompts = ["The capital of France is"]
        sampling_params = {"temperature": 0, "max_new_tokens": 32}
        outputs = self.engine.generate(prompts, sampling_params)
        print("=" * 100)
        print(
            f"[Engine Mode] Prompt: {prompts[0]}\nGenerated text: {outputs[0]['text']}"
        )
        return outputs[0]["text"]

    def run_update_weights(self, model_path):
        ret = self.engine.update_weights_from_disk(model_path)
        print(json.dumps(ret))
        return ret

    def test_update_weights(self):
        origin_response = self.run_decode()
        # Update weights: use new model (remove "-Instruct")
        new_model_path = self.model.replace("-Instruct", "")
        ret = self.run_update_weights(new_model_path)
        self.assertTrue(ret[0])  # ret is a tuple; index 0 holds the success flag

        updated_response = self.run_decode()
        self.assertNotEqual(origin_response[:32], updated_response[:32])

        # Revert back to original weights
        ret = self.run_update_weights(self.model)
        self.assertTrue(ret[0])
        reverted_response = self.run_decode()
        self.assertEqual(origin_response[:32], reverted_response[:32])

    def test_update_weights_unexist_model(self):
        origin_response = self.run_decode()
        new_model_path = self.model.replace("-Instruct", "wrong")
        ret = self.run_update_weights(new_model_path)
        self.assertFalse(ret[0])
        updated_response = self.run_decode()
        self.assertEqual(origin_response[:32], updated_response[:32])


###############################################################################
# HTTP Server Mode Tests (Single-configuration)
###############################################################################
74
class TestServerUpdateWeightsFromDisk(CustomTestCase):
75
76
    @classmethod
    def setUpClass(cls):
Lianmin Zheng's avatar
Lianmin Zheng committed
77
        cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
78
79
80
81
        cls.base_url = DEFAULT_URL_FOR_TEST
        cls.process = popen_launch_server(
            cls.model, cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
        )
82
83
84

    @classmethod
    def tearDownClass(cls):
85
        kill_process_tree(cls.process.pid)
86
87
88
89
90
91

    def run_decode(self):
        response = requests.post(
            self.base_url + "/generate",
            json={
                "text": "The capital of France is",
92
                "sampling_params": {"temperature": 0, "max_new_tokens": 32},
93
94
95
            },
        )
        print("=" * 100)
96
97
        print(f"[Server Mode] Generated text: {response.json()['text']}")
        return response.json()["text"]
98
99
100
101
102
103
104
105
106

    def get_model_info(self):
        response = requests.get(self.base_url + "/get_model_info")
        model_path = response.json()["model_path"]
        print(json.dumps(response.json()))
        return model_path

    def run_update_weights(self, model_path):
        response = requests.post(
Chayenne's avatar
Chayenne committed
107
            self.base_url + "/update_weights_from_disk",
108
            json={"model_path": model_path},
109
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
110
        ret = response.json()
111
        print(json.dumps(ret))
Lianmin Zheng's avatar
Lianmin Zheng committed
112
        return ret
113

114
    def test_update_weights(self):
115
        origin_model_path = self.get_model_info()
116
        print(f"[Server Mode] origin_model_path: {origin_model_path}")
117
118
        origin_response = self.run_decode()

Lianmin Zheng's avatar
Lianmin Zheng committed
119
        new_model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST.replace("-Instruct", "")
Lianmin Zheng's avatar
Lianmin Zheng committed
120
        ret = self.run_update_weights(new_model_path)
121
        self.assertTrue(ret["success"])
122
123

        updated_model_path = self.get_model_info()
124
125
126
        print(f"[Server Mode] updated_model_path: {updated_model_path}")
        self.assertEqual(updated_model_path, new_model_path)
        self.assertNotEqual(updated_model_path, origin_model_path)
127
128

        updated_response = self.run_decode()
129
        self.assertNotEqual(origin_response[:32], updated_response[:32])
130

Lianmin Zheng's avatar
Lianmin Zheng committed
131
        ret = self.run_update_weights(origin_model_path)
132
        self.assertTrue(ret["success"])
133
        updated_model_path = self.get_model_info()
134
        self.assertEqual(updated_model_path, origin_model_path)
135
136

        updated_response = self.run_decode()
137
        self.assertEqual(origin_response[:32], updated_response[:32])
138

139
    def test_update_weights_unexist_model(self):
140
        origin_model_path = self.get_model_info()
141
        print(f"[Server Mode] origin_model_path: {origin_model_path}")
142
143
        origin_response = self.run_decode()

Lianmin Zheng's avatar
Lianmin Zheng committed
144
        new_model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST.replace("-Instruct", "wrong")
Lianmin Zheng's avatar
Lianmin Zheng committed
145
        ret = self.run_update_weights(new_model_path)
146
        self.assertFalse(ret["success"])
147
148

        updated_model_path = self.get_model_info()
149
150
        print(f"[Server Mode] updated_model_path: {updated_model_path}")
        self.assertEqual(updated_model_path, origin_model_path)
151
152

        updated_response = self.run_decode()
153
154
155
156
157
158
159
160
161
162
        self.assertEqual(origin_response[:32], updated_response[:32])


###############################################################################
# Parameterized Tests for update_weights_from_disk
# Test coverage is determined based on the value of is_in_ci:
# - In a CI environment: randomly select one mode (Engine or Server) and test only with tp=1, dp=1.
# - In a non-CI environment: test both Engine and Server modes, and enumerate all combinations
#   with tp and dp ranging from 1 to 2.
###############################################################################
163
class TestUpdateWeightsFromDiskParameterized(CustomTestCase):
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
    def run_common_test(self, mode, tp, dp):
        """
        Common test procedure for update_weights_from_disk.
        For Engine mode, we instantiate the engine with tp_size=tp.
        For Server mode, we launch the server with additional arguments for tp (dp is not used in server launch here).
        """
        if mode == "Engine":
            # Instantiate engine with additional parameter tp_size.
            print(f"[Parameterized Engine] Testing with tp={tp}, dp={dp}")
            engine = sgl.Engine(
                model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
                random_seed=42,
                tp_size=tp,
                # dp parameter is not explicitly used in this API.
            )
            try:
                origin_response = self._engine_update_weights_test(engine)
            finally:
                engine.shutdown()
        elif mode == "Server":
            print(f"[Parameterized Server] Testing with tp={tp}, dp={dp}")
            # Pass additional arguments to launch the server.
            base_args = ["--tp-size", str(tp)]
            process = popen_launch_server(
                DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
                DEFAULT_URL_FOR_TEST,
                timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
                other_args=base_args,
            )
            try:
                origin_response = self._server_update_weights_test(DEFAULT_URL_FOR_TEST)
            finally:
                kill_process_tree(process.pid)
        else:
            raise ValueError(f"Unknown mode: {mode}")

    def _engine_update_weights_test(self, engine):
        # Run the update weights test on the given engine instance.
        def run_decode():
            prompts = ["The capital of France is"]
            sampling_params = {"temperature": 0, "max_new_tokens": 32}
            outputs = engine.generate(prompts, sampling_params)
            print("=" * 100)
            print(
                f"[Parameterized Engine] Prompt: {prompts[0]}\nGenerated text: {outputs[0]['text']}"
            )
            return outputs[0]["text"]

        def run_update_weights(model_path):
            ret = engine.update_weights_from_disk(model_path)
            print(json.dumps(ret))
            return ret

        origin_response = run_decode()
        new_model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST.replace("-Instruct", "")
        ret = run_update_weights(new_model_path)
        self.assertTrue(ret[0])
        updated_response = run_decode()
        self.assertNotEqual(origin_response[:32], updated_response[:32])
        ret = run_update_weights(DEFAULT_SMALL_MODEL_NAME_FOR_TEST)
        self.assertTrue(ret[0])
        reverted_response = run_decode()
        self.assertEqual(origin_response[:32], reverted_response[:32])
        return origin_response

    def _server_update_weights_test(self, base_url):
        def run_decode():
            response = requests.post(
                base_url + "/generate",
                json={
                    "text": "The capital of France is",
                    "sampling_params": {"temperature": 0, "max_new_tokens": 32},
                },
            )
            print("=" * 100)
            print(f"[Parameterized Server] Generated text: {response.json()['text']}")
            return response.json()["text"]

        def get_model_info():
            response = requests.get(base_url + "/get_model_info")
            model_path = response.json()["model_path"]
            print(json.dumps(response.json()))
            return model_path

        def run_update_weights(model_path):
            response = requests.post(
                base_url + "/update_weights_from_disk",
                json={"model_path": model_path},
            )
            ret = response.json()
            print(json.dumps(ret))
            return ret

        origin_model_path = get_model_info()
        origin_response = run_decode()
        new_model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST.replace("-Instruct", "")
        ret = run_update_weights(new_model_path)
        self.assertTrue(ret["success"])
        updated_model_path = get_model_info()
        self.assertEqual(updated_model_path, new_model_path)
        self.assertNotEqual(updated_model_path, origin_model_path)
        updated_response = run_decode()
        self.assertNotEqual(origin_response[:32], updated_response[:32])
        ret = run_update_weights(origin_model_path)
        self.assertTrue(ret["success"])
        updated_model_path = get_model_info()
        self.assertEqual(updated_model_path, origin_model_path)
        reverted_response = run_decode()
        self.assertEqual(origin_response[:32], reverted_response[:32])
        return origin_response

    def test_parameterized_update_weights(self):
        if is_in_ci():
            # In CI, choose one random mode (Engine or Server) with tp=1, dp=1.
            mode = random.choice(["Engine", "Server"])
            test_suits = [(1, 1, mode)]
        else:
            # Otherwise, test both modes and enumerate tp,dp combinations from 1 to 2.
            test_suits = []
            for mode in ["Engine", "Server"]:
                for tp in [1, 2]:
                    for dp in [1, 2]:
                        test_suits.append((tp, dp, mode))
        for tp, dp, mode in test_suits:
            with self.subTest(mode=mode, tp=tp, dp=dp):
                self.run_common_test(mode, tp, dp)
290
291
292
293


if __name__ == "__main__":
    unittest.main()