server_args.py 16.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

Lianmin Zheng's avatar
Lianmin Zheng committed
16
17
"""The arguments of the server."""

Lianmin Zheng's avatar
Lianmin Zheng committed
18
19
import argparse
import dataclasses
20
import random
21
from typing import List, Optional, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
22
23
24
25


@dataclasses.dataclass
class ServerArgs:
Lianmin Zheng's avatar
Lianmin Zheng committed
26
    # Model and tokenizer
Lianmin Zheng's avatar
Lianmin Zheng committed
27
28
29
    model_path: str
    tokenizer_path: Optional[str] = None
    tokenizer_mode: str = "auto"
Lianmin Zheng's avatar
Lianmin Zheng committed
30
31
    load_format: str = "auto"
    dtype: str = "auto"
Lianmin Zheng's avatar
Lianmin Zheng committed
32
    trust_remote_code: bool = True
Lianmin Zheng's avatar
Lianmin Zheng committed
33
    context_length: Optional[int] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
34
    quantization: Optional[str] = None
35
    served_model_name: Optional[str] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
36
    chat_template: Optional[str] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
37
38
39
40
41
42
43

    # Port
    host: str = "127.0.0.1"
    port: int = 30000
    additional_ports: Optional[Union[List[int], int]] = None

    # Memory and scheduling
Lianmin Zheng's avatar
Lianmin Zheng committed
44
    mem_fraction_static: Optional[float] = None
45
46
    max_prefill_tokens: Optional[int] = None
    max_running_requests: Optional[int] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
47
    max_num_reqs: Optional[int] = None
48
    max_total_tokens: Optional[int] = None
49
    schedule_policy: str = "lpm"
50
    schedule_conservativeness: float = 1.0
Lianmin Zheng's avatar
Lianmin Zheng committed
51
52
53

    # Other runtime options
    tp_size: int = 1
54
    stream_interval: int = 1
55
    random_seed: Optional[int] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
56
57
58

    # Logging
    log_level: str = "info"
59
    log_level_http: Optional[str] = None
60
    log_requests: bool = False
Liangsheng Yin's avatar
Liangsheng Yin committed
61
    show_time_cost: bool = False
Liangsheng Yin's avatar
Liangsheng Yin committed
62

Lianmin Zheng's avatar
Lianmin Zheng committed
63
    # Other
64
    api_key: Optional[str] = None
65
    file_storage_pth: str = "SGlang_storage"
Lianmin Zheng's avatar
Lianmin Zheng committed
66

67
68
69
70
    # Data parallelism
    dp_size: int = 1
    load_balance_method: str = "round_robin"

Liangsheng Yin's avatar
Liangsheng Yin committed
71
72
73
    # Chunked Prefill
    chunked_prefill_size: Optional[int] = None

Lianmin Zheng's avatar
Lianmin Zheng committed
74
    # Optimization/debug options
75
    disable_flashinfer: bool = False
76
    disable_flashinfer_sampling: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
77
    disable_radix_cache: bool = False
78
    disable_regex_jump_forward: bool = False
79
    disable_cuda_graph: bool = False
80
    disable_disk_cache: bool = False
81
    enable_torch_compile: bool = False
82
    enable_p2p_check: bool = False
83
    enable_mla: bool = False
84
    attention_reduce_in_fp32: bool = False
85
    efficient_weight_load: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
86

87
88
89
90
91
    # Distributed args
    nccl_init_addr: Optional[str] = None
    nnodes: int = 1
    node_rank: Optional[int] = None

Lianmin Zheng's avatar
Lianmin Zheng committed
92
93
94
    def __post_init__(self):
        if self.tokenizer_path is None:
            self.tokenizer_path = self.model_path
95
96
97
98

        if self.served_model_name is None:
            self.served_model_name = self.model_path

Lianmin Zheng's avatar
Lianmin Zheng committed
99
        if self.mem_fraction_static is None:
100
            if self.tp_size >= 16:
101
                self.mem_fraction_static = 0.79
102
            elif self.tp_size >= 8:
103
                self.mem_fraction_static = 0.83
Lianmin Zheng's avatar
Lianmin Zheng committed
104
            elif self.tp_size >= 4:
105
                self.mem_fraction_static = 0.85
Lianmin Zheng's avatar
Lianmin Zheng committed
106
            elif self.tp_size >= 2:
107
                self.mem_fraction_static = 0.87
Ying Sheng's avatar
Ying Sheng committed
108
            else:
109
                self.mem_fraction_static = 0.88
110
111
112
113
        if isinstance(self.additional_ports, int):
            self.additional_ports = [self.additional_ports]
        elif self.additional_ports is None:
            self.additional_ports = []
Lianmin Zheng's avatar
Lianmin Zheng committed
114

115
116
117
        if self.random_seed is None:
            self.random_seed = random.randint(0, 1 << 30)

Lianmin Zheng's avatar
Lianmin Zheng committed
118
119
120
121
122
123
124
125
126
127
128
129
130
131
    @staticmethod
    def add_cli_args(parser: argparse.ArgumentParser):
        parser.add_argument(
            "--model-path",
            type=str,
            help="The path of the model weights. This can be a local folder or a Hugging Face repo ID.",
            required=True,
        )
        parser.add_argument(
            "--tokenizer-path",
            type=str,
            default=ServerArgs.tokenizer_path,
            help="The path of the tokenizer.",
        )
Yuanhan Zhang's avatar
Yuanhan Zhang committed
132
133
134
135
136
137
        parser.add_argument(
            "--host", type=str, default=ServerArgs.host, help="The host of the server."
        )
        parser.add_argument(
            "--port", type=int, default=ServerArgs.port, help="The port of the server."
        )
138
139
140
141
142
        parser.add_argument(
            "--additional-ports",
            type=int,
            nargs="*",
            default=[],
143
            help="The additional ports specified for the server.",
144
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
145
146
147
148
149
150
151
152
153
        parser.add_argument(
            "--tokenizer-mode",
            type=str,
            default=ServerArgs.tokenizer_mode,
            choices=["auto", "slow"],
            help="Tokenizer mode. 'auto' will use the fast "
            "tokenizer if available, and 'slow' will "
            "always use the slow tokenizer.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
        parser.add_argument(
            "--load-format",
            type=str,
            default=ServerArgs.load_format,
            choices=["auto", "pt", "safetensors", "npcache", "dummy"],
            help="The format of the model weights to load. "
            '"auto" will try to load the weights in the safetensors format '
            "and fall back to the pytorch bin format if safetensors format "
            "is not available. "
            '"pt" will load the weights in the pytorch bin format. '
            '"safetensors" will load the weights in the safetensors format. '
            '"npcache" will load the weights in pytorch format and store '
            "a numpy cache to speed up the loading. "
            '"dummy" will initialize the weights with random values, '
            "which is mainly for profiling.",
        )
        parser.add_argument(
Lianmin Zheng's avatar
Lianmin Zheng committed
171
            "--dtype",
Cody Yu's avatar
Cody Yu committed
172
            type=str,
Lianmin Zheng's avatar
Lianmin Zheng committed
173
            default=ServerArgs.dtype,
Ying Sheng's avatar
Ying Sheng committed
174
175
            choices=["auto", "half", "float16", "bfloat16", "float", "float32"],
            help="Data type for model weights and activations.\n\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
176
            '* "auto" will use FP16 precision for FP32 and FP16 models, and '
Ying Sheng's avatar
Ying Sheng committed
177
            "BF16 precision for BF16 models.\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
178
179
180
181
            '* "half" for FP16. Recommended for AWQ quantization.\n'
            '* "float16" is the same as "half".\n'
            '* "bfloat16" for a balance between precision and range.\n'
            '* "float" is shorthand for FP32 precision.\n'
Ying Sheng's avatar
Ying Sheng committed
182
183
            '* "float32" for FP32 precision.',
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
184
185
186
187
188
        parser.add_argument(
            "--trust-remote-code",
            action="store_true",
            help="Whether or not to allow for custom models defined on the Hub in their own modeling files.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
189
190
191
192
193
194
        parser.add_argument(
            "--context-length",
            type=int,
            default=ServerArgs.context_length,
            help="The model's maximum context length. Defaults to None (will use the value from the model's config.json instead).",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
195
196
197
198
        parser.add_argument(
            "--quantization",
            type=str,
            default=ServerArgs.quantization,
Ying Sheng's avatar
Ying Sheng committed
199
200
201
202
203
204
            choices=[
                "awq",
                "fp8",
                "gptq",
                "marlin",
                "gptq_marlin",
Ying Sheng's avatar
Ying Sheng committed
205
                "awq_marlin",
Ying Sheng's avatar
Ying Sheng committed
206
207
208
                "squeezellm",
                "bitsandbytes",
            ],
Lianmin Zheng's avatar
Lianmin Zheng committed
209
210
            help="The quantization method.",
        )
211
212
213
214
215
216
        parser.add_argument(
            "--served-model-name",
            type=str,
            default=ServerArgs.served_model_name,
            help="Override the model name returned by the v1/models endpoint in OpenAI API server.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
217
218
219
220
221
222
        parser.add_argument(
            "--chat-template",
            type=str,
            default=ServerArgs.chat_template,
            help="The buliltin chat template name or the path of the chat template file. This is only used for OpenAI-compatible API server.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
223
224
225
226
        parser.add_argument(
            "--mem-fraction-static",
            type=float,
            default=ServerArgs.mem_fraction_static,
227
            help="The fraction of the memory used for static allocation (model weights and KV cache memory pool). Use a smaller value if you see out-of-memory errors.",
Lianmin Zheng's avatar
Lianmin Zheng committed
228
        )
229
        parser.add_argument(
230
            "--max-prefill-tokens",
231
            type=int,
232
            default=ServerArgs.max_prefill_tokens,
233
            help="The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model's maximum context length.",
234
        )
235
236
237
238
239
240
        parser.add_argument(
            "--max-running-requests",
            type=int,
            default=ServerArgs.max_running_requests,
            help="The maximum number of running requests.",
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
241
242
243
        parser.add_argument(
            "--max-num-reqs",
            type=int,
Liangsheng Yin's avatar
Liangsheng Yin committed
244
            default=ServerArgs.max_num_reqs,
Liangsheng Yin's avatar
Liangsheng Yin committed
245
246
            help="The maximum number of requests to serve in the memory pool. If the model have a large context length, you may need to decrease this value to avoid out-of-memory errors.",
        )
247
248
249
250
251
252
        parser.add_argument(
            "--max-total-tokens",
            type=int,
            default=ServerArgs.max_total_tokens,
            help="The maximum number of tokens in the memory pool. If not specified, it will be automatically calculated based on the memory usage fraction. This option is typically used for development and debugging purposes.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
253
        parser.add_argument(
254
            "--schedule-policy",
Lianmin Zheng's avatar
Lianmin Zheng committed
255
            type=str,
256
            default=ServerArgs.schedule_policy,
Liangsheng Yin's avatar
Liangsheng Yin committed
257
            choices=["lpm", "random", "fcfs", "dfs-weight"],
258
            help="The scheduling policy of the requests.",
Lianmin Zheng's avatar
Lianmin Zheng committed
259
        )
260
261
262
263
        parser.add_argument(
            "--schedule-conservativeness",
            type=float,
            default=ServerArgs.schedule_conservativeness,
264
            help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see requests being retracted frequently.",
265
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
266
        parser.add_argument(
267
            "--tensor-parallel-size",
Lianmin Zheng's avatar
Lianmin Zheng committed
268
            "--tp-size",
Lianmin Zheng's avatar
Lianmin Zheng committed
269
            type=int,
Lianmin Zheng's avatar
Lianmin Zheng committed
270
            default=ServerArgs.tp_size,
271
            help="The tensor parallelism size.",
272
        )
273
274
275
        parser.add_argument(
            "--stream-interval",
            type=int,
Lianmin Zheng's avatar
Lianmin Zheng committed
276
            default=ServerArgs.stream_interval,
277
            help="The interval (or buffer size) for streaming in terms of the token length. A smaller value makes streaming smoother, while a larger value makes the throughput higher",
278
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
279
280
281
282
        parser.add_argument(
            "--random-seed",
            type=int,
            default=ServerArgs.random_seed,
283
            help="The random seed.",
Lianmin Zheng's avatar
Lianmin Zheng committed
284
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
285
286
287
288
        parser.add_argument(
            "--log-level",
            type=str,
            default=ServerArgs.log_level,
289
            help="The logging level of all loggers.",
Lianmin Zheng's avatar
Lianmin Zheng committed
290
        )
291
        parser.add_argument(
292
293
294
295
            "--log-level-http",
            type=str,
            default=ServerArgs.log_level_http,
            help="The logging level of HTTP server. If not set, reuse --log-level by default.",
296
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
297
        parser.add_argument(
298
            "--log-requests",
Lianmin Zheng's avatar
Lianmin Zheng committed
299
            action="store_true",
300
            help="Log the inputs and outputs of all requests.",
Lianmin Zheng's avatar
Lianmin Zheng committed
301
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
302
303
304
        parser.add_argument(
            "--show-time-cost",
            action="store_true",
Ying Sheng's avatar
Ying Sheng committed
305
            help="Show time cost of custom marks.",
Lianmin Zheng's avatar
Lianmin Zheng committed
306
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
307
308
309
310
        parser.add_argument(
            "--api-key",
            type=str,
            default=ServerArgs.api_key,
311
            help="Set API key of the server. It is also used in the OpenAI API compatible server.",
Liangsheng Yin's avatar
Liangsheng Yin committed
312
        )
313
314
315
316
317
318
        parser.add_argument(
            "--file-storage-pth",
            type=str,
            default=ServerArgs.file_storage_pth,
            help="The path of the file storage in backend.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
319

320
321
        # Data parallelism
        parser.add_argument(
322
            "--data-parallel-size",
323
324
325
            "--dp-size",
            type=int,
            default=ServerArgs.dp_size,
326
            help="The data parallelism size.",
327
328
329
330
331
        )
        parser.add_argument(
            "--load-balance-method",
            type=str,
            default=ServerArgs.load_balance_method,
332
            help="The load balancing strategy for data parallelism.",
333
334
335
336
337
338
            choices=[
                "round_robin",
                "shortest_queue",
            ],
        )

339
340
341
342
        # Multi-node distributed serving args
        parser.add_argument(
            "--nccl-init-addr",
            type=str,
Ying Sheng's avatar
Ying Sheng committed
343
            help="The nccl init address of multi-node server.",
344
345
        )
        parser.add_argument(
Liangsheng Yin's avatar
Liangsheng Yin committed
346
            "--nnodes", type=int, default=ServerArgs.nnodes, help="The number of nodes."
347
        )
Ying Sheng's avatar
Ying Sheng committed
348
        parser.add_argument("--node-rank", type=int, help="The node rank.")
349

Liangsheng Yin's avatar
Liangsheng Yin committed
350
351
352
353
354
355
356
357
        # Chunked prefill
        parser.add_argument(
            "--chunked-prefill-size",
            type=int,
            default=ServerArgs.chunked_prefill_size,
            help="The size of the chunked prefill.",
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
358
        # Optimization/debug options
Liangsheng Yin's avatar
Liangsheng Yin committed
359
        parser.add_argument(
360
            "--disable-flashinfer",
Liangsheng Yin's avatar
Liangsheng Yin committed
361
            action="store_true",
362
363
364
365
366
367
            help="Disable flashinfer attention kernels.",
        )
        parser.add_argument(
            "--disable-flashinfer-sampling",
            action="store_true",
            help="Disable flashinfer sampling kernels.",
Liangsheng Yin's avatar
Liangsheng Yin committed
368
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
369
        parser.add_argument(
Lianmin Zheng's avatar
Lianmin Zheng committed
370
            "--disable-radix-cache",
Liangsheng Yin's avatar
Liangsheng Yin committed
371
            action="store_true",
Ying Sheng's avatar
Ying Sheng committed
372
            help="Disable RadixAttention for prefix caching.",
Liangsheng Yin's avatar
Liangsheng Yin committed
373
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
374
        parser.add_argument(
375
            "--disable-regex-jump-forward",
Liangsheng Yin's avatar
Liangsheng Yin committed
376
            action="store_true",
Ying Sheng's avatar
Ying Sheng committed
377
            help="Disable regex jump-forward.",
Liangsheng Yin's avatar
Liangsheng Yin committed
378
        )
379
380
381
382
383
        parser.add_argument(
            "--disable-cuda-graph",
            action="store_true",
            help="Disable cuda graph.",
        )
384
385
386
387
388
        parser.add_argument(
            "--disable-disk-cache",
            action="store_true",
            help="Disable disk cache to avoid possible crashes related to file system or high concurrency.",
        )
389
390
391
392
393
        parser.add_argument(
            "--enable-torch-compile",
            action="store_true",
            help="Optimize the model with torch.compile, experimental feature.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
394
        parser.add_argument(
395
            "--enable-p2p-check",
Lianmin Zheng's avatar
Lianmin Zheng committed
396
            action="store_true",
397
            help="Enable P2P check for GPU access, otherwise the p2p access is allowed by default.",
Lianmin Zheng's avatar
Lianmin Zheng committed
398
        )
399
400
401
402
403
        parser.add_argument(
            "--enable-mla",
            action="store_true",
            help="Enable Multi-head Latent Attention (MLA) for DeepSeek-V2",
        )
404
        parser.add_argument(
405
            "--attention-reduce-in-fp32",
406
            action="store_true",
407
408
            help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
            "This only affects Triton attention kernels",
409
        )
410
411
412
413
414
        parser.add_argument(
            "--efficient-weight-load",
            action="store_true",
            help="Turn on memory efficient weight loading with quantization (quantize per layer during loading).",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
415
416
417

    @classmethod
    def from_cli_args(cls, args: argparse.Namespace):
418
419
        args.tp_size = args.tensor_parallel_size
        args.dp_size = args.data_parallel_size
Lianmin Zheng's avatar
Lianmin Zheng committed
420
421
422
423
424
425
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        return cls(**{attr: getattr(args, attr) for attr in attrs})

    def url(self):
        return f"http://{self.host}:{self.port}"

Lianmin Zheng's avatar
Lianmin Zheng committed
426
    def print_mode_args(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
427
        return (
428
            f"disable_flashinfer={self.disable_flashinfer}, "
429
            f"attention_reduce_in_fp32={self.attention_reduce_in_fp32}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
430
            f"disable_radix_cache={self.disable_radix_cache}, "
Liangsheng Yin's avatar
Liangsheng Yin committed
431
432
433
434
            f"disable_regex_jump_forward={self.disable_regex_jump_forward}, "
            f"disable_disk_cache={self.disable_disk_cache}, "
        )

435
436
437
438
439
440
441
442
    def check_server_args(self):
        assert (
            self.tp_size % self.nnodes == 0
        ), "tp_size must be divisible by number of nodes"
        assert not (
            self.dp_size > 1 and self.node_rank is not None
        ), "multi-node data parallel is not supported"

Lianmin Zheng's avatar
Lianmin Zheng committed
443
444
445
446

@dataclasses.dataclass
class PortArgs:
    tokenizer_port: int
Mingyi's avatar
Mingyi committed
447
    controller_port: int
Lianmin Zheng's avatar
Lianmin Zheng committed
448
    detokenizer_port: int
Mingyi's avatar
Mingyi committed
449
    nccl_ports: List[int]