server_args.py 28.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

Lianmin Zheng's avatar
Lianmin Zheng committed
16
17
"""The arguments of the server."""

Lianmin Zheng's avatar
Lianmin Zheng committed
18
19
import argparse
import dataclasses
20
import logging
21
import random
22
23
import tempfile
from typing import List, Optional
Lianmin Zheng's avatar
Lianmin Zheng committed
24

25
from sglang.srt.utils import is_flashinfer_available, is_ipv6, is_port_available
26

27
28
logger = logging.getLogger(__name__)

Lianmin Zheng's avatar
Lianmin Zheng committed
29
30
31

@dataclasses.dataclass
class ServerArgs:
Lianmin Zheng's avatar
Lianmin Zheng committed
32
    # Model and tokenizer
Lianmin Zheng's avatar
Lianmin Zheng committed
33
34
35
    model_path: str
    tokenizer_path: Optional[str] = None
    tokenizer_mode: str = "auto"
36
    skip_tokenizer_init: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
37
    load_format: str = "auto"
38
    trust_remote_code: bool = True
Lianmin Zheng's avatar
Lianmin Zheng committed
39
    dtype: str = "auto"
40
    kv_cache_dtype: str = "auto"
Lianmin Zheng's avatar
Lianmin Zheng committed
41
    quantization: Optional[str] = None
42
43
    context_length: Optional[int] = None
    device: str = "cuda"
44
    served_model_name: Optional[str] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
45
    chat_template: Optional[str] = None
46
    is_embedding: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
47
48
49
50
51
52

    # Port
    host: str = "127.0.0.1"
    port: int = 30000

    # Memory and scheduling
Lianmin Zheng's avatar
Lianmin Zheng committed
53
    mem_fraction_static: Optional[float] = None
54
    max_running_requests: Optional[int] = None
55
    max_total_tokens: Optional[int] = None
56
    chunked_prefill_size: int = 8192
57
    max_prefill_tokens: int = 16384
58
    schedule_policy: str = "lpm"
59
    schedule_conservativeness: float = 1.0
Lianmin Zheng's avatar
Lianmin Zheng committed
60
61
62

    # Other runtime options
    tp_size: int = 1
63
    stream_interval: int = 1
64
    random_seed: Optional[int] = None
65
    constrained_json_whitespace_pattern: Optional[str] = None
66
    decode_log_interval: int = 40
Lianmin Zheng's avatar
Lianmin Zheng committed
67
68
69

    # Logging
    log_level: str = "info"
70
    log_level_http: Optional[str] = None
71
    log_requests: bool = False
Liangsheng Yin's avatar
Liangsheng Yin committed
72
    show_time_cost: bool = False
73
    enable_metrics: bool = False
Liangsheng Yin's avatar
Liangsheng Yin committed
74

Lianmin Zheng's avatar
Lianmin Zheng committed
75
    # Other
76
    api_key: Optional[str] = None
77
    file_storage_pth: str = "SGLang_storage"
78
    enable_cache_report: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
79
    watchdog_timeout: float = 600
Lianmin Zheng's avatar
Lianmin Zheng committed
80

81
82
83
84
    # Data parallelism
    dp_size: int = 1
    load_balance_method: str = "round_robin"

Lianmin Zheng's avatar
Lianmin Zheng committed
85
    # Distributed args
86
    dist_init_addr: Optional[str] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
87
    nnodes: int = 1
88
    node_rank: int = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
89
90
91
92

    # Model override args in JSON
    json_model_override_args: str = "{}"

Shuo Yang's avatar
Shuo Yang committed
93
94
95
96
97
98
99
100
    # Double Sparsity
    enable_double_sparsity: bool = False
    ds_channel_config_path: str = None
    ds_heavy_channel_num: int = 32
    ds_heavy_token_num: int = 256
    ds_heavy_channel_type: str = "qk"
    ds_sparse_decode_threshold: int = 4096

101
102
103
104
105
    # LoRA
    lora_paths: Optional[List[str]] = None
    max_loras_per_batch: int = 8

    # Kernel backend
106
107
    attention_backend: Optional[str] = None
    sampling_backend: Optional[str] = None
108
    grammar_backend: Optional[str] = "outlines"
109

110
    # Optimization/debug options
111
    disable_flashinfer: bool = False
112
    disable_flashinfer_sampling: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
113
    disable_radix_cache: bool = False
114
    disable_regex_jump_forward: bool = False
115
    disable_cuda_graph: bool = False
116
    disable_cuda_graph_padding: bool = False
117
    disable_disk_cache: bool = False
118
    disable_custom_all_reduce: bool = False
Ke Bao's avatar
Ke Bao committed
119
    disable_mla: bool = False
120
    disable_penalizer: bool = False
121
    disable_nan_detection: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
122
    enable_overlap_schedule: bool = False
123
    enable_mixed_chunk: bool = False
124
    enable_torch_compile: bool = False
125
126
    torch_compile_max_bs: int = 32
    cuda_graph_max_bs: int = 160
127
    torchao_config: str = ""
128
    enable_p2p_check: bool = False
129
    triton_attention_reduce_in_fp32: bool = False
130
    num_continuous_decode_steps: int = 1
Lianmin Zheng's avatar
Lianmin Zheng committed
131
132

    def __post_init__(self):
133
        # Set missing default values
Lianmin Zheng's avatar
Lianmin Zheng committed
134
135
        if self.tokenizer_path is None:
            self.tokenizer_path = self.model_path
136
137
138
139

        if self.served_model_name is None:
            self.served_model_name = self.model_path

140
141
142
143
        if self.chunked_prefill_size <= 0:
            # Disable chunked prefill
            self.chunked_prefill_size = None

144
        # Mem fraction depends on the tensor parallelism size
Lianmin Zheng's avatar
Lianmin Zheng committed
145
        if self.mem_fraction_static is None:
146
            if self.tp_size >= 16:
147
                self.mem_fraction_static = 0.79
148
            elif self.tp_size >= 8:
149
                self.mem_fraction_static = 0.83
Lianmin Zheng's avatar
Lianmin Zheng committed
150
            elif self.tp_size >= 4:
151
                self.mem_fraction_static = 0.85
Lianmin Zheng's avatar
Lianmin Zheng committed
152
            elif self.tp_size >= 2:
153
                self.mem_fraction_static = 0.87
Ying Sheng's avatar
Ying Sheng committed
154
            else:
155
                self.mem_fraction_static = 0.88
156

157
158
159
        if self.random_seed is None:
            self.random_seed = random.randint(0, 1 << 30)

160
161
162
163
164
165
        # Deprecation warnings
        if self.disable_flashinfer:
            logger.warning(
                "The option '--disable-flashinfer' will be deprecated in the next release. "
                "Please use '--attention-backend triton' instead."
            )
166
            self.attention_backend = "triton"
167
168
169
170
171
        if self.disable_flashinfer_sampling:
            logger.warning(
                "The option '--disable-flashinfer-sampling' will be deprecated in the next release. "
                "Please use '--sampling-backend pytorch' instead. "
            )
172
            self.sampling_backend = "pytorch"
173

174
        if not is_flashinfer_available():
175
176
177
            self.attention_backend = "triton"
            self.sampling_backend = "pytorch"

178
179
180
181
182
183
184
        # Default kernel backends
        if self.attention_backend is None:
            self.attention_backend = "flashinfer"

        if self.sampling_backend is None:
            self.sampling_backend = "flashinfer"

185
186
187
188
189
190
191
192
193
194
        if self.enable_overlap_schedule:
            logger.warning(
                "Overlap scheduler mode is enabled. This is an experimental feature. "
                "Sampling penalizer (e.g., frequency and repetition penalty), constrained decoding (e.g., regex, JSON), "
                "and embedding APIs are not supported and will lead to wrong results. "
                "The NaN detection is also disabled."
            )
            self.disable_penalizer = True
            self.disable_nan_detection = True

195
196
197
198
199
200
201
202
203
204
205
        # Model-specific patches
        if "Alibaba-NLP/gte-Qwen2-1.5B-instruct" == self.model_path:
            logger.info(
                "Not sure why, the tokenizer will add an additional token at the end of the prompt when trust_remote_mode=True"
            )
            self.trust_remote_code = False

        if "gemma-2" in self.model_path.lower():
            logger.info("When using sliding window in gemma-2, turn on flashinfer.")
            self.attention_backend = "flashinfer"

Lianmin Zheng's avatar
Lianmin Zheng committed
206
207
208
209
210
211
212
213
214
215
216
217
218
219
    @staticmethod
    def add_cli_args(parser: argparse.ArgumentParser):
        parser.add_argument(
            "--model-path",
            type=str,
            help="The path of the model weights. This can be a local folder or a Hugging Face repo ID.",
            required=True,
        )
        parser.add_argument(
            "--tokenizer-path",
            type=str,
            default=ServerArgs.tokenizer_path,
            help="The path of the tokenizer.",
        )
Yuanhan Zhang's avatar
Yuanhan Zhang committed
220
221
222
223
224
225
        parser.add_argument(
            "--host", type=str, default=ServerArgs.host, help="The host of the server."
        )
        parser.add_argument(
            "--port", type=int, default=ServerArgs.port, help="The port of the server."
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
226
227
228
229
230
231
232
233
234
        parser.add_argument(
            "--tokenizer-mode",
            type=str,
            default=ServerArgs.tokenizer_mode,
            choices=["auto", "slow"],
            help="Tokenizer mode. 'auto' will use the fast "
            "tokenizer if available, and 'slow' will "
            "always use the slow tokenizer.",
        )
235
236
237
238
239
        parser.add_argument(
            "--skip-tokenizer-init",
            action="store_true",
            help="If set, skip init tokenizer and pass input_ids in generate request",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
        parser.add_argument(
            "--load-format",
            type=str,
            default=ServerArgs.load_format,
            choices=["auto", "pt", "safetensors", "npcache", "dummy"],
            help="The format of the model weights to load. "
            '"auto" will try to load the weights in the safetensors format '
            "and fall back to the pytorch bin format if safetensors format "
            "is not available. "
            '"pt" will load the weights in the pytorch bin format. '
            '"safetensors" will load the weights in the safetensors format. '
            '"npcache" will load the weights in pytorch format and store '
            "a numpy cache to speed up the loading. "
            '"dummy" will initialize the weights with random values, '
            "which is mainly for profiling.",
        )
256
257
258
259
260
        parser.add_argument(
            "--trust-remote-code",
            action="store_true",
            help="Whether or not to allow for custom models defined on the Hub in their own modeling files.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
261
        parser.add_argument(
Lianmin Zheng's avatar
Lianmin Zheng committed
262
            "--dtype",
Cody Yu's avatar
Cody Yu committed
263
            type=str,
Lianmin Zheng's avatar
Lianmin Zheng committed
264
            default=ServerArgs.dtype,
Ying Sheng's avatar
Ying Sheng committed
265
266
            choices=["auto", "half", "float16", "bfloat16", "float", "float32"],
            help="Data type for model weights and activations.\n\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
267
            '* "auto" will use FP16 precision for FP32 and FP16 models, and '
Ying Sheng's avatar
Ying Sheng committed
268
            "BF16 precision for BF16 models.\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
269
270
271
272
            '* "half" for FP16. Recommended for AWQ quantization.\n'
            '* "float16" is the same as "half".\n'
            '* "bfloat16" for a balance between precision and range.\n'
            '* "float" is shorthand for FP32 precision.\n'
Ying Sheng's avatar
Ying Sheng committed
273
274
            '* "float32" for FP32 precision.',
        )
275
276
277
278
279
280
281
        parser.add_argument(
            "--kv-cache-dtype",
            type=str,
            default=ServerArgs.kv_cache_dtype,
            choices=["auto", "fp8_e5m2"],
            help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" is supported for CUDA 11.8+.',
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
282
283
284
285
        parser.add_argument(
            "--quantization",
            type=str,
            default=ServerArgs.quantization,
Ying Sheng's avatar
Ying Sheng committed
286
287
288
289
290
291
            choices=[
                "awq",
                "fp8",
                "gptq",
                "marlin",
                "gptq_marlin",
Ying Sheng's avatar
Ying Sheng committed
292
                "awq_marlin",
Ying Sheng's avatar
Ying Sheng committed
293
294
                "bitsandbytes",
            ],
Lianmin Zheng's avatar
Lianmin Zheng committed
295
296
            help="The quantization method.",
        )
297
298
299
300
301
302
303
304
305
306
307
308
309
        parser.add_argument(
            "--context-length",
            type=int,
            default=ServerArgs.context_length,
            help="The model's maximum context length. Defaults to None (will use the value from the model's config.json instead).",
        )
        parser.add_argument(
            "--device",
            type=str,
            default="cuda",
            choices=["cuda", "xpu"],
            help="The device type.",
        )
310
311
312
313
314
315
        parser.add_argument(
            "--served-model-name",
            type=str,
            default=ServerArgs.served_model_name,
            help="Override the model name returned by the v1/models endpoint in OpenAI API server.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
316
317
318
319
320
321
        parser.add_argument(
            "--chat-template",
            type=str,
            default=ServerArgs.chat_template,
            help="The buliltin chat template name or the path of the chat template file. This is only used for OpenAI-compatible API server.",
        )
322
323
324
325
326
        parser.add_argument(
            "--is-embedding",
            action="store_true",
            help="Whether to use a CausalLM as an embedding model.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
327
328
329
330
        parser.add_argument(
            "--mem-fraction-static",
            type=float,
            default=ServerArgs.mem_fraction_static,
331
            help="The fraction of the memory used for static allocation (model weights and KV cache memory pool). Use a smaller value if you see out-of-memory errors.",
Lianmin Zheng's avatar
Lianmin Zheng committed
332
        )
333
334
335
336
337
338
        parser.add_argument(
            "--max-running-requests",
            type=int,
            default=ServerArgs.max_running_requests,
            help="The maximum number of running requests.",
        )
339
340
341
342
        parser.add_argument(
            "--max-total-tokens",
            type=int,
            default=ServerArgs.max_total_tokens,
343
344
            help="The maximum number of tokens in the memory pool. If not specified, it will be automatically calculated based on the memory usage fraction. "
            "This option is typically used for development and debugging purposes.",
345
        )
346
347
348
349
350
351
352
353
354
355
356
357
        parser.add_argument(
            "--chunked-prefill-size",
            type=int,
            default=ServerArgs.chunked_prefill_size,
            help="The maximum number of tokens in a chunk for the chunked prefill. Setting this to -1 means disabling chunked prefill",
        )
        parser.add_argument(
            "--max-prefill-tokens",
            type=int,
            default=ServerArgs.max_prefill_tokens,
            help="The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model's maximum context length.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
358
        parser.add_argument(
359
            "--schedule-policy",
Lianmin Zheng's avatar
Lianmin Zheng committed
360
            type=str,
361
            default=ServerArgs.schedule_policy,
Liangsheng Yin's avatar
Liangsheng Yin committed
362
            choices=["lpm", "random", "fcfs", "dfs-weight"],
363
            help="The scheduling policy of the requests.",
Lianmin Zheng's avatar
Lianmin Zheng committed
364
        )
365
366
367
368
        parser.add_argument(
            "--schedule-conservativeness",
            type=float,
            default=ServerArgs.schedule_conservativeness,
369
            help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see requests being retracted frequently.",
370
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
371
        parser.add_argument(
372
            "--tensor-parallel-size",
Lianmin Zheng's avatar
Lianmin Zheng committed
373
            "--tp-size",
Lianmin Zheng's avatar
Lianmin Zheng committed
374
            type=int,
Lianmin Zheng's avatar
Lianmin Zheng committed
375
            default=ServerArgs.tp_size,
376
            help="The tensor parallelism size.",
377
        )
378
379
380
        parser.add_argument(
            "--stream-interval",
            type=int,
Lianmin Zheng's avatar
Lianmin Zheng committed
381
            default=ServerArgs.stream_interval,
382
            help="The interval (or buffer size) for streaming in terms of the token length. A smaller value makes streaming smoother, while a larger value makes the throughput higher",
383
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
384
385
386
387
        parser.add_argument(
            "--random-seed",
            type=int,
            default=ServerArgs.random_seed,
388
            help="The random seed.",
Lianmin Zheng's avatar
Lianmin Zheng committed
389
        )
390
391
392
393
394
395
        parser.add_argument(
            "--constrained-json-whitespace-pattern",
            type=str,
            default=ServerArgs.constrained_json_whitespace_pattern,
            help=r"Regex pattern for syntactic whitespaces allowed in JSON constrained output. For example, to allow the model generate consecutive whitespaces, set the pattern to [\n\t ]*",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
396
397
398
399
        parser.add_argument(
            "--log-level",
            type=str,
            default=ServerArgs.log_level,
400
            help="The logging level of all loggers.",
Lianmin Zheng's avatar
Lianmin Zheng committed
401
        )
402
        parser.add_argument(
403
404
405
406
            "--log-level-http",
            type=str,
            default=ServerArgs.log_level_http,
            help="The logging level of HTTP server. If not set, reuse --log-level by default.",
407
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
408
        parser.add_argument(
409
            "--log-requests",
Lianmin Zheng's avatar
Lianmin Zheng committed
410
            action="store_true",
411
            help="Log the inputs and outputs of all requests.",
Lianmin Zheng's avatar
Lianmin Zheng committed
412
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
413
414
415
        parser.add_argument(
            "--show-time-cost",
            action="store_true",
Ying Sheng's avatar
Ying Sheng committed
416
            help="Show time cost of custom marks.",
Lianmin Zheng's avatar
Lianmin Zheng committed
417
        )
418
419
420
421
422
423
        parser.add_argument(
            "--enable-metrics",
            action="store_true",
            help="Enable log prometheus metrics.",
        )

Liangsheng Yin's avatar
Liangsheng Yin committed
424
425
426
427
        parser.add_argument(
            "--api-key",
            type=str,
            default=ServerArgs.api_key,
428
            help="Set API key of the server. It is also used in the OpenAI API compatible server.",
Liangsheng Yin's avatar
Liangsheng Yin committed
429
        )
430
431
432
433
434
435
        parser.add_argument(
            "--file-storage-pth",
            type=str,
            default=ServerArgs.file_storage_pth,
            help="The path of the file storage in backend.",
        )
436
437
438
439
440
        parser.add_argument(
            "--enable-cache-report",
            action="store_true",
            help="Return number of cached tokens in usage.prompt_tokens_details for each openai request.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
441
442
443
444
445
446
        parser.add_argument(
            "--watchdog-timeout",
            type=float,
            default=ServerArgs.watchdog_timeout,
            help="Set watchdog timeout in seconds. If a forward batch takes longer than this, the server will crash to prevent hanging.",
        )
447
448
449
450
        parser.add_argument(
            "--decode-log-interval",
            type=int,
            default=ServerArgs.decode_log_interval,
Chayenne's avatar
Chayenne committed
451
            help="The log interval of decode batch",
452
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
453

454
455
        # Data parallelism
        parser.add_argument(
456
            "--data-parallel-size",
457
458
459
            "--dp-size",
            type=int,
            default=ServerArgs.dp_size,
460
            help="The data parallelism size.",
461
462
463
464
465
        )
        parser.add_argument(
            "--load-balance-method",
            type=str,
            default=ServerArgs.load_balance_method,
466
            help="The load balancing strategy for data parallelism.",
467
468
469
470
471
472
            choices=[
                "round_robin",
                "shortest_queue",
            ],
        )

473
474
        # Multi-node distributed serving args
        parser.add_argument(
475
476
            "--dist-init-addr",
            "--nccl-init-addr",  # For backward compatbility. This will be removed in the future.
477
            type=str,
478
            help="The host address for initializing distributed backend (e.g., `192.168.0.2:25000`).",
479
480
        )
        parser.add_argument(
Liangsheng Yin's avatar
Liangsheng Yin committed
481
            "--nnodes", type=int, default=ServerArgs.nnodes, help="The number of nodes."
482
        )
483
484
485
        parser.add_argument(
            "--node-rank", type=int, default=ServerArgs.node_rank, help="The node rank."
        )
486

Lianmin Zheng's avatar
Lianmin Zheng committed
487
488
489
490
491
492
493
494
        # Model override args
        parser.add_argument(
            "--json-model-override-args",
            type=str,
            help="A dictionary in JSON string format used to override default model configurations.",
            default=ServerArgs.json_model_override_args,
        )

Shuo Yang's avatar
Shuo Yang committed
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
        # Double Sparsity
        parser.add_argument(
            "--enable-double-sparsity",
            action="store_true",
            help="Enable double sparsity attention",
        )
        parser.add_argument(
            "--ds-channel-config-path",
            type=str,
            default=ServerArgs.ds_channel_config_path,
            help="The path of the double sparsity channel config",
        )
        parser.add_argument(
            "--ds-heavy-channel-num",
            type=int,
            default=ServerArgs.ds_heavy_channel_num,
            help="The number of heavy channels in double sparsity attention",
        )
        parser.add_argument(
            "--ds-heavy-token-num",
            type=int,
            default=ServerArgs.ds_heavy_token_num,
            help="The number of heavy tokens in double sparsity attention",
        )
        parser.add_argument(
            "--ds-heavy-channel-type",
            type=str,
            default=ServerArgs.ds_heavy_channel_type,
            help="The type of heavy channels in double sparsity attention",
        )
        parser.add_argument(
            "--ds-sparse-decode-threshold",
            type=int,
            default=ServerArgs.ds_sparse_decode_threshold,
            help="The type of heavy channels in double sparsity attention",
        )

532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
        # LoRA
        parser.add_argument(
            "--lora-paths",
            type=str,
            nargs="*",
            default=None,
            action=LoRAPathAction,
            help="The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}",
        )
        parser.add_argument(
            "--max-loras-per-batch",
            type=int,
            default=8,
            help="Maximum number of adapters for a running batch, include base-only request",
        )

        # Kernel backend
549
550
551
552
553
554
555
556
557
558
559
560
561
562
        parser.add_argument(
            "--attention-backend",
            type=str,
            choices=["flashinfer", "triton"],
            default=ServerArgs.attention_backend,
            help="Choose the kernels for attention layers.",
        )
        parser.add_argument(
            "--sampling-backend",
            type=str,
            choices=["flashinfer", "pytorch"],
            default=ServerArgs.sampling_backend,
            help="Choose the kernels for sampling layers.",
        )
563
564
565
566
567
568
569
        parser.add_argument(
            "--grammar-backend",
            type=str,
            choices=["xgrammar", "outlines"],
            default=ServerArgs.grammar_backend,
            help="Choose the backend for constrained decoding.",
        )
570
571

        # Optimization/debug options
Liangsheng Yin's avatar
Liangsheng Yin committed
572
        parser.add_argument(
573
            "--disable-flashinfer",
Liangsheng Yin's avatar
Liangsheng Yin committed
574
            action="store_true",
575
            help="Disable flashinfer attention kernels. This option will be deprecated in the next release. Please use '--attention-backend triton' instead.",
576
577
578
579
        )
        parser.add_argument(
            "--disable-flashinfer-sampling",
            action="store_true",
580
            help="Disable flashinfer sampling kernels. This option will be deprecated in the next release. Please use '--sampling-backend pytorch' instead.",
Liangsheng Yin's avatar
Liangsheng Yin committed
581
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
582
        parser.add_argument(
Lianmin Zheng's avatar
Lianmin Zheng committed
583
            "--disable-radix-cache",
Liangsheng Yin's avatar
Liangsheng Yin committed
584
            action="store_true",
Ying Sheng's avatar
Ying Sheng committed
585
            help="Disable RadixAttention for prefix caching.",
Liangsheng Yin's avatar
Liangsheng Yin committed
586
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
587
        parser.add_argument(
588
            "--disable-regex-jump-forward",
Liangsheng Yin's avatar
Liangsheng Yin committed
589
            action="store_true",
Ying Sheng's avatar
Ying Sheng committed
590
            help="Disable regex jump-forward.",
Liangsheng Yin's avatar
Liangsheng Yin committed
591
        )
592
593
594
595
596
        parser.add_argument(
            "--disable-cuda-graph",
            action="store_true",
            help="Disable cuda graph.",
        )
597
        parser.add_argument(
598
599
600
601
602
            "--disable-cuda-graph-padding",
            action="store_true",
            help="Disable cuda graph when padding is needed. Still uses cuda graph when padding is not needed.",
        )
        parser.add_argument(
603
604
605
606
            "--disable-disk-cache",
            action="store_true",
            help="Disable disk cache to avoid possible crashes related to file system or high concurrency.",
        )
607
608
609
610
611
612
        parser.add_argument(
            "--disable-custom-all-reduce",
            action="store_true",
            default=False,
            help="Disable the custom all-reduce kernel and fall back to NCCL.",
        )
Ke Bao's avatar
Ke Bao committed
613
614
615
616
617
        parser.add_argument(
            "--disable-mla",
            action="store_true",
            help="Disable Multi-head Latent Attention (MLA) for DeepSeek-V2.",
        )
618
619
620
        parser.add_argument(
            "--disable-penalizer",
            action="store_true",
621
622
623
624
625
626
            help="Disable the logit penalizers (e.g., frequency and repetition penalty) for better performance if they are not used in any requests.",
        )
        parser.add_argument(
            "--disable-nan-detection",
            action="store_true",
            help="Disable the NaN detection for better performance.",
627
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
628
629
630
631
632
        parser.add_argument(
            "--enable-overlap-schedule",
            action="store_true",
            help="Overlap the CPU scheduler with GPU model worker. Experimental feature.",
        )
633
634
635
        parser.add_argument(
            "--enable-mixed-chunk",
            action="store_true",
636
            help="Enabling mixing prefill and decode in a batch when using chunked prefill.",
637
        )
638
639
640
        parser.add_argument(
            "--enable-torch-compile",
            action="store_true",
641
642
            help="Optimize the model with torch.compile. Experimental feature.",
        )
643
        parser.add_argument(
644
            "--torch-compile-max-bs",
645
            type=int,
646
            default=ServerArgs.torch_compile_max_bs,
647
648
            help="Set the maximum batch size when using torch compile.",
        )
649
        parser.add_argument(
650
            "--cuda-graph-max-bs",
651
            type=int,
652
            default=ServerArgs.cuda_graph_max_bs,
653
654
            help="Set the maximum batch size for cuda graph.",
        )
655
656
657
658
659
        parser.add_argument(
            "--torchao-config",
            type=str,
            default=ServerArgs.torchao_config,
            help="Optimize the model with torchao. Experimental feature. Current choices are: int8dq, int8wo, int4wo-<group_size>, fp8wo",
660
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
661
        parser.add_argument(
662
            "--enable-p2p-check",
Lianmin Zheng's avatar
Lianmin Zheng committed
663
            action="store_true",
664
            help="Enable P2P check for GPU access, otherwise the p2p access is allowed by default.",
Lianmin Zheng's avatar
Lianmin Zheng committed
665
        )
666
        parser.add_argument(
667
            "--triton-attention-reduce-in-fp32",
668
            action="store_true",
669
            help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
670
            "This only affects Triton attention kernels.",
671
        )
672
673
674
675
676
677
678
679
        parser.add_argument(
            "--num-continuous-decode-steps",
            type=int,
            default=ServerArgs.num_continuous_decode_steps,
            help="Run multiple continuous decoding steps to reduce scheduling overhead. "
            "This can potentially increase throughput but may also increase time-to-first-token latency. "
            "The default value is 1, meaning only run one decoding step at a time.",
        )
680

Lianmin Zheng's avatar
Lianmin Zheng committed
681
682
    @classmethod
    def from_cli_args(cls, args: argparse.Namespace):
683
684
        args.tp_size = args.tensor_parallel_size
        args.dp_size = args.data_parallel_size
Lianmin Zheng's avatar
Lianmin Zheng committed
685
686
687
688
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        return cls(**{attr: getattr(args, attr) for attr in attrs})

    def url(self):
689
690
691
692
        if is_ipv6(self.host):
            return f"http://[{self.host}]:{self.port}"
        else:
            return f"http://{self.host}:{self.port}"
Lianmin Zheng's avatar
Lianmin Zheng committed
693

694
695
696
697
698
    def check_server_args(self):
        assert (
            self.tp_size % self.nnodes == 0
        ), "tp_size must be divisible by number of nodes"
        assert not (
699
            self.dp_size > 1 and self.nnodes != 1
700
        ), "multi-node data parallel is not supported"
701
702
703
704
705
706
        assert (
            self.max_loras_per_batch > 0
            # FIXME
            and (self.lora_paths is None or self.disable_cuda_graph)
            and (self.lora_paths is None or self.disable_radix_cache)
        ), "compatibility of lora and cuda graph and radix attention is in progress"
707

708
709
710
711
712
713
714
715
716
717
        if isinstance(self.lora_paths, list):
            lora_paths = self.lora_paths
            self.lora_paths = {}
            for lora_path in lora_paths:
                if "=" in lora_path:
                    name, path = lora_path.split("=", 1)
                    self.lora_paths[name] = path
                else:
                    self.lora_paths[lora_path] = lora_path

Lianmin Zheng's avatar
Lianmin Zheng committed
718

Lianmin Zheng's avatar
Lianmin Zheng committed
719
def prepare_server_args(argv: List[str]) -> ServerArgs:
720
721
722
723
724
725
726
727
728
729
730
731
    """
    Prepare the server arguments from the command line arguments.

    Args:
        args: The command line arguments. Typically, it should be `sys.argv[1:]`
            to ensure compatibility with `parse_args` when no arguments are passed.

    Returns:
        The server arguments.
    """
    parser = argparse.ArgumentParser()
    ServerArgs.add_cli_args(parser)
Lianmin Zheng's avatar
Lianmin Zheng committed
732
    raw_args = parser.parse_args(argv)
733
734
735
736
    server_args = ServerArgs.from_cli_args(raw_args)
    return server_args


Lianmin Zheng's avatar
Lianmin Zheng committed
737
738
@dataclasses.dataclass
class PortArgs:
739
740
741
742
743
744
    # The ipc filename for tokenizer to receive inputs from detokenizer (zmq)
    tokenizer_ipc_name: str
    # The ipc filename for scheduler (rank 0) to receive inputs from tokenizer (zmq)
    scheduler_input_ipc_name: str
    # The ipc filename for detokenizer to receive inputs from scheduler (zmq)
    detokenizer_ipc_name: str
745

746
747
    # The port for nccl initialization (torch.dist)
    nccl_port: int
748

749
750
    @staticmethod
    def init_new(server_args) -> "PortArgs":
751
        port = server_args.port + 42
752
753
754
        while True:
            if is_port_available(port):
                break
755
            port += 42
756
757
758
759
760

        return PortArgs(
            tokenizer_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
            scheduler_input_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
            detokenizer_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
761
            nccl_port=port,
762
763
        )

764
765
766
767
768
769
770
771
772
773

class LoRAPathAction(argparse.Action):
    def __call__(self, parser, namespace, values, option_string=None):
        setattr(namespace, self.dest, {})
        for lora_path in values:
            if "=" in lora_path:
                name, path = lora_path.split("=", 1)
                getattr(namespace, self.dest)[name] = path
            else:
                getattr(namespace, self.dest)[lora_path] = lora_path