server_args.py 30.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

Lianmin Zheng's avatar
Lianmin Zheng committed
16
17
"""The arguments of the server."""

Lianmin Zheng's avatar
Lianmin Zheng committed
18
19
import argparse
import dataclasses
20
import logging
21
import random
22
23
import tempfile
from typing import List, Optional
Lianmin Zheng's avatar
Lianmin Zheng committed
24

25
from sglang.srt.utils import (
HAI's avatar
HAI committed
26
27
    get_amdgpu_memory_capacity,
    get_nvgpu_memory_capacity,
28
    is_flashinfer_available,
HAI's avatar
HAI committed
29
    is_hip,
30
31
32
    is_ipv6,
    is_port_available,
)
33

34
35
logger = logging.getLogger(__name__)

Lianmin Zheng's avatar
Lianmin Zheng committed
36
37
38

@dataclasses.dataclass
class ServerArgs:
Lianmin Zheng's avatar
Lianmin Zheng committed
39
    # Model and tokenizer
Lianmin Zheng's avatar
Lianmin Zheng committed
40
41
42
    model_path: str
    tokenizer_path: Optional[str] = None
    tokenizer_mode: str = "auto"
43
    skip_tokenizer_init: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
44
    load_format: str = "auto"
45
    trust_remote_code: bool = True
Lianmin Zheng's avatar
Lianmin Zheng committed
46
    dtype: str = "auto"
47
    kv_cache_dtype: str = "auto"
Lianmin Zheng's avatar
Lianmin Zheng committed
48
    quantization: Optional[str] = None
49
50
    context_length: Optional[int] = None
    device: str = "cuda"
51
    served_model_name: Optional[str] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
52
    chat_template: Optional[str] = None
53
    is_embedding: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
54
55
56
57
58
59

    # Port
    host: str = "127.0.0.1"
    port: int = 30000

    # Memory and scheduling
Lianmin Zheng's avatar
Lianmin Zheng committed
60
    mem_fraction_static: Optional[float] = None
61
    max_running_requests: Optional[int] = None
62
    max_total_tokens: Optional[int] = None
63
    chunked_prefill_size: int = 8192
64
    max_prefill_tokens: int = 16384
65
    schedule_policy: str = "lpm"
66
    schedule_conservativeness: float = 1.0
Lianmin Zheng's avatar
Lianmin Zheng committed
67
68
69

    # Other runtime options
    tp_size: int = 1
70
    stream_interval: int = 1
71
    random_seed: Optional[int] = None
72
    constrained_json_whitespace_pattern: Optional[str] = None
73
    watchdog_timeout: float = 300
74
    download_dir: Optional[str] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
75
76
77

    # Logging
    log_level: str = "info"
78
    log_level_http: Optional[str] = None
79
    log_requests: bool = False
Liangsheng Yin's avatar
Liangsheng Yin committed
80
    show_time_cost: bool = False
81
    enable_metrics: bool = False
82
    decode_log_interval: int = 40
Liangsheng Yin's avatar
Liangsheng Yin committed
83

84
    # API related
85
    api_key: Optional[str] = None
86
    file_storage_pth: str = "SGLang_storage"
87
    enable_cache_report: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
88

89
90
91
92
    # Data parallelism
    dp_size: int = 1
    load_balance_method: str = "round_robin"

93
    # Multi-node distributed serving
94
    dist_init_addr: Optional[str] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
95
    nnodes: int = 1
96
    node_rank: int = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
97
98
99
100

    # Model override args in JSON
    json_model_override_args: str = "{}"

Shuo Yang's avatar
Shuo Yang committed
101
102
103
104
105
106
107
108
    # Double Sparsity
    enable_double_sparsity: bool = False
    ds_channel_config_path: str = None
    ds_heavy_channel_num: int = 32
    ds_heavy_token_num: int = 256
    ds_heavy_channel_type: str = "qk"
    ds_sparse_decode_threshold: int = 4096

109
110
111
112
113
    # LoRA
    lora_paths: Optional[List[str]] = None
    max_loras_per_batch: int = 8

    # Kernel backend
114
115
    attention_backend: Optional[str] = None
    sampling_backend: Optional[str] = None
116
    grammar_backend: Optional[str] = "outlines"
117

118
    # Optimization/debug options
119
    disable_flashinfer: bool = False
120
    disable_flashinfer_sampling: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
121
    disable_radix_cache: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
122
    disable_jump_forward: bool = False
123
    disable_cuda_graph: bool = False
124
    disable_cuda_graph_padding: bool = False
125
    disable_disk_cache: bool = False
126
    disable_custom_all_reduce: bool = False
Ke Bao's avatar
Ke Bao committed
127
    disable_mla: bool = False
128
    disable_penalizer: bool = False
129
    disable_nan_detection: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
130
    enable_overlap_schedule: bool = False
131
    enable_mixed_chunk: bool = False
Ke Bao's avatar
Ke Bao committed
132
    enable_dp_attention: bool = False
133
    enable_torch_compile: bool = False
134
135
    torch_compile_max_bs: int = 32
    cuda_graph_max_bs: int = 160
136
    torchao_config: str = ""
137
    enable_p2p_check: bool = False
138
    triton_attention_reduce_in_fp32: bool = False
139
    num_continuous_decode_steps: int = 1
140
    delete_ckpt_after_loading: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
141
142

    def __post_init__(self):
143
        # Set missing default values
Lianmin Zheng's avatar
Lianmin Zheng committed
144
145
        if self.tokenizer_path is None:
            self.tokenizer_path = self.model_path
146
147
148
149

        if self.served_model_name is None:
            self.served_model_name = self.model_path

150
151
152
153
        if self.chunked_prefill_size <= 0:
            # Disable chunked prefill
            self.chunked_prefill_size = None

154
155
156
        if self.random_seed is None:
            self.random_seed = random.randint(0, 1 << 30)

157
        # Mem fraction depends on the tensor parallelism size
Lianmin Zheng's avatar
Lianmin Zheng committed
158
        if self.mem_fraction_static is None:
159
            if self.tp_size >= 16:
160
                self.mem_fraction_static = 0.79
161
            elif self.tp_size >= 8:
162
                self.mem_fraction_static = 0.83
Lianmin Zheng's avatar
Lianmin Zheng committed
163
            elif self.tp_size >= 4:
164
                self.mem_fraction_static = 0.85
Lianmin Zheng's avatar
Lianmin Zheng committed
165
            elif self.tp_size >= 2:
166
                self.mem_fraction_static = 0.87
Ying Sheng's avatar
Ying Sheng committed
167
            else:
168
                self.mem_fraction_static = 0.88
169

170
        # Adjust for GPUs with small memory capacities
HAI's avatar
HAI committed
171
172
173
174
        if is_hip():
            gpu_mem = get_amdgpu_memory_capacity()
        else:
            gpu_mem = get_nvgpu_memory_capacity()
175
176
177
178
179
180
        if gpu_mem < 25000:
            logger.warning(
                "Automatically adjust --chunked-prefill-size for small GPUs."
            )
            self.chunked_prefill_size //= 4  # make it 2048
            self.cuda_graph_max_bs = 4
181

182
183
184
185
186
187
        # Deprecation warnings
        if self.disable_flashinfer:
            logger.warning(
                "The option '--disable-flashinfer' will be deprecated in the next release. "
                "Please use '--attention-backend triton' instead."
            )
188
            self.attention_backend = "triton"
189
190
191
192
193
        if self.disable_flashinfer_sampling:
            logger.warning(
                "The option '--disable-flashinfer-sampling' will be deprecated in the next release. "
                "Please use '--sampling-backend pytorch' instead. "
            )
194
            self.sampling_backend = "pytorch"
195

196
        if not is_flashinfer_available():
197
198
199
            self.attention_backend = "triton"
            self.sampling_backend = "pytorch"

200
201
202
203
204
205
206
        # Default kernel backends
        if self.attention_backend is None:
            self.attention_backend = "flashinfer"

        if self.sampling_backend is None:
            self.sampling_backend = "flashinfer"

Ke Bao's avatar
Ke Bao committed
207
208
209
210
211
212
213
214
215
216
        if self.enable_dp_attention:
            self.dp_size = self.tp_size
            self.chunked_prefill_size = self.chunked_prefill_size // 2
            self.disable_cuda_graph = True
            self.enable_overlap_schedule = False
            logger.warning(
                f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE workload issue. "
                "The CUDA graph is disabled."
            )

217
218
219
220
221
222
223
224
225
226
        if self.enable_overlap_schedule:
            logger.warning(
                "Overlap scheduler mode is enabled. This is an experimental feature. "
                "Sampling penalizer (e.g., frequency and repetition penalty), constrained decoding (e.g., regex, JSON), "
                "and embedding APIs are not supported and will lead to wrong results. "
                "The NaN detection is also disabled."
            )
            self.disable_penalizer = True
            self.disable_nan_detection = True

227
228
229
230
231
232
233
234
235
236
237
        # Model-specific patches
        if "Alibaba-NLP/gte-Qwen2-1.5B-instruct" == self.model_path:
            logger.info(
                "Not sure why, the tokenizer will add an additional token at the end of the prompt when trust_remote_mode=True"
            )
            self.trust_remote_code = False

        if "gemma-2" in self.model_path.lower():
            logger.info("When using sliding window in gemma-2, turn on flashinfer.")
            self.attention_backend = "flashinfer"

Lianmin Zheng's avatar
Lianmin Zheng committed
238
239
    @staticmethod
    def add_cli_args(parser: argparse.ArgumentParser):
240
        # Model and port args
Lianmin Zheng's avatar
Lianmin Zheng committed
241
242
243
244
245
246
247
248
249
250
251
252
        parser.add_argument(
            "--model-path",
            type=str,
            help="The path of the model weights. This can be a local folder or a Hugging Face repo ID.",
            required=True,
        )
        parser.add_argument(
            "--tokenizer-path",
            type=str,
            default=ServerArgs.tokenizer_path,
            help="The path of the tokenizer.",
        )
Yuanhan Zhang's avatar
Yuanhan Zhang committed
253
254
255
256
257
258
        parser.add_argument(
            "--host", type=str, default=ServerArgs.host, help="The host of the server."
        )
        parser.add_argument(
            "--port", type=int, default=ServerArgs.port, help="The port of the server."
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
259
260
261
262
263
264
265
266
267
        parser.add_argument(
            "--tokenizer-mode",
            type=str,
            default=ServerArgs.tokenizer_mode,
            choices=["auto", "slow"],
            help="Tokenizer mode. 'auto' will use the fast "
            "tokenizer if available, and 'slow' will "
            "always use the slow tokenizer.",
        )
268
269
270
271
272
        parser.add_argument(
            "--skip-tokenizer-init",
            action="store_true",
            help="If set, skip init tokenizer and pass input_ids in generate request",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
        parser.add_argument(
            "--load-format",
            type=str,
            default=ServerArgs.load_format,
            choices=["auto", "pt", "safetensors", "npcache", "dummy"],
            help="The format of the model weights to load. "
            '"auto" will try to load the weights in the safetensors format '
            "and fall back to the pytorch bin format if safetensors format "
            "is not available. "
            '"pt" will load the weights in the pytorch bin format. '
            '"safetensors" will load the weights in the safetensors format. '
            '"npcache" will load the weights in pytorch format and store '
            "a numpy cache to speed up the loading. "
            '"dummy" will initialize the weights with random values, '
            "which is mainly for profiling.",
        )
289
290
291
292
293
        parser.add_argument(
            "--trust-remote-code",
            action="store_true",
            help="Whether or not to allow for custom models defined on the Hub in their own modeling files.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
294
        parser.add_argument(
Lianmin Zheng's avatar
Lianmin Zheng committed
295
            "--dtype",
Cody Yu's avatar
Cody Yu committed
296
            type=str,
Lianmin Zheng's avatar
Lianmin Zheng committed
297
            default=ServerArgs.dtype,
Ying Sheng's avatar
Ying Sheng committed
298
299
            choices=["auto", "half", "float16", "bfloat16", "float", "float32"],
            help="Data type for model weights and activations.\n\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
300
            '* "auto" will use FP16 precision for FP32 and FP16 models, and '
Ying Sheng's avatar
Ying Sheng committed
301
            "BF16 precision for BF16 models.\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
302
303
304
305
            '* "half" for FP16. Recommended for AWQ quantization.\n'
            '* "float16" is the same as "half".\n'
            '* "bfloat16" for a balance between precision and range.\n'
            '* "float" is shorthand for FP32 precision.\n'
Ying Sheng's avatar
Ying Sheng committed
306
307
            '* "float32" for FP32 precision.',
        )
308
309
310
311
312
313
314
        parser.add_argument(
            "--kv-cache-dtype",
            type=str,
            default=ServerArgs.kv_cache_dtype,
            choices=["auto", "fp8_e5m2"],
            help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" is supported for CUDA 11.8+.',
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
315
316
317
318
        parser.add_argument(
            "--quantization",
            type=str,
            default=ServerArgs.quantization,
Ying Sheng's avatar
Ying Sheng committed
319
320
321
322
323
324
            choices=[
                "awq",
                "fp8",
                "gptq",
                "marlin",
                "gptq_marlin",
Ying Sheng's avatar
Ying Sheng committed
325
                "awq_marlin",
Ying Sheng's avatar
Ying Sheng committed
326
327
                "bitsandbytes",
            ],
Lianmin Zheng's avatar
Lianmin Zheng committed
328
329
            help="The quantization method.",
        )
330
331
332
333
334
335
336
337
338
339
340
341
342
        parser.add_argument(
            "--context-length",
            type=int,
            default=ServerArgs.context_length,
            help="The model's maximum context length. Defaults to None (will use the value from the model's config.json instead).",
        )
        parser.add_argument(
            "--device",
            type=str,
            default="cuda",
            choices=["cuda", "xpu"],
            help="The device type.",
        )
343
344
345
346
347
348
        parser.add_argument(
            "--served-model-name",
            type=str,
            default=ServerArgs.served_model_name,
            help="Override the model name returned by the v1/models endpoint in OpenAI API server.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
349
350
351
352
353
354
        parser.add_argument(
            "--chat-template",
            type=str,
            default=ServerArgs.chat_template,
            help="The buliltin chat template name or the path of the chat template file. This is only used for OpenAI-compatible API server.",
        )
355
356
357
358
359
        parser.add_argument(
            "--is-embedding",
            action="store_true",
            help="Whether to use a CausalLM as an embedding model.",
        )
360
361

        # Memory and scheduling
Lianmin Zheng's avatar
Lianmin Zheng committed
362
363
364
365
        parser.add_argument(
            "--mem-fraction-static",
            type=float,
            default=ServerArgs.mem_fraction_static,
366
            help="The fraction of the memory used for static allocation (model weights and KV cache memory pool). Use a smaller value if you see out-of-memory errors.",
Lianmin Zheng's avatar
Lianmin Zheng committed
367
        )
368
369
370
371
372
373
        parser.add_argument(
            "--max-running-requests",
            type=int,
            default=ServerArgs.max_running_requests,
            help="The maximum number of running requests.",
        )
374
375
376
377
        parser.add_argument(
            "--max-total-tokens",
            type=int,
            default=ServerArgs.max_total_tokens,
378
379
            help="The maximum number of tokens in the memory pool. If not specified, it will be automatically calculated based on the memory usage fraction. "
            "This option is typically used for development and debugging purposes.",
380
        )
381
382
383
384
385
386
387
388
389
390
391
392
        parser.add_argument(
            "--chunked-prefill-size",
            type=int,
            default=ServerArgs.chunked_prefill_size,
            help="The maximum number of tokens in a chunk for the chunked prefill. Setting this to -1 means disabling chunked prefill",
        )
        parser.add_argument(
            "--max-prefill-tokens",
            type=int,
            default=ServerArgs.max_prefill_tokens,
            help="The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model's maximum context length.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
393
        parser.add_argument(
394
            "--schedule-policy",
Lianmin Zheng's avatar
Lianmin Zheng committed
395
            type=str,
396
            default=ServerArgs.schedule_policy,
Liangsheng Yin's avatar
Liangsheng Yin committed
397
            choices=["lpm", "random", "fcfs", "dfs-weight"],
398
            help="The scheduling policy of the requests.",
Lianmin Zheng's avatar
Lianmin Zheng committed
399
        )
400
401
402
403
        parser.add_argument(
            "--schedule-conservativeness",
            type=float,
            default=ServerArgs.schedule_conservativeness,
404
            help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see requests being retracted frequently.",
405
        )
406
407

        # Other runtime options
Lianmin Zheng's avatar
Lianmin Zheng committed
408
        parser.add_argument(
409
            "--tensor-parallel-size",
Lianmin Zheng's avatar
Lianmin Zheng committed
410
            "--tp-size",
Lianmin Zheng's avatar
Lianmin Zheng committed
411
            type=int,
Lianmin Zheng's avatar
Lianmin Zheng committed
412
            default=ServerArgs.tp_size,
413
            help="The tensor parallelism size.",
414
        )
415
416
417
        parser.add_argument(
            "--stream-interval",
            type=int,
Lianmin Zheng's avatar
Lianmin Zheng committed
418
            default=ServerArgs.stream_interval,
419
            help="The interval (or buffer size) for streaming in terms of the token length. A smaller value makes streaming smoother, while a larger value makes the throughput higher",
420
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
421
422
423
424
        parser.add_argument(
            "--random-seed",
            type=int,
            default=ServerArgs.random_seed,
425
            help="The random seed.",
Lianmin Zheng's avatar
Lianmin Zheng committed
426
        )
427
428
429
430
431
432
        parser.add_argument(
            "--constrained-json-whitespace-pattern",
            type=str,
            default=ServerArgs.constrained_json_whitespace_pattern,
            help=r"Regex pattern for syntactic whitespaces allowed in JSON constrained output. For example, to allow the model generate consecutive whitespaces, set the pattern to [\n\t ]*",
        )
433
434
435
436
437
438
        parser.add_argument(
            "--watchdog-timeout",
            type=float,
            default=ServerArgs.watchdog_timeout,
            help="Set watchdog timeout in seconds. If a forward batch takes longer than this, the server will crash to prevent hanging.",
        )
439
440
441
442
443
444
        parser.add_argument(
            "--download-dir",
            type=str,
            default=ServerArgs.download_dir,
            help="Model download directory.",
        )
445
446

        # Logging
Lianmin Zheng's avatar
Lianmin Zheng committed
447
448
449
450
        parser.add_argument(
            "--log-level",
            type=str,
            default=ServerArgs.log_level,
451
            help="The logging level of all loggers.",
Lianmin Zheng's avatar
Lianmin Zheng committed
452
        )
453
        parser.add_argument(
454
455
456
457
            "--log-level-http",
            type=str,
            default=ServerArgs.log_level_http,
            help="The logging level of HTTP server. If not set, reuse --log-level by default.",
458
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
459
        parser.add_argument(
460
            "--log-requests",
Lianmin Zheng's avatar
Lianmin Zheng committed
461
            action="store_true",
462
            help="Log the inputs and outputs of all requests.",
Lianmin Zheng's avatar
Lianmin Zheng committed
463
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
464
465
466
        parser.add_argument(
            "--show-time-cost",
            action="store_true",
Ying Sheng's avatar
Ying Sheng committed
467
            help="Show time cost of custom marks.",
Lianmin Zheng's avatar
Lianmin Zheng committed
468
        )
469
470
471
472
473
        parser.add_argument(
            "--enable-metrics",
            action="store_true",
            help="Enable log prometheus metrics.",
        )
474
475
476
477
478
479
        parser.add_argument(
            "--decode-log-interval",
            type=int,
            default=ServerArgs.decode_log_interval,
            help="The log interval of decode batch",
        )
480

481
        # API related
Liangsheng Yin's avatar
Liangsheng Yin committed
482
483
484
485
        parser.add_argument(
            "--api-key",
            type=str,
            default=ServerArgs.api_key,
486
            help="Set API key of the server. It is also used in the OpenAI API compatible server.",
Liangsheng Yin's avatar
Liangsheng Yin committed
487
        )
488
489
490
491
492
493
        parser.add_argument(
            "--file-storage-pth",
            type=str,
            default=ServerArgs.file_storage_pth,
            help="The path of the file storage in backend.",
        )
494
495
496
497
498
        parser.add_argument(
            "--enable-cache-report",
            action="store_true",
            help="Return number of cached tokens in usage.prompt_tokens_details for each openai request.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
499

500
501
        # Data parallelism
        parser.add_argument(
502
            "--data-parallel-size",
503
504
505
            "--dp-size",
            type=int,
            default=ServerArgs.dp_size,
506
            help="The data parallelism size.",
507
508
509
510
511
        )
        parser.add_argument(
            "--load-balance-method",
            type=str,
            default=ServerArgs.load_balance_method,
512
            help="The load balancing strategy for data parallelism.",
513
514
515
516
517
518
            choices=[
                "round_robin",
                "shortest_queue",
            ],
        )

519
        # Multi-node distributed serving
520
        parser.add_argument(
521
522
            "--dist-init-addr",
            "--nccl-init-addr",  # For backward compatbility. This will be removed in the future.
523
            type=str,
524
            help="The host address for initializing distributed backend (e.g., `192.168.0.2:25000`).",
525
526
        )
        parser.add_argument(
Liangsheng Yin's avatar
Liangsheng Yin committed
527
            "--nnodes", type=int, default=ServerArgs.nnodes, help="The number of nodes."
528
        )
529
530
531
        parser.add_argument(
            "--node-rank", type=int, default=ServerArgs.node_rank, help="The node rank."
        )
532

Lianmin Zheng's avatar
Lianmin Zheng committed
533
534
535
536
537
538
539
540
        # Model override args
        parser.add_argument(
            "--json-model-override-args",
            type=str,
            help="A dictionary in JSON string format used to override default model configurations.",
            default=ServerArgs.json_model_override_args,
        )

Shuo Yang's avatar
Shuo Yang committed
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
        # Double Sparsity
        parser.add_argument(
            "--enable-double-sparsity",
            action="store_true",
            help="Enable double sparsity attention",
        )
        parser.add_argument(
            "--ds-channel-config-path",
            type=str,
            default=ServerArgs.ds_channel_config_path,
            help="The path of the double sparsity channel config",
        )
        parser.add_argument(
            "--ds-heavy-channel-num",
            type=int,
            default=ServerArgs.ds_heavy_channel_num,
            help="The number of heavy channels in double sparsity attention",
        )
        parser.add_argument(
            "--ds-heavy-token-num",
            type=int,
            default=ServerArgs.ds_heavy_token_num,
            help="The number of heavy tokens in double sparsity attention",
        )
        parser.add_argument(
            "--ds-heavy-channel-type",
            type=str,
            default=ServerArgs.ds_heavy_channel_type,
            help="The type of heavy channels in double sparsity attention",
        )
        parser.add_argument(
            "--ds-sparse-decode-threshold",
            type=int,
            default=ServerArgs.ds_sparse_decode_threshold,
            help="The type of heavy channels in double sparsity attention",
        )

578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
        # LoRA
        parser.add_argument(
            "--lora-paths",
            type=str,
            nargs="*",
            default=None,
            action=LoRAPathAction,
            help="The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}",
        )
        parser.add_argument(
            "--max-loras-per-batch",
            type=int,
            default=8,
            help="Maximum number of adapters for a running batch, include base-only request",
        )

        # Kernel backend
595
596
597
598
599
600
601
602
603
604
605
606
607
608
        parser.add_argument(
            "--attention-backend",
            type=str,
            choices=["flashinfer", "triton"],
            default=ServerArgs.attention_backend,
            help="Choose the kernels for attention layers.",
        )
        parser.add_argument(
            "--sampling-backend",
            type=str,
            choices=["flashinfer", "pytorch"],
            default=ServerArgs.sampling_backend,
            help="Choose the kernels for sampling layers.",
        )
609
610
611
612
613
        parser.add_argument(
            "--grammar-backend",
            type=str,
            choices=["xgrammar", "outlines"],
            default=ServerArgs.grammar_backend,
Lianmin Zheng's avatar
Lianmin Zheng committed
614
            help="Choose the backend for grammar-guided decoding.",
615
        )
616
617

        # Optimization/debug options
Liangsheng Yin's avatar
Liangsheng Yin committed
618
        parser.add_argument(
619
            "--disable-flashinfer",
Liangsheng Yin's avatar
Liangsheng Yin committed
620
            action="store_true",
621
            help="Disable flashinfer attention kernels. This option will be deprecated in the next release. Please use '--attention-backend triton' instead.",
622
623
624
625
        )
        parser.add_argument(
            "--disable-flashinfer-sampling",
            action="store_true",
626
            help="Disable flashinfer sampling kernels. This option will be deprecated in the next release. Please use '--sampling-backend pytorch' instead.",
Liangsheng Yin's avatar
Liangsheng Yin committed
627
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
628
        parser.add_argument(
Lianmin Zheng's avatar
Lianmin Zheng committed
629
            "--disable-radix-cache",
Liangsheng Yin's avatar
Liangsheng Yin committed
630
            action="store_true",
Ying Sheng's avatar
Ying Sheng committed
631
            help="Disable RadixAttention for prefix caching.",
Liangsheng Yin's avatar
Liangsheng Yin committed
632
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
633
        parser.add_argument(
Lianmin Zheng's avatar
Lianmin Zheng committed
634
            "--disable-jump-forward",
Liangsheng Yin's avatar
Liangsheng Yin committed
635
            action="store_true",
Lianmin Zheng's avatar
Lianmin Zheng committed
636
            help="Disable jump-forward for grammar-guided decoding.",
Liangsheng Yin's avatar
Liangsheng Yin committed
637
        )
638
639
640
641
642
        parser.add_argument(
            "--disable-cuda-graph",
            action="store_true",
            help="Disable cuda graph.",
        )
643
        parser.add_argument(
644
645
646
647
648
            "--disable-cuda-graph-padding",
            action="store_true",
            help="Disable cuda graph when padding is needed. Still uses cuda graph when padding is not needed.",
        )
        parser.add_argument(
649
650
651
652
            "--disable-disk-cache",
            action="store_true",
            help="Disable disk cache to avoid possible crashes related to file system or high concurrency.",
        )
653
654
655
656
657
        parser.add_argument(
            "--disable-custom-all-reduce",
            action="store_true",
            help="Disable the custom all-reduce kernel and fall back to NCCL.",
        )
Ke Bao's avatar
Ke Bao committed
658
659
660
661
662
        parser.add_argument(
            "--disable-mla",
            action="store_true",
            help="Disable Multi-head Latent Attention (MLA) for DeepSeek-V2.",
        )
663
664
665
        parser.add_argument(
            "--disable-penalizer",
            action="store_true",
666
667
668
669
670
671
            help="Disable the logit penalizers (e.g., frequency and repetition penalty) for better performance if they are not used in any requests.",
        )
        parser.add_argument(
            "--disable-nan-detection",
            action="store_true",
            help="Disable the NaN detection for better performance.",
672
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
673
674
675
676
677
        parser.add_argument(
            "--enable-overlap-schedule",
            action="store_true",
            help="Overlap the CPU scheduler with GPU model worker. Experimental feature.",
        )
678
679
680
        parser.add_argument(
            "--enable-mixed-chunk",
            action="store_true",
681
            help="Enabling mixing prefill and decode in a batch when using chunked prefill.",
682
        )
Ke Bao's avatar
Ke Bao committed
683
684
685
686
687
        parser.add_argument(
            "--enable-dp-attention",
            action="store_true",
            help="Enabling data parallelism for attention and tensor parallelism for FFN. The dp size should be equal to the tp size. Currently only DeepSeek-V2 is supported.",
        )
688
689
690
        parser.add_argument(
            "--enable-torch-compile",
            action="store_true",
691
692
            help="Optimize the model with torch.compile. Experimental feature.",
        )
693
        parser.add_argument(
694
            "--torch-compile-max-bs",
695
            type=int,
696
            default=ServerArgs.torch_compile_max_bs,
697
698
            help="Set the maximum batch size when using torch compile.",
        )
699
        parser.add_argument(
700
            "--cuda-graph-max-bs",
701
            type=int,
702
            default=ServerArgs.cuda_graph_max_bs,
703
704
            help="Set the maximum batch size for cuda graph.",
        )
705
706
707
708
709
        parser.add_argument(
            "--torchao-config",
            type=str,
            default=ServerArgs.torchao_config,
            help="Optimize the model with torchao. Experimental feature. Current choices are: int8dq, int8wo, int4wo-<group_size>, fp8wo",
710
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
711
        parser.add_argument(
712
            "--enable-p2p-check",
Lianmin Zheng's avatar
Lianmin Zheng committed
713
            action="store_true",
714
            help="Enable P2P check for GPU access, otherwise the p2p access is allowed by default.",
Lianmin Zheng's avatar
Lianmin Zheng committed
715
        )
716
        parser.add_argument(
717
            "--triton-attention-reduce-in-fp32",
718
            action="store_true",
719
            help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
720
            "This only affects Triton attention kernels.",
721
        )
722
723
724
725
726
727
728
729
        parser.add_argument(
            "--num-continuous-decode-steps",
            type=int,
            default=ServerArgs.num_continuous_decode_steps,
            help="Run multiple continuous decoding steps to reduce scheduling overhead. "
            "This can potentially increase throughput but may also increase time-to-first-token latency. "
            "The default value is 1, meaning only run one decoding step at a time.",
        )
730
731
732
733
734
        parser.add_argument(
            "--delete-ckpt-after-loading",
            action="store_true",
            help="Delete the model checkpoint after loading the model.",
        )
735

Lianmin Zheng's avatar
Lianmin Zheng committed
736
737
    @classmethod
    def from_cli_args(cls, args: argparse.Namespace):
738
739
        args.tp_size = args.tensor_parallel_size
        args.dp_size = args.data_parallel_size
Lianmin Zheng's avatar
Lianmin Zheng committed
740
741
742
743
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        return cls(**{attr: getattr(args, attr) for attr in attrs})

    def url(self):
744
745
746
747
        if is_ipv6(self.host):
            return f"http://[{self.host}]:{self.port}"
        else:
            return f"http://{self.host}:{self.port}"
Lianmin Zheng's avatar
Lianmin Zheng committed
748

749
750
751
752
753
    def check_server_args(self):
        assert (
            self.tp_size % self.nnodes == 0
        ), "tp_size must be divisible by number of nodes"
        assert not (
754
            self.dp_size > 1 and self.nnodes != 1
755
        ), "multi-node data parallel is not supported"
756
757
758
759
760
761
        assert (
            self.max_loras_per_batch > 0
            # FIXME
            and (self.lora_paths is None or self.disable_cuda_graph)
            and (self.lora_paths is None or self.disable_radix_cache)
        ), "compatibility of lora and cuda graph and radix attention is in progress"
762

763
764
765
766
767
768
769
770
771
772
        if isinstance(self.lora_paths, list):
            lora_paths = self.lora_paths
            self.lora_paths = {}
            for lora_path in lora_paths:
                if "=" in lora_path:
                    name, path = lora_path.split("=", 1)
                    self.lora_paths[name] = path
                else:
                    self.lora_paths[lora_path] = lora_path

Lianmin Zheng's avatar
Lianmin Zheng committed
773

Lianmin Zheng's avatar
Lianmin Zheng committed
774
def prepare_server_args(argv: List[str]) -> ServerArgs:
775
776
777
778
779
780
781
782
783
784
785
786
    """
    Prepare the server arguments from the command line arguments.

    Args:
        args: The command line arguments. Typically, it should be `sys.argv[1:]`
            to ensure compatibility with `parse_args` when no arguments are passed.

    Returns:
        The server arguments.
    """
    parser = argparse.ArgumentParser()
    ServerArgs.add_cli_args(parser)
Lianmin Zheng's avatar
Lianmin Zheng committed
787
    raw_args = parser.parse_args(argv)
788
789
790
791
    server_args = ServerArgs.from_cli_args(raw_args)
    return server_args


Lianmin Zheng's avatar
Lianmin Zheng committed
792
793
@dataclasses.dataclass
class PortArgs:
794
795
796
797
798
799
    # The ipc filename for tokenizer to receive inputs from detokenizer (zmq)
    tokenizer_ipc_name: str
    # The ipc filename for scheduler (rank 0) to receive inputs from tokenizer (zmq)
    scheduler_input_ipc_name: str
    # The ipc filename for detokenizer to receive inputs from scheduler (zmq)
    detokenizer_ipc_name: str
800

801
802
    # The port for nccl initialization (torch.dist)
    nccl_port: int
803

804
805
    @staticmethod
    def init_new(server_args) -> "PortArgs":
806
        port = server_args.port + 42
807
808
809
        while True:
            if is_port_available(port):
                break
810
            port += 42
811
812
813
814
815

        return PortArgs(
            tokenizer_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
            scheduler_input_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
            detokenizer_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
816
            nccl_port=port,
817
818
        )

819
820
821
822
823
824
825
826
827
828

class LoRAPathAction(argparse.Action):
    def __call__(self, parser, namespace, values, option_string=None):
        setattr(namespace, self.dest, {})
        for lora_path in values:
            if "=" in lora_path:
                name, path = lora_path.split("=", 1)
                getattr(namespace, self.dest)[name] = path
            else:
                getattr(namespace, self.dest)[lora_path] = lora_path