server_args.py 23.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

Lianmin Zheng's avatar
Lianmin Zheng committed
16
17
"""The arguments of the server."""

Lianmin Zheng's avatar
Lianmin Zheng committed
18
19
import argparse
import dataclasses
20
import logging
21
import random
22
23
import tempfile
from typing import List, Optional
Lianmin Zheng's avatar
Lianmin Zheng committed
24

25
from sglang.srt.utils import is_hip, is_ipv6, is_port_available
26

27
28
logger = logging.getLogger(__name__)

Lianmin Zheng's avatar
Lianmin Zheng committed
29
30
31

@dataclasses.dataclass
class ServerArgs:
Lianmin Zheng's avatar
Lianmin Zheng committed
32
    # Model and tokenizer
Lianmin Zheng's avatar
Lianmin Zheng committed
33
34
35
    model_path: str
    tokenizer_path: Optional[str] = None
    tokenizer_mode: str = "auto"
36
    skip_tokenizer_init: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
37
38
    load_format: str = "auto"
    dtype: str = "auto"
39
    kv_cache_dtype: str = "auto"
Lianmin Zheng's avatar
Lianmin Zheng committed
40
    trust_remote_code: bool = True
Lianmin Zheng's avatar
Lianmin Zheng committed
41
    context_length: Optional[int] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
42
    quantization: Optional[str] = None
43
    served_model_name: Optional[str] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
44
    chat_template: Optional[str] = None
45
    is_embedding: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
46
47
48
49
50
51

    # Port
    host: str = "127.0.0.1"
    port: int = 30000

    # Memory and scheduling
Lianmin Zheng's avatar
Lianmin Zheng committed
52
    mem_fraction_static: Optional[float] = None
53
    max_running_requests: Optional[int] = None
54
    max_total_tokens: Optional[int] = None
55
    chunked_prefill_size: int = 8192
56
    max_prefill_tokens: int = 16384
57
    schedule_policy: str = "lpm"
58
    schedule_conservativeness: float = 1.0
Lianmin Zheng's avatar
Lianmin Zheng committed
59
60
61

    # Other runtime options
    tp_size: int = 1
62
    stream_interval: int = 1
63
    random_seed: Optional[int] = None
64
    constrained_json_whitespace_pattern: Optional[str] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
65
66
67

    # Logging
    log_level: str = "info"
68
    log_level_http: Optional[str] = None
69
    log_requests: bool = False
Liangsheng Yin's avatar
Liangsheng Yin committed
70
    show_time_cost: bool = False
Liangsheng Yin's avatar
Liangsheng Yin committed
71

Lianmin Zheng's avatar
Lianmin Zheng committed
72
    # Other
73
    api_key: Optional[str] = None
74
    file_storage_pth: str = "SGLang_storage"
Lianmin Zheng's avatar
Lianmin Zheng committed
75

76
77
78
79
    # Data parallelism
    dp_size: int = 1
    load_balance_method: str = "round_robin"

Lianmin Zheng's avatar
Lianmin Zheng committed
80
    # Distributed args
81
    dist_init_addr: Optional[str] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
82
    nnodes: int = 1
83
    node_rank: int = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
84
85
86
87

    # Model override args in JSON
    json_model_override_args: str = "{}"

Lianmin Zheng's avatar
Lianmin Zheng committed
88
    # Optimization/debug options
89
90
    attention_backend: Optional[str] = None
    sampling_backend: Optional[str] = None
91

92
    disable_flashinfer: bool = False
93
    disable_flashinfer_sampling: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
94
    disable_radix_cache: bool = False
95
    disable_regex_jump_forward: bool = False
96
    disable_cuda_graph: bool = False
97
    disable_cuda_graph_padding: bool = False
98
    disable_disk_cache: bool = False
99
    disable_custom_all_reduce: bool = False
Ke Bao's avatar
Ke Bao committed
100
    disable_mla: bool = False
101
    enable_mixed_chunk: bool = False
102
    enable_torch_compile: bool = False
103
    max_torch_compile_bs: int = 32
104
    torchao_config: str = ""
105
    enable_p2p_check: bool = False
106
    triton_attention_reduce_in_fp32: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
107

108
109
110
111
    # LoRA
    lora_paths: Optional[List[str]] = None
    max_loras_per_batch: int = 8

Lianmin Zheng's avatar
Lianmin Zheng committed
112
    def __post_init__(self):
113
        # Set missing default values
Lianmin Zheng's avatar
Lianmin Zheng committed
114
115
        if self.tokenizer_path is None:
            self.tokenizer_path = self.model_path
116
117
118
119

        if self.served_model_name is None:
            self.served_model_name = self.model_path

120
121
122
123
        if self.chunked_prefill_size <= 0:
            # Disable chunked prefill
            self.chunked_prefill_size = None

124
        # Mem fraction depends on the tensor parallelism size
Lianmin Zheng's avatar
Lianmin Zheng committed
125
        if self.mem_fraction_static is None:
126
            if self.tp_size >= 16:
127
                self.mem_fraction_static = 0.79
128
            elif self.tp_size >= 8:
129
                self.mem_fraction_static = 0.83
Lianmin Zheng's avatar
Lianmin Zheng committed
130
            elif self.tp_size >= 4:
131
                self.mem_fraction_static = 0.85
Lianmin Zheng's avatar
Lianmin Zheng committed
132
            elif self.tp_size >= 2:
133
                self.mem_fraction_static = 0.87
Ying Sheng's avatar
Ying Sheng committed
134
            else:
135
                self.mem_fraction_static = 0.88
136

137
138
139
        if self.random_seed is None:
            self.random_seed = random.randint(0, 1 << 30)

140
141
142
143
144
145
        # Deprecation warnings
        if self.disable_flashinfer:
            logger.warning(
                "The option '--disable-flashinfer' will be deprecated in the next release. "
                "Please use '--attention-backend triton' instead."
            )
146
            self.attention_backend = "triton"
147
148
149
150
151
        if self.disable_flashinfer_sampling:
            logger.warning(
                "The option '--disable-flashinfer-sampling' will be deprecated in the next release. "
                "Please use '--sampling-backend pytorch' instead. "
            )
152
            self.sampling_backend = "pytorch"
153

154
155
156
157
158
        # ROCm: flashinfer available later
        if is_hip():
            self.attention_backend = "triton"
            self.sampling_backend = "pytorch"

159
160
161
162
163
164
165
        # Default kernel backends
        if self.attention_backend is None:
            self.attention_backend = "flashinfer"

        if self.sampling_backend is None:
            self.sampling_backend = "flashinfer"

166
167
168
169
170
171
172
173
174
175
176
        # Model-specific patches
        if "Alibaba-NLP/gte-Qwen2-1.5B-instruct" == self.model_path:
            logger.info(
                "Not sure why, the tokenizer will add an additional token at the end of the prompt when trust_remote_mode=True"
            )
            self.trust_remote_code = False

        if "gemma-2" in self.model_path.lower():
            logger.info("When using sliding window in gemma-2, turn on flashinfer.")
            self.attention_backend = "flashinfer"

Lianmin Zheng's avatar
Lianmin Zheng committed
177
178
179
180
181
182
183
184
185
186
187
188
189
190
    @staticmethod
    def add_cli_args(parser: argparse.ArgumentParser):
        parser.add_argument(
            "--model-path",
            type=str,
            help="The path of the model weights. This can be a local folder or a Hugging Face repo ID.",
            required=True,
        )
        parser.add_argument(
            "--tokenizer-path",
            type=str,
            default=ServerArgs.tokenizer_path,
            help="The path of the tokenizer.",
        )
Yuanhan Zhang's avatar
Yuanhan Zhang committed
191
192
193
194
195
196
        parser.add_argument(
            "--host", type=str, default=ServerArgs.host, help="The host of the server."
        )
        parser.add_argument(
            "--port", type=int, default=ServerArgs.port, help="The port of the server."
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
197
198
199
200
201
202
203
204
205
        parser.add_argument(
            "--tokenizer-mode",
            type=str,
            default=ServerArgs.tokenizer_mode,
            choices=["auto", "slow"],
            help="Tokenizer mode. 'auto' will use the fast "
            "tokenizer if available, and 'slow' will "
            "always use the slow tokenizer.",
        )
206
207
208
209
210
        parser.add_argument(
            "--skip-tokenizer-init",
            action="store_true",
            help="If set, skip init tokenizer and pass input_ids in generate request",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
        parser.add_argument(
            "--load-format",
            type=str,
            default=ServerArgs.load_format,
            choices=["auto", "pt", "safetensors", "npcache", "dummy"],
            help="The format of the model weights to load. "
            '"auto" will try to load the weights in the safetensors format '
            "and fall back to the pytorch bin format if safetensors format "
            "is not available. "
            '"pt" will load the weights in the pytorch bin format. '
            '"safetensors" will load the weights in the safetensors format. '
            '"npcache" will load the weights in pytorch format and store '
            "a numpy cache to speed up the loading. "
            '"dummy" will initialize the weights with random values, '
            "which is mainly for profiling.",
        )
        parser.add_argument(
Lianmin Zheng's avatar
Lianmin Zheng committed
228
            "--dtype",
Cody Yu's avatar
Cody Yu committed
229
            type=str,
Lianmin Zheng's avatar
Lianmin Zheng committed
230
            default=ServerArgs.dtype,
Ying Sheng's avatar
Ying Sheng committed
231
232
            choices=["auto", "half", "float16", "bfloat16", "float", "float32"],
            help="Data type for model weights and activations.\n\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
233
            '* "auto" will use FP16 precision for FP32 and FP16 models, and '
Ying Sheng's avatar
Ying Sheng committed
234
            "BF16 precision for BF16 models.\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
235
236
237
238
            '* "half" for FP16. Recommended for AWQ quantization.\n'
            '* "float16" is the same as "half".\n'
            '* "bfloat16" for a balance between precision and range.\n'
            '* "float" is shorthand for FP32 precision.\n'
Ying Sheng's avatar
Ying Sheng committed
239
240
            '* "float32" for FP32 precision.',
        )
241
242
243
244
245
246
247
        parser.add_argument(
            "--kv-cache-dtype",
            type=str,
            default=ServerArgs.kv_cache_dtype,
            choices=["auto", "fp8_e5m2"],
            help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" is supported for CUDA 11.8+.',
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
248
249
250
251
252
        parser.add_argument(
            "--trust-remote-code",
            action="store_true",
            help="Whether or not to allow for custom models defined on the Hub in their own modeling files.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
253
254
255
256
257
258
        parser.add_argument(
            "--context-length",
            type=int,
            default=ServerArgs.context_length,
            help="The model's maximum context length. Defaults to None (will use the value from the model's config.json instead).",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
259
260
261
262
        parser.add_argument(
            "--quantization",
            type=str,
            default=ServerArgs.quantization,
Ying Sheng's avatar
Ying Sheng committed
263
264
265
266
267
268
            choices=[
                "awq",
                "fp8",
                "gptq",
                "marlin",
                "gptq_marlin",
Ying Sheng's avatar
Ying Sheng committed
269
                "awq_marlin",
Ying Sheng's avatar
Ying Sheng committed
270
271
272
                "squeezellm",
                "bitsandbytes",
            ],
Lianmin Zheng's avatar
Lianmin Zheng committed
273
274
            help="The quantization method.",
        )
275
276
277
278
279
280
        parser.add_argument(
            "--served-model-name",
            type=str,
            default=ServerArgs.served_model_name,
            help="Override the model name returned by the v1/models endpoint in OpenAI API server.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
281
282
283
284
285
286
        parser.add_argument(
            "--chat-template",
            type=str,
            default=ServerArgs.chat_template,
            help="The buliltin chat template name or the path of the chat template file. This is only used for OpenAI-compatible API server.",
        )
287
288
289
290
291
        parser.add_argument(
            "--is-embedding",
            action="store_true",
            help="Whether to use a CausalLM as an embedding model.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
292
293
294
295
        parser.add_argument(
            "--mem-fraction-static",
            type=float,
            default=ServerArgs.mem_fraction_static,
296
            help="The fraction of the memory used for static allocation (model weights and KV cache memory pool). Use a smaller value if you see out-of-memory errors.",
Lianmin Zheng's avatar
Lianmin Zheng committed
297
        )
298
299
300
301
302
303
        parser.add_argument(
            "--max-running-requests",
            type=int,
            default=ServerArgs.max_running_requests,
            help="The maximum number of running requests.",
        )
304
305
306
307
        parser.add_argument(
            "--max-total-tokens",
            type=int,
            default=ServerArgs.max_total_tokens,
308
309
            help="The maximum number of tokens in the memory pool. If not specified, it will be automatically calculated based on the memory usage fraction. "
            "This option is typically used for development and debugging purposes.",
310
        )
311
312
313
314
315
316
317
318
319
320
321
322
        parser.add_argument(
            "--chunked-prefill-size",
            type=int,
            default=ServerArgs.chunked_prefill_size,
            help="The maximum number of tokens in a chunk for the chunked prefill. Setting this to -1 means disabling chunked prefill",
        )
        parser.add_argument(
            "--max-prefill-tokens",
            type=int,
            default=ServerArgs.max_prefill_tokens,
            help="The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model's maximum context length.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
323
        parser.add_argument(
324
            "--schedule-policy",
Lianmin Zheng's avatar
Lianmin Zheng committed
325
            type=str,
326
            default=ServerArgs.schedule_policy,
Liangsheng Yin's avatar
Liangsheng Yin committed
327
            choices=["lpm", "random", "fcfs", "dfs-weight"],
328
            help="The scheduling policy of the requests.",
Lianmin Zheng's avatar
Lianmin Zheng committed
329
        )
330
331
332
333
        parser.add_argument(
            "--schedule-conservativeness",
            type=float,
            default=ServerArgs.schedule_conservativeness,
334
            help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see requests being retracted frequently.",
335
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
336
        parser.add_argument(
337
            "--tensor-parallel-size",
Lianmin Zheng's avatar
Lianmin Zheng committed
338
            "--tp-size",
Lianmin Zheng's avatar
Lianmin Zheng committed
339
            type=int,
Lianmin Zheng's avatar
Lianmin Zheng committed
340
            default=ServerArgs.tp_size,
341
            help="The tensor parallelism size.",
342
        )
343
344
345
        parser.add_argument(
            "--stream-interval",
            type=int,
Lianmin Zheng's avatar
Lianmin Zheng committed
346
            default=ServerArgs.stream_interval,
347
            help="The interval (or buffer size) for streaming in terms of the token length. A smaller value makes streaming smoother, while a larger value makes the throughput higher",
348
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
349
350
351
352
        parser.add_argument(
            "--random-seed",
            type=int,
            default=ServerArgs.random_seed,
353
            help="The random seed.",
Lianmin Zheng's avatar
Lianmin Zheng committed
354
        )
355
356
357
358
359
360
        parser.add_argument(
            "--constrained-json-whitespace-pattern",
            type=str,
            default=ServerArgs.constrained_json_whitespace_pattern,
            help=r"Regex pattern for syntactic whitespaces allowed in JSON constrained output. For example, to allow the model generate consecutive whitespaces, set the pattern to [\n\t ]*",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
361
362
363
364
        parser.add_argument(
            "--log-level",
            type=str,
            default=ServerArgs.log_level,
365
            help="The logging level of all loggers.",
Lianmin Zheng's avatar
Lianmin Zheng committed
366
        )
367
        parser.add_argument(
368
369
370
371
            "--log-level-http",
            type=str,
            default=ServerArgs.log_level_http,
            help="The logging level of HTTP server. If not set, reuse --log-level by default.",
372
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
373
        parser.add_argument(
374
            "--log-requests",
Lianmin Zheng's avatar
Lianmin Zheng committed
375
            action="store_true",
376
            help="Log the inputs and outputs of all requests.",
Lianmin Zheng's avatar
Lianmin Zheng committed
377
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
378
379
380
        parser.add_argument(
            "--show-time-cost",
            action="store_true",
Ying Sheng's avatar
Ying Sheng committed
381
            help="Show time cost of custom marks.",
Lianmin Zheng's avatar
Lianmin Zheng committed
382
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
383
384
385
386
        parser.add_argument(
            "--api-key",
            type=str,
            default=ServerArgs.api_key,
387
            help="Set API key of the server. It is also used in the OpenAI API compatible server.",
Liangsheng Yin's avatar
Liangsheng Yin committed
388
        )
389
390
391
392
393
394
        parser.add_argument(
            "--file-storage-pth",
            type=str,
            default=ServerArgs.file_storage_pth,
            help="The path of the file storage in backend.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
395

396
397
        # Data parallelism
        parser.add_argument(
398
            "--data-parallel-size",
399
400
401
            "--dp-size",
            type=int,
            default=ServerArgs.dp_size,
402
            help="The data parallelism size.",
403
404
405
406
407
        )
        parser.add_argument(
            "--load-balance-method",
            type=str,
            default=ServerArgs.load_balance_method,
408
            help="The load balancing strategy for data parallelism.",
409
410
411
412
413
414
            choices=[
                "round_robin",
                "shortest_queue",
            ],
        )

415
416
        # Multi-node distributed serving args
        parser.add_argument(
417
418
            "--dist-init-addr",
            "--nccl-init-addr",  # For backward compatbility. This will be removed in the future.
419
            type=str,
420
            help="The host address for initializing distributed backend (e.g., `192.168.0.2:25000`).",
421
422
        )
        parser.add_argument(
Liangsheng Yin's avatar
Liangsheng Yin committed
423
            "--nnodes", type=int, default=ServerArgs.nnodes, help="The number of nodes."
424
        )
425
426
427
        parser.add_argument(
            "--node-rank", type=int, default=ServerArgs.node_rank, help="The node rank."
        )
428

Lianmin Zheng's avatar
Lianmin Zheng committed
429
430
431
432
433
434
435
436
        # Model override args
        parser.add_argument(
            "--json-model-override-args",
            type=str,
            help="A dictionary in JSON string format used to override default model configurations.",
            default=ServerArgs.json_model_override_args,
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
437
        # Optimization/debug options
438
439
440
441
442
443
444
445
446
447
448
449
450
451
        parser.add_argument(
            "--attention-backend",
            type=str,
            choices=["flashinfer", "triton"],
            default=ServerArgs.attention_backend,
            help="Choose the kernels for attention layers.",
        )
        parser.add_argument(
            "--sampling-backend",
            type=str,
            choices=["flashinfer", "pytorch"],
            default=ServerArgs.sampling_backend,
            help="Choose the kernels for sampling layers.",
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
452
        parser.add_argument(
453
            "--disable-flashinfer",
Liangsheng Yin's avatar
Liangsheng Yin committed
454
            action="store_true",
455
            help="Disable flashinfer attention kernels. This option will be deprecated in the next release. Please use '--attention-backend triton' instead.",
456
457
458
459
        )
        parser.add_argument(
            "--disable-flashinfer-sampling",
            action="store_true",
460
            help="Disable flashinfer sampling kernels. This option will be deprecated in the next release. Please use '--sampling-backend pytorch' instead.",
Liangsheng Yin's avatar
Liangsheng Yin committed
461
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
462
        parser.add_argument(
Lianmin Zheng's avatar
Lianmin Zheng committed
463
            "--disable-radix-cache",
Liangsheng Yin's avatar
Liangsheng Yin committed
464
            action="store_true",
Ying Sheng's avatar
Ying Sheng committed
465
            help="Disable RadixAttention for prefix caching.",
Liangsheng Yin's avatar
Liangsheng Yin committed
466
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
467
        parser.add_argument(
468
            "--disable-regex-jump-forward",
Liangsheng Yin's avatar
Liangsheng Yin committed
469
            action="store_true",
Ying Sheng's avatar
Ying Sheng committed
470
            help="Disable regex jump-forward.",
Liangsheng Yin's avatar
Liangsheng Yin committed
471
        )
472
473
474
475
476
        parser.add_argument(
            "--disable-cuda-graph",
            action="store_true",
            help="Disable cuda graph.",
        )
477
        parser.add_argument(
478
479
480
481
482
            "--disable-cuda-graph-padding",
            action="store_true",
            help="Disable cuda graph when padding is needed. Still uses cuda graph when padding is not needed.",
        )
        parser.add_argument(
483
484
485
486
            "--disable-disk-cache",
            action="store_true",
            help="Disable disk cache to avoid possible crashes related to file system or high concurrency.",
        )
487
488
489
490
491
492
        parser.add_argument(
            "--disable-custom-all-reduce",
            action="store_true",
            default=False,
            help="Disable the custom all-reduce kernel and fall back to NCCL.",
        )
Ke Bao's avatar
Ke Bao committed
493
494
495
496
497
        parser.add_argument(
            "--disable-mla",
            action="store_true",
            help="Disable Multi-head Latent Attention (MLA) for DeepSeek-V2.",
        )
498
499
500
        parser.add_argument(
            "--enable-mixed-chunk",
            action="store_true",
501
            help="Enabling mixing prefill and decode in a batch when using chunked prefill.",
502
        )
503
504
505
        parser.add_argument(
            "--enable-torch-compile",
            action="store_true",
506
507
            help="Optimize the model with torch.compile. Experimental feature.",
        )
508
509
510
511
512
513
        parser.add_argument(
            "--max-torch-compile-bs",
            type=int,
            default=ServerArgs.max_torch_compile_bs,
            help="Set the maximum batch size when using torch compile.",
        )
514
515
516
517
518
        parser.add_argument(
            "--torchao-config",
            type=str,
            default=ServerArgs.torchao_config,
            help="Optimize the model with torchao. Experimental feature. Current choices are: int8dq, int8wo, int4wo-<group_size>, fp8wo",
519
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
520
        parser.add_argument(
521
            "--enable-p2p-check",
Lianmin Zheng's avatar
Lianmin Zheng committed
522
            action="store_true",
523
            help="Enable P2P check for GPU access, otherwise the p2p access is allowed by default.",
Lianmin Zheng's avatar
Lianmin Zheng committed
524
        )
525
        parser.add_argument(
526
            "--triton-attention-reduce-in-fp32",
527
            action="store_true",
528
            help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
529
            "This only affects Triton attention kernels.",
530
        )
531
532
533
534
535
        parser.add_argument(
            "--efficient-weight-load",
            action="store_true",
            help="Turn on memory efficient weight loading with quantization (quantize per layer during loading).",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
536

537
538
539
540
541
542
        # LoRA options
        parser.add_argument(
            "--lora-paths",
            type=str,
            nargs="*",
            default=None,
543
544
            action=LoRAPathAction,
            help="The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}",
545
546
547
548
549
550
551
552
        )
        parser.add_argument(
            "--max-loras-per-batch",
            type=int,
            default=8,
            help="Maximum number of adapters for a running batch, include base-only request",
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
553
554
    @classmethod
    def from_cli_args(cls, args: argparse.Namespace):
555
556
        args.tp_size = args.tensor_parallel_size
        args.dp_size = args.data_parallel_size
Lianmin Zheng's avatar
Lianmin Zheng committed
557
558
559
560
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        return cls(**{attr: getattr(args, attr) for attr in attrs})

    def url(self):
561
562
563
564
        if is_ipv6(self.host):
            return f"http://[{self.host}]:{self.port}"
        else:
            return f"http://{self.host}:{self.port}"
Lianmin Zheng's avatar
Lianmin Zheng committed
565

566
567
568
569
570
571
572
    def check_server_args(self):
        assert (
            self.tp_size % self.nnodes == 0
        ), "tp_size must be divisible by number of nodes"
        assert not (
            self.dp_size > 1 and self.node_rank is not None
        ), "multi-node data parallel is not supported"
573
574
575
576
577
578
        assert (
            self.max_loras_per_batch > 0
            # FIXME
            and (self.lora_paths is None or self.disable_cuda_graph)
            and (self.lora_paths is None or self.disable_radix_cache)
        ), "compatibility of lora and cuda graph and radix attention is in progress"
579

580
581
582
583
584
        assert self.dp_size == 1, (
            "The support for data parallelism is temporarily disabled during refactor. "
            "Please use sglang<=0.3.2 or wait for later updates."
        )

585
586
587
588
589
590
591
592
593
594
        if isinstance(self.lora_paths, list):
            lora_paths = self.lora_paths
            self.lora_paths = {}
            for lora_path in lora_paths:
                if "=" in lora_path:
                    name, path = lora_path.split("=", 1)
                    self.lora_paths[name] = path
                else:
                    self.lora_paths[lora_path] = lora_path

Lianmin Zheng's avatar
Lianmin Zheng committed
595

Lianmin Zheng's avatar
Lianmin Zheng committed
596
def prepare_server_args(argv: List[str]) -> ServerArgs:
597
598
599
600
601
602
603
604
605
606
607
608
    """
    Prepare the server arguments from the command line arguments.

    Args:
        args: The command line arguments. Typically, it should be `sys.argv[1:]`
            to ensure compatibility with `parse_args` when no arguments are passed.

    Returns:
        The server arguments.
    """
    parser = argparse.ArgumentParser()
    ServerArgs.add_cli_args(parser)
Lianmin Zheng's avatar
Lianmin Zheng committed
609
    raw_args = parser.parse_args(argv)
610
611
612
613
    server_args = ServerArgs.from_cli_args(raw_args)
    return server_args


Lianmin Zheng's avatar
Lianmin Zheng committed
614
615
@dataclasses.dataclass
class PortArgs:
616
617
618
619
620
621
    # The ipc filename for tokenizer to receive inputs from detokenizer (zmq)
    tokenizer_ipc_name: str
    # The ipc filename for scheduler (rank 0) to receive inputs from tokenizer (zmq)
    scheduler_input_ipc_name: str
    # The ipc filename for detokenizer to receive inputs from scheduler (zmq)
    detokenizer_ipc_name: str
622

623
    # The port for nccl initialization for multiple TP groups (torch.dist)
Mingyi's avatar
Mingyi committed
624
    nccl_ports: List[int]
625

626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
    @classmethod
    def init_new(self, server_args):
        port = server_args.port + 1
        while True:
            if is_port_available(port):
                break
            port += 1

        return PortArgs(
            tokenizer_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
            scheduler_input_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
            detokenizer_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
            nccl_ports=[port],
        )

641
642
643
644
645
646
647
648
649
650

class LoRAPathAction(argparse.Action):
    def __call__(self, parser, namespace, values, option_string=None):
        setattr(namespace, self.dest, {})
        for lora_path in values:
            if "=" in lora_path:
                name, path = lora_path.split("=", 1)
                getattr(namespace, self.dest)[name] = path
            else:
                getattr(namespace, self.dest)[lora_path] = lora_path