test_server_args.py 9.74 KB
Newer Older
Lianmin Zheng's avatar
Lianmin Zheng committed
1
import json
2
import unittest
Vincent's avatar
Vincent committed
3
from unittest.mock import MagicMock, patch
4

Lianmin Zheng's avatar
Lianmin Zheng committed
5
from sglang.srt.server_args import PortArgs, prepare_server_args
6
from sglang.test.test_utils import CustomTestCase
7
8


9
class TestPrepareServerArgs(CustomTestCase):
10
11
12
13
    def test_prepare_server_args(self):
        server_args = prepare_server_args(
            [
                "--model-path",
14
                "meta-llama/Meta-Llama-3.1-8B-Instruct",
15
                "--json-model-override-args",
16
                '{"rope_scaling": {"factor": 2.0, "rope_type": "linear"}}',
17
18
            ]
        )
19
20
21
        self.assertEqual(
            server_args.model_path, "meta-llama/Meta-Llama-3.1-8B-Instruct"
        )
22
        self.assertEqual(
Lianmin Zheng's avatar
Lianmin Zheng committed
23
            json.loads(server_args.json_model_override_args),
24
            {"rope_scaling": {"factor": 2.0, "rope_type": "linear"}},
25
26
27
        )


Vincent's avatar
Vincent committed
28
29
30
31
32
33
34
35
36
class TestPortArgs(unittest.TestCase):
    @patch("sglang.srt.server_args.is_port_available")
    @patch("sglang.srt.server_args.tempfile.NamedTemporaryFile")
    def test_init_new_standard_case(self, mock_temp_file, mock_is_port_available):
        mock_is_port_available.return_value = True
        mock_temp_file.return_value.name = "temp_file"

        server_args = MagicMock()
        server_args.port = 30000
Mick's avatar
Mick committed
37
        server_args.nccl_port = None
Vincent's avatar
Vincent committed
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
        server_args.enable_dp_attention = False

        port_args = PortArgs.init_new(server_args)

        self.assertTrue(port_args.tokenizer_ipc_name.startswith("ipc://"))
        self.assertTrue(port_args.scheduler_input_ipc_name.startswith("ipc://"))
        self.assertTrue(port_args.detokenizer_ipc_name.startswith("ipc://"))
        self.assertIsInstance(port_args.nccl_port, int)

    @patch("sglang.srt.server_args.is_port_available")
    def test_init_new_with_single_node_dp_attention(self, mock_is_port_available):
        mock_is_port_available.return_value = True

        server_args = MagicMock()
        server_args.port = 30000
Mick's avatar
Mick committed
53
        server_args.nccl_port = None
Vincent's avatar
Vincent committed
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
        server_args.enable_dp_attention = True
        server_args.nnodes = 1
        server_args.dist_init_addr = None

        port_args = PortArgs.init_new(server_args)

        self.assertTrue(port_args.tokenizer_ipc_name.startswith("tcp://127.0.0.1:"))
        self.assertTrue(
            port_args.scheduler_input_ipc_name.startswith("tcp://127.0.0.1:")
        )
        self.assertTrue(port_args.detokenizer_ipc_name.startswith("tcp://127.0.0.1:"))
        self.assertIsInstance(port_args.nccl_port, int)

    @patch("sglang.srt.server_args.is_port_available")
    def test_init_new_with_dp_rank(self, mock_is_port_available):
        mock_is_port_available.return_value = True

        server_args = MagicMock()
        server_args.port = 30000
Mick's avatar
Mick committed
73
        server_args.nccl_port = None
Vincent's avatar
Vincent committed
74
75
76
77
        server_args.enable_dp_attention = True
        server_args.nnodes = 1
        server_args.dist_init_addr = "192.168.1.1:25000"

78
79
        worker_ports = [25006, 25007, 25008, 25009]
        port_args = PortArgs.init_new(server_args, dp_rank=2, worker_ports=worker_ports)
Vincent's avatar
Vincent committed
80

81
        self.assertTrue(port_args.scheduler_input_ipc_name.endswith(":25008"))
Vincent's avatar
Vincent committed
82
83
84
85
86
87
88
89
90
91
92

        self.assertTrue(port_args.tokenizer_ipc_name.startswith("tcp://192.168.1.1:"))
        self.assertTrue(port_args.detokenizer_ipc_name.startswith("tcp://192.168.1.1:"))
        self.assertIsInstance(port_args.nccl_port, int)

    @patch("sglang.srt.server_args.is_port_available")
    def test_init_new_with_ipv4_address(self, mock_is_port_available):
        mock_is_port_available.return_value = True

        server_args = MagicMock()
        server_args.port = 30000
Mick's avatar
Mick committed
93
94
95

        server_args.nccl_port = None

Vincent's avatar
Vincent committed
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
        server_args.enable_dp_attention = True
        server_args.nnodes = 2
        server_args.dist_init_addr = "192.168.1.1:25000"

        port_args = PortArgs.init_new(server_args)

        self.assertTrue(port_args.tokenizer_ipc_name.startswith("tcp://192.168.1.1:"))
        self.assertTrue(
            port_args.scheduler_input_ipc_name.startswith("tcp://192.168.1.1:")
        )
        self.assertTrue(port_args.detokenizer_ipc_name.startswith("tcp://192.168.1.1:"))
        self.assertIsInstance(port_args.nccl_port, int)

    @patch("sglang.srt.server_args.is_port_available")
    def test_init_new_with_malformed_ipv4_address(self, mock_is_port_available):
        mock_is_port_available.return_value = True

        server_args = MagicMock()
        server_args.port = 30000
Mick's avatar
Mick committed
115
116
        server_args.nccl_port = None

Vincent's avatar
Vincent committed
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
        server_args.enable_dp_attention = True
        server_args.nnodes = 2
        server_args.dist_init_addr = "192.168.1.1"

        with self.assertRaises(AssertionError) as context:
            PortArgs.init_new(server_args)

        self.assertIn(
            "please provide --dist-init-addr as host:port", str(context.exception)
        )

    @patch("sglang.srt.server_args.is_port_available")
    def test_init_new_with_malformed_ipv4_address_invalid_port(
        self, mock_is_port_available
    ):
        mock_is_port_available.return_value = True

        server_args = MagicMock()
        server_args.port = 30000
Mick's avatar
Mick committed
136
137
        server_args.nccl_port = None

Vincent's avatar
Vincent committed
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
        server_args.enable_dp_attention = True
        server_args.nnodes = 2
        server_args.dist_init_addr = "192.168.1.1:abc"

        with self.assertRaises(ValueError) as context:
            PortArgs.init_new(server_args)

    @patch("sglang.srt.server_args.is_port_available")
    @patch("sglang.srt.server_args.is_valid_ipv6_address", return_value=True)
    def test_init_new_with_ipv6_address(
        self, mock_is_valid_ipv6, mock_is_port_available
    ):
        mock_is_port_available.return_value = True

        server_args = MagicMock()
        server_args.port = 30000
Mick's avatar
Mick committed
154
155
        server_args.nccl_port = None

Vincent's avatar
Vincent committed
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
        server_args.enable_dp_attention = True
        server_args.nnodes = 2
        server_args.dist_init_addr = "[2001:db8::1]:25000"

        port_args = PortArgs.init_new(server_args)

        self.assertTrue(port_args.tokenizer_ipc_name.startswith("tcp://[2001:db8::1]:"))
        self.assertTrue(
            port_args.scheduler_input_ipc_name.startswith("tcp://[2001:db8::1]:")
        )
        self.assertTrue(
            port_args.detokenizer_ipc_name.startswith("tcp://[2001:db8::1]:")
        )
        self.assertIsInstance(port_args.nccl_port, int)

    @patch("sglang.srt.server_args.is_port_available")
    @patch("sglang.srt.server_args.is_valid_ipv6_address", return_value=False)
    def test_init_new_with_invalid_ipv6_address(
        self, mock_is_valid_ipv6, mock_is_port_available
    ):
        mock_is_port_available.return_value = True

        server_args = MagicMock()
        server_args.port = 30000
Mick's avatar
Mick committed
180
181
        server_args.nccl_port = None

Vincent's avatar
Vincent committed
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
        server_args.enable_dp_attention = True
        server_args.nnodes = 2
        server_args.dist_init_addr = "[invalid-ipv6]:25000"

        with self.assertRaises(ValueError) as context:
            PortArgs.init_new(server_args)

        self.assertIn("invalid IPv6 address", str(context.exception))

    @patch("sglang.srt.server_args.is_port_available")
    def test_init_new_with_malformed_ipv6_address_missing_bracket(
        self, mock_is_port_available
    ):
        mock_is_port_available.return_value = True

        server_args = MagicMock()
        server_args.port = 30000
Mick's avatar
Mick committed
199
200
        server_args.nccl_port = None

Vincent's avatar
Vincent committed
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
        server_args.enable_dp_attention = True
        server_args.nnodes = 2
        server_args.dist_init_addr = "[2001:db8::1:25000"

        with self.assertRaises(ValueError) as context:
            PortArgs.init_new(server_args)

        self.assertIn("invalid IPv6 address format", str(context.exception))

    @patch("sglang.srt.server_args.is_port_available")
    @patch("sglang.srt.server_args.is_valid_ipv6_address", return_value=True)
    def test_init_new_with_malformed_ipv6_address_missing_port(
        self, mock_is_valid_ipv6, mock_is_port_available
    ):
        mock_is_port_available.return_value = True

        server_args = MagicMock()
        server_args.port = 30000
Mick's avatar
Mick committed
219
220
        server_args.nccl_port = None

Vincent's avatar
Vincent committed
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
        server_args.enable_dp_attention = True
        server_args.nnodes = 2
        server_args.dist_init_addr = "[2001:db8::1]"

        with self.assertRaises(ValueError) as context:
            PortArgs.init_new(server_args)

        self.assertIn(
            "a port must be specified in IPv6 address", str(context.exception)
        )

    @patch("sglang.srt.server_args.is_port_available")
    @patch("sglang.srt.server_args.is_valid_ipv6_address", return_value=True)
    def test_init_new_with_malformed_ipv6_address_invalid_port(
        self, mock_is_valid_ipv6, mock_is_port_available
    ):
        mock_is_port_available.return_value = True

        server_args = MagicMock()
        server_args.port = 30000
Mick's avatar
Mick committed
241
242
        server_args.nccl_port = None

Vincent's avatar
Vincent committed
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
        server_args.enable_dp_attention = True
        server_args.nnodes = 2
        server_args.dist_init_addr = "[2001:db8::1]:abcde"

        with self.assertRaises(ValueError) as context:
            PortArgs.init_new(server_args)

        self.assertIn("invalid port in IPv6 address", str(context.exception))

    @patch("sglang.srt.server_args.is_port_available")
    @patch("sglang.srt.server_args.is_valid_ipv6_address", return_value=True)
    def test_init_new_with_malformed_ipv6_address_wrong_separator(
        self, mock_is_valid_ipv6, mock_is_port_available
    ):
        mock_is_port_available.return_value = True

        server_args = MagicMock()
        server_args.port = 30000
Mick's avatar
Mick committed
261
262
        server_args.nccl_port = None

Vincent's avatar
Vincent committed
263
264
265
266
267
268
269
270
271
272
        server_args.enable_dp_attention = True
        server_args.nnodes = 2
        server_args.dist_init_addr = "[2001:db8::1]#25000"

        with self.assertRaises(ValueError) as context:
            PortArgs.init_new(server_args)

        self.assertIn("expected ':' after ']'", str(context.exception))


273
274
if __name__ == "__main__":
    unittest.main()