"docs/vscode:/vscode.git/clone" did not exist on "74f00474e1969f4fa5822c36e6328724ff109b00"
server_args.py 30 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

Lianmin Zheng's avatar
Lianmin Zheng committed
16
17
"""The arguments of the server."""

Lianmin Zheng's avatar
Lianmin Zheng committed
18
19
import argparse
import dataclasses
20
import logging
21
import random
22
23
import tempfile
from typing import List, Optional
Lianmin Zheng's avatar
Lianmin Zheng committed
24

25
from sglang.srt.utils import (
HAI's avatar
HAI committed
26
27
    get_amdgpu_memory_capacity,
    get_nvgpu_memory_capacity,
28
    is_flashinfer_available,
HAI's avatar
HAI committed
29
    is_hip,
30
31
32
    is_ipv6,
    is_port_available,
)
33

34
35
logger = logging.getLogger(__name__)

Lianmin Zheng's avatar
Lianmin Zheng committed
36
37
38

@dataclasses.dataclass
class ServerArgs:
Lianmin Zheng's avatar
Lianmin Zheng committed
39
    # Model and tokenizer
Lianmin Zheng's avatar
Lianmin Zheng committed
40
41
42
    model_path: str
    tokenizer_path: Optional[str] = None
    tokenizer_mode: str = "auto"
43
    skip_tokenizer_init: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
44
    load_format: str = "auto"
45
    trust_remote_code: bool = True
Lianmin Zheng's avatar
Lianmin Zheng committed
46
    dtype: str = "auto"
47
    kv_cache_dtype: str = "auto"
Lianmin Zheng's avatar
Lianmin Zheng committed
48
    quantization: Optional[str] = None
49
50
    context_length: Optional[int] = None
    device: str = "cuda"
51
    served_model_name: Optional[str] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
52
    chat_template: Optional[str] = None
53
    is_embedding: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
54
55
56
57
58
59

    # Port
    host: str = "127.0.0.1"
    port: int = 30000

    # Memory and scheduling
Lianmin Zheng's avatar
Lianmin Zheng committed
60
    mem_fraction_static: Optional[float] = None
61
    max_running_requests: Optional[int] = None
62
    max_total_tokens: Optional[int] = None
63
    chunked_prefill_size: int = 8192
64
    max_prefill_tokens: int = 16384
65
    schedule_policy: str = "lpm"
66
    schedule_conservativeness: float = 1.0
Lianmin Zheng's avatar
Lianmin Zheng committed
67
68
69

    # Other runtime options
    tp_size: int = 1
70
    stream_interval: int = 1
71
    random_seed: Optional[int] = None
72
    constrained_json_whitespace_pattern: Optional[str] = None
73
    watchdog_timeout: float = 300
74
    download_dir: Optional[str] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
75
76
77

    # Logging
    log_level: str = "info"
78
    log_level_http: Optional[str] = None
79
    log_requests: bool = False
Liangsheng Yin's avatar
Liangsheng Yin committed
80
    show_time_cost: bool = False
81
    enable_metrics: bool = False
82
    decode_log_interval: int = 40
Liangsheng Yin's avatar
Liangsheng Yin committed
83

84
    # API related
85
    api_key: Optional[str] = None
86
    file_storage_pth: str = "SGLang_storage"
87
    enable_cache_report: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
88

89
90
91
92
    # Data parallelism
    dp_size: int = 1
    load_balance_method: str = "round_robin"

93
    # Multi-node distributed serving
94
    dist_init_addr: Optional[str] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
95
    nnodes: int = 1
96
    node_rank: int = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
97
98
99
100

    # Model override args in JSON
    json_model_override_args: str = "{}"

Shuo Yang's avatar
Shuo Yang committed
101
102
103
104
105
106
107
108
    # Double Sparsity
    enable_double_sparsity: bool = False
    ds_channel_config_path: str = None
    ds_heavy_channel_num: int = 32
    ds_heavy_token_num: int = 256
    ds_heavy_channel_type: str = "qk"
    ds_sparse_decode_threshold: int = 4096

109
110
111
112
113
    # LoRA
    lora_paths: Optional[List[str]] = None
    max_loras_per_batch: int = 8

    # Kernel backend
114
115
    attention_backend: Optional[str] = None
    sampling_backend: Optional[str] = None
116
    grammar_backend: Optional[str] = "outlines"
117

118
    # Optimization/debug options
Lianmin Zheng's avatar
Lianmin Zheng committed
119
    disable_radix_cache: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
120
    disable_jump_forward: bool = False
121
    disable_cuda_graph: bool = False
122
    disable_cuda_graph_padding: bool = False
123
    disable_disk_cache: bool = False
124
    disable_custom_all_reduce: bool = False
Ke Bao's avatar
Ke Bao committed
125
    disable_mla: bool = False
126
    disable_penalizer: bool = False
127
    disable_nan_detection: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
128
    enable_overlap_schedule: bool = False
129
    enable_mixed_chunk: bool = False
Ke Bao's avatar
Ke Bao committed
130
    enable_dp_attention: bool = False
131
    enable_torch_compile: bool = False
132
133
    torch_compile_max_bs: int = 32
    cuda_graph_max_bs: int = 160
134
    torchao_config: str = ""
135
    enable_p2p_check: bool = False
136
    triton_attention_reduce_in_fp32: bool = False
137
    num_continuous_decode_steps: int = 1
138
    delete_ckpt_after_loading: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
139
140

    def __post_init__(self):
141
        # Set missing default values
Lianmin Zheng's avatar
Lianmin Zheng committed
142
143
        if self.tokenizer_path is None:
            self.tokenizer_path = self.model_path
144
145
146
147

        if self.served_model_name is None:
            self.served_model_name = self.model_path

148
149
150
151
        if self.chunked_prefill_size <= 0:
            # Disable chunked prefill
            self.chunked_prefill_size = None

152
153
154
        if self.random_seed is None:
            self.random_seed = random.randint(0, 1 << 30)

155
        # Mem fraction depends on the tensor parallelism size
Lianmin Zheng's avatar
Lianmin Zheng committed
156
        if self.mem_fraction_static is None:
157
            if self.tp_size >= 16:
158
                self.mem_fraction_static = 0.79
159
            elif self.tp_size >= 8:
160
                self.mem_fraction_static = 0.82
Lianmin Zheng's avatar
Lianmin Zheng committed
161
            elif self.tp_size >= 4:
162
                self.mem_fraction_static = 0.85
Lianmin Zheng's avatar
Lianmin Zheng committed
163
            elif self.tp_size >= 2:
164
                self.mem_fraction_static = 0.87
Ying Sheng's avatar
Ying Sheng committed
165
            else:
166
                self.mem_fraction_static = 0.88
167

168
        # Adjust for GPUs with small memory capacities
HAI's avatar
HAI committed
169
170
171
172
        if is_hip():
            gpu_mem = get_amdgpu_memory_capacity()
        else:
            gpu_mem = get_nvgpu_memory_capacity()
173
174
175
176
177
178
        if gpu_mem < 25000:
            logger.warning(
                "Automatically adjust --chunked-prefill-size for small GPUs."
            )
            self.chunked_prefill_size //= 4  # make it 2048
            self.cuda_graph_max_bs = 4
179

180
        if not is_flashinfer_available():
181
182
183
            self.attention_backend = "triton"
            self.sampling_backend = "pytorch"

184
185
186
187
188
189
190
        # Default kernel backends
        if self.attention_backend is None:
            self.attention_backend = "flashinfer"

        if self.sampling_backend is None:
            self.sampling_backend = "flashinfer"

Ke Bao's avatar
Ke Bao committed
191
192
193
        if self.enable_dp_attention:
            self.dp_size = self.tp_size
            self.chunked_prefill_size = self.chunked_prefill_size // 2
194
            self.cuda_graph_max_bs = min(self.cuda_graph_max_bs, 96)
Ke Bao's avatar
Ke Bao committed
195
196
197
            self.enable_overlap_schedule = False
            logger.warning(
                f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE workload issue. "
198
199
                f"The CUDA graph max batch size is adjusted to {self.cuda_graph_max_bs}. "
                "Data parallel size is adjusted to be the same as tensor parallel size."
Ke Bao's avatar
Ke Bao committed
200
201
            )

202
203
204
205
206
207
208
209
210
211
        if self.enable_overlap_schedule:
            logger.warning(
                "Overlap scheduler mode is enabled. This is an experimental feature. "
                "Sampling penalizer (e.g., frequency and repetition penalty), constrained decoding (e.g., regex, JSON), "
                "and embedding APIs are not supported and will lead to wrong results. "
                "The NaN detection is also disabled."
            )
            self.disable_penalizer = True
            self.disable_nan_detection = True

212
213
214
215
216
217
218
219
220
221
222
        # Model-specific patches
        if "Alibaba-NLP/gte-Qwen2-1.5B-instruct" == self.model_path:
            logger.info(
                "Not sure why, the tokenizer will add an additional token at the end of the prompt when trust_remote_mode=True"
            )
            self.trust_remote_code = False

        if "gemma-2" in self.model_path.lower():
            logger.info("When using sliding window in gemma-2, turn on flashinfer.")
            self.attention_backend = "flashinfer"

Lianmin Zheng's avatar
Lianmin Zheng committed
223
224
    @staticmethod
    def add_cli_args(parser: argparse.ArgumentParser):
225
        # Model and port args
Lianmin Zheng's avatar
Lianmin Zheng committed
226
227
228
229
230
231
232
233
234
235
236
237
        parser.add_argument(
            "--model-path",
            type=str,
            help="The path of the model weights. This can be a local folder or a Hugging Face repo ID.",
            required=True,
        )
        parser.add_argument(
            "--tokenizer-path",
            type=str,
            default=ServerArgs.tokenizer_path,
            help="The path of the tokenizer.",
        )
Yuanhan Zhang's avatar
Yuanhan Zhang committed
238
239
240
241
242
243
        parser.add_argument(
            "--host", type=str, default=ServerArgs.host, help="The host of the server."
        )
        parser.add_argument(
            "--port", type=int, default=ServerArgs.port, help="The port of the server."
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
244
245
246
247
248
249
250
251
252
        parser.add_argument(
            "--tokenizer-mode",
            type=str,
            default=ServerArgs.tokenizer_mode,
            choices=["auto", "slow"],
            help="Tokenizer mode. 'auto' will use the fast "
            "tokenizer if available, and 'slow' will "
            "always use the slow tokenizer.",
        )
253
254
255
256
257
        parser.add_argument(
            "--skip-tokenizer-init",
            action="store_true",
            help="If set, skip init tokenizer and pass input_ids in generate request",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
        parser.add_argument(
            "--load-format",
            type=str,
            default=ServerArgs.load_format,
            choices=["auto", "pt", "safetensors", "npcache", "dummy"],
            help="The format of the model weights to load. "
            '"auto" will try to load the weights in the safetensors format '
            "and fall back to the pytorch bin format if safetensors format "
            "is not available. "
            '"pt" will load the weights in the pytorch bin format. '
            '"safetensors" will load the weights in the safetensors format. '
            '"npcache" will load the weights in pytorch format and store '
            "a numpy cache to speed up the loading. "
            '"dummy" will initialize the weights with random values, '
            "which is mainly for profiling.",
        )
274
275
276
277
278
        parser.add_argument(
            "--trust-remote-code",
            action="store_true",
            help="Whether or not to allow for custom models defined on the Hub in their own modeling files.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
279
        parser.add_argument(
Lianmin Zheng's avatar
Lianmin Zheng committed
280
            "--dtype",
Cody Yu's avatar
Cody Yu committed
281
            type=str,
Lianmin Zheng's avatar
Lianmin Zheng committed
282
            default=ServerArgs.dtype,
Ying Sheng's avatar
Ying Sheng committed
283
284
            choices=["auto", "half", "float16", "bfloat16", "float", "float32"],
            help="Data type for model weights and activations.\n\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
285
            '* "auto" will use FP16 precision for FP32 and FP16 models, and '
Ying Sheng's avatar
Ying Sheng committed
286
            "BF16 precision for BF16 models.\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
287
288
289
290
            '* "half" for FP16. Recommended for AWQ quantization.\n'
            '* "float16" is the same as "half".\n'
            '* "bfloat16" for a balance between precision and range.\n'
            '* "float" is shorthand for FP32 precision.\n'
Ying Sheng's avatar
Ying Sheng committed
291
292
            '* "float32" for FP32 precision.',
        )
293
294
295
296
297
298
299
        parser.add_argument(
            "--kv-cache-dtype",
            type=str,
            default=ServerArgs.kv_cache_dtype,
            choices=["auto", "fp8_e5m2"],
            help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" is supported for CUDA 11.8+.',
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
300
301
302
303
        parser.add_argument(
            "--quantization",
            type=str,
            default=ServerArgs.quantization,
Ying Sheng's avatar
Ying Sheng committed
304
305
306
307
308
309
            choices=[
                "awq",
                "fp8",
                "gptq",
                "marlin",
                "gptq_marlin",
Ying Sheng's avatar
Ying Sheng committed
310
                "awq_marlin",
Ying Sheng's avatar
Ying Sheng committed
311
312
                "bitsandbytes",
            ],
Lianmin Zheng's avatar
Lianmin Zheng committed
313
314
            help="The quantization method.",
        )
315
316
317
318
319
320
321
322
323
324
325
326
327
        parser.add_argument(
            "--context-length",
            type=int,
            default=ServerArgs.context_length,
            help="The model's maximum context length. Defaults to None (will use the value from the model's config.json instead).",
        )
        parser.add_argument(
            "--device",
            type=str,
            default="cuda",
            choices=["cuda", "xpu"],
            help="The device type.",
        )
328
329
330
331
332
333
        parser.add_argument(
            "--served-model-name",
            type=str,
            default=ServerArgs.served_model_name,
            help="Override the model name returned by the v1/models endpoint in OpenAI API server.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
334
335
336
337
338
339
        parser.add_argument(
            "--chat-template",
            type=str,
            default=ServerArgs.chat_template,
            help="The buliltin chat template name or the path of the chat template file. This is only used for OpenAI-compatible API server.",
        )
340
341
342
343
344
        parser.add_argument(
            "--is-embedding",
            action="store_true",
            help="Whether to use a CausalLM as an embedding model.",
        )
345
346

        # Memory and scheduling
Lianmin Zheng's avatar
Lianmin Zheng committed
347
348
349
350
        parser.add_argument(
            "--mem-fraction-static",
            type=float,
            default=ServerArgs.mem_fraction_static,
351
            help="The fraction of the memory used for static allocation (model weights and KV cache memory pool). Use a smaller value if you see out-of-memory errors.",
Lianmin Zheng's avatar
Lianmin Zheng committed
352
        )
353
354
355
356
357
358
        parser.add_argument(
            "--max-running-requests",
            type=int,
            default=ServerArgs.max_running_requests,
            help="The maximum number of running requests.",
        )
359
360
361
362
        parser.add_argument(
            "--max-total-tokens",
            type=int,
            default=ServerArgs.max_total_tokens,
363
364
            help="The maximum number of tokens in the memory pool. If not specified, it will be automatically calculated based on the memory usage fraction. "
            "This option is typically used for development and debugging purposes.",
365
        )
366
367
368
369
370
371
372
373
374
375
376
377
        parser.add_argument(
            "--chunked-prefill-size",
            type=int,
            default=ServerArgs.chunked_prefill_size,
            help="The maximum number of tokens in a chunk for the chunked prefill. Setting this to -1 means disabling chunked prefill",
        )
        parser.add_argument(
            "--max-prefill-tokens",
            type=int,
            default=ServerArgs.max_prefill_tokens,
            help="The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model's maximum context length.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
378
        parser.add_argument(
379
            "--schedule-policy",
Lianmin Zheng's avatar
Lianmin Zheng committed
380
            type=str,
381
            default=ServerArgs.schedule_policy,
Liangsheng Yin's avatar
Liangsheng Yin committed
382
            choices=["lpm", "random", "fcfs", "dfs-weight"],
383
            help="The scheduling policy of the requests.",
Lianmin Zheng's avatar
Lianmin Zheng committed
384
        )
385
386
387
388
        parser.add_argument(
            "--schedule-conservativeness",
            type=float,
            default=ServerArgs.schedule_conservativeness,
389
            help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see requests being retracted frequently.",
390
        )
391
392

        # Other runtime options
Lianmin Zheng's avatar
Lianmin Zheng committed
393
        parser.add_argument(
394
            "--tensor-parallel-size",
Lianmin Zheng's avatar
Lianmin Zheng committed
395
            "--tp-size",
Lianmin Zheng's avatar
Lianmin Zheng committed
396
            type=int,
Lianmin Zheng's avatar
Lianmin Zheng committed
397
            default=ServerArgs.tp_size,
398
            help="The tensor parallelism size.",
399
        )
400
401
402
        parser.add_argument(
            "--stream-interval",
            type=int,
Lianmin Zheng's avatar
Lianmin Zheng committed
403
            default=ServerArgs.stream_interval,
404
            help="The interval (or buffer size) for streaming in terms of the token length. A smaller value makes streaming smoother, while a larger value makes the throughput higher",
405
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
406
407
408
409
        parser.add_argument(
            "--random-seed",
            type=int,
            default=ServerArgs.random_seed,
410
            help="The random seed.",
Lianmin Zheng's avatar
Lianmin Zheng committed
411
        )
412
413
414
415
416
417
        parser.add_argument(
            "--constrained-json-whitespace-pattern",
            type=str,
            default=ServerArgs.constrained_json_whitespace_pattern,
            help=r"Regex pattern for syntactic whitespaces allowed in JSON constrained output. For example, to allow the model generate consecutive whitespaces, set the pattern to [\n\t ]*",
        )
418
419
420
421
422
423
        parser.add_argument(
            "--watchdog-timeout",
            type=float,
            default=ServerArgs.watchdog_timeout,
            help="Set watchdog timeout in seconds. If a forward batch takes longer than this, the server will crash to prevent hanging.",
        )
424
425
426
427
428
429
        parser.add_argument(
            "--download-dir",
            type=str,
            default=ServerArgs.download_dir,
            help="Model download directory.",
        )
430
431

        # Logging
Lianmin Zheng's avatar
Lianmin Zheng committed
432
433
434
435
        parser.add_argument(
            "--log-level",
            type=str,
            default=ServerArgs.log_level,
436
            help="The logging level of all loggers.",
Lianmin Zheng's avatar
Lianmin Zheng committed
437
        )
438
        parser.add_argument(
439
440
441
442
            "--log-level-http",
            type=str,
            default=ServerArgs.log_level_http,
            help="The logging level of HTTP server. If not set, reuse --log-level by default.",
443
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
444
        parser.add_argument(
445
            "--log-requests",
Lianmin Zheng's avatar
Lianmin Zheng committed
446
            action="store_true",
447
            help="Log the inputs and outputs of all requests.",
Lianmin Zheng's avatar
Lianmin Zheng committed
448
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
449
450
451
        parser.add_argument(
            "--show-time-cost",
            action="store_true",
Ying Sheng's avatar
Ying Sheng committed
452
            help="Show time cost of custom marks.",
Lianmin Zheng's avatar
Lianmin Zheng committed
453
        )
454
455
456
457
458
        parser.add_argument(
            "--enable-metrics",
            action="store_true",
            help="Enable log prometheus metrics.",
        )
459
460
461
462
463
464
        parser.add_argument(
            "--decode-log-interval",
            type=int,
            default=ServerArgs.decode_log_interval,
            help="The log interval of decode batch",
        )
465

466
        # API related
Liangsheng Yin's avatar
Liangsheng Yin committed
467
468
469
470
        parser.add_argument(
            "--api-key",
            type=str,
            default=ServerArgs.api_key,
471
            help="Set API key of the server. It is also used in the OpenAI API compatible server.",
Liangsheng Yin's avatar
Liangsheng Yin committed
472
        )
473
474
475
476
477
478
        parser.add_argument(
            "--file-storage-pth",
            type=str,
            default=ServerArgs.file_storage_pth,
            help="The path of the file storage in backend.",
        )
479
480
481
482
483
        parser.add_argument(
            "--enable-cache-report",
            action="store_true",
            help="Return number of cached tokens in usage.prompt_tokens_details for each openai request.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
484

485
486
        # Data parallelism
        parser.add_argument(
487
            "--data-parallel-size",
488
489
490
            "--dp-size",
            type=int,
            default=ServerArgs.dp_size,
491
            help="The data parallelism size.",
492
493
494
495
496
        )
        parser.add_argument(
            "--load-balance-method",
            type=str,
            default=ServerArgs.load_balance_method,
497
            help="The load balancing strategy for data parallelism.",
498
499
500
501
502
503
            choices=[
                "round_robin",
                "shortest_queue",
            ],
        )

504
        # Multi-node distributed serving
505
        parser.add_argument(
506
507
            "--dist-init-addr",
            "--nccl-init-addr",  # For backward compatbility. This will be removed in the future.
508
            type=str,
509
            help="The host address for initializing distributed backend (e.g., `192.168.0.2:25000`).",
510
511
        )
        parser.add_argument(
Liangsheng Yin's avatar
Liangsheng Yin committed
512
            "--nnodes", type=int, default=ServerArgs.nnodes, help="The number of nodes."
513
        )
514
515
516
        parser.add_argument(
            "--node-rank", type=int, default=ServerArgs.node_rank, help="The node rank."
        )
517

Lianmin Zheng's avatar
Lianmin Zheng committed
518
519
520
521
522
523
524
525
        # Model override args
        parser.add_argument(
            "--json-model-override-args",
            type=str,
            help="A dictionary in JSON string format used to override default model configurations.",
            default=ServerArgs.json_model_override_args,
        )

Shuo Yang's avatar
Shuo Yang committed
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
        # Double Sparsity
        parser.add_argument(
            "--enable-double-sparsity",
            action="store_true",
            help="Enable double sparsity attention",
        )
        parser.add_argument(
            "--ds-channel-config-path",
            type=str,
            default=ServerArgs.ds_channel_config_path,
            help="The path of the double sparsity channel config",
        )
        parser.add_argument(
            "--ds-heavy-channel-num",
            type=int,
            default=ServerArgs.ds_heavy_channel_num,
            help="The number of heavy channels in double sparsity attention",
        )
        parser.add_argument(
            "--ds-heavy-token-num",
            type=int,
            default=ServerArgs.ds_heavy_token_num,
            help="The number of heavy tokens in double sparsity attention",
        )
        parser.add_argument(
            "--ds-heavy-channel-type",
            type=str,
            default=ServerArgs.ds_heavy_channel_type,
            help="The type of heavy channels in double sparsity attention",
        )
        parser.add_argument(
            "--ds-sparse-decode-threshold",
            type=int,
            default=ServerArgs.ds_sparse_decode_threshold,
            help="The type of heavy channels in double sparsity attention",
        )

563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
        # LoRA
        parser.add_argument(
            "--lora-paths",
            type=str,
            nargs="*",
            default=None,
            action=LoRAPathAction,
            help="The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}",
        )
        parser.add_argument(
            "--max-loras-per-batch",
            type=int,
            default=8,
            help="Maximum number of adapters for a running batch, include base-only request",
        )

        # Kernel backend
580
581
582
583
584
585
586
587
588
589
590
591
592
593
        parser.add_argument(
            "--attention-backend",
            type=str,
            choices=["flashinfer", "triton"],
            default=ServerArgs.attention_backend,
            help="Choose the kernels for attention layers.",
        )
        parser.add_argument(
            "--sampling-backend",
            type=str,
            choices=["flashinfer", "pytorch"],
            default=ServerArgs.sampling_backend,
            help="Choose the kernels for sampling layers.",
        )
594
595
596
597
598
        parser.add_argument(
            "--grammar-backend",
            type=str,
            choices=["xgrammar", "outlines"],
            default=ServerArgs.grammar_backend,
Lianmin Zheng's avatar
Lianmin Zheng committed
599
            help="Choose the backend for grammar-guided decoding.",
600
        )
601
602

        # Optimization/debug options
Liangsheng Yin's avatar
Liangsheng Yin committed
603
        parser.add_argument(
Lianmin Zheng's avatar
Lianmin Zheng committed
604
            "--disable-radix-cache",
Liangsheng Yin's avatar
Liangsheng Yin committed
605
            action="store_true",
Ying Sheng's avatar
Ying Sheng committed
606
            help="Disable RadixAttention for prefix caching.",
Liangsheng Yin's avatar
Liangsheng Yin committed
607
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
608
        parser.add_argument(
Lianmin Zheng's avatar
Lianmin Zheng committed
609
            "--disable-jump-forward",
Liangsheng Yin's avatar
Liangsheng Yin committed
610
            action="store_true",
Lianmin Zheng's avatar
Lianmin Zheng committed
611
            help="Disable jump-forward for grammar-guided decoding.",
Liangsheng Yin's avatar
Liangsheng Yin committed
612
        )
613
614
615
616
617
        parser.add_argument(
            "--disable-cuda-graph",
            action="store_true",
            help="Disable cuda graph.",
        )
618
        parser.add_argument(
619
620
621
622
623
            "--disable-cuda-graph-padding",
            action="store_true",
            help="Disable cuda graph when padding is needed. Still uses cuda graph when padding is not needed.",
        )
        parser.add_argument(
624
625
626
627
            "--disable-disk-cache",
            action="store_true",
            help="Disable disk cache to avoid possible crashes related to file system or high concurrency.",
        )
628
629
630
631
632
        parser.add_argument(
            "--disable-custom-all-reduce",
            action="store_true",
            help="Disable the custom all-reduce kernel and fall back to NCCL.",
        )
Ke Bao's avatar
Ke Bao committed
633
634
635
636
637
        parser.add_argument(
            "--disable-mla",
            action="store_true",
            help="Disable Multi-head Latent Attention (MLA) for DeepSeek-V2.",
        )
638
639
640
        parser.add_argument(
            "--disable-penalizer",
            action="store_true",
641
642
643
644
645
646
            help="Disable the logit penalizers (e.g., frequency and repetition penalty) for better performance if they are not used in any requests.",
        )
        parser.add_argument(
            "--disable-nan-detection",
            action="store_true",
            help="Disable the NaN detection for better performance.",
647
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
648
649
650
651
652
        parser.add_argument(
            "--enable-overlap-schedule",
            action="store_true",
            help="Overlap the CPU scheduler with GPU model worker. Experimental feature.",
        )
653
654
655
        parser.add_argument(
            "--enable-mixed-chunk",
            action="store_true",
656
            help="Enabling mixing prefill and decode in a batch when using chunked prefill.",
657
        )
Ke Bao's avatar
Ke Bao committed
658
659
660
661
662
        parser.add_argument(
            "--enable-dp-attention",
            action="store_true",
            help="Enabling data parallelism for attention and tensor parallelism for FFN. The dp size should be equal to the tp size. Currently only DeepSeek-V2 is supported.",
        )
663
664
665
        parser.add_argument(
            "--enable-torch-compile",
            action="store_true",
666
667
            help="Optimize the model with torch.compile. Experimental feature.",
        )
668
        parser.add_argument(
669
            "--torch-compile-max-bs",
670
            type=int,
671
            default=ServerArgs.torch_compile_max_bs,
672
673
            help="Set the maximum batch size when using torch compile.",
        )
674
        parser.add_argument(
675
            "--cuda-graph-max-bs",
676
            type=int,
677
            default=ServerArgs.cuda_graph_max_bs,
678
679
            help="Set the maximum batch size for cuda graph.",
        )
680
681
682
683
684
        parser.add_argument(
            "--torchao-config",
            type=str,
            default=ServerArgs.torchao_config,
            help="Optimize the model with torchao. Experimental feature. Current choices are: int8dq, int8wo, int4wo-<group_size>, fp8wo",
685
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
686
        parser.add_argument(
687
            "--enable-p2p-check",
Lianmin Zheng's avatar
Lianmin Zheng committed
688
            action="store_true",
689
            help="Enable P2P check for GPU access, otherwise the p2p access is allowed by default.",
Lianmin Zheng's avatar
Lianmin Zheng committed
690
        )
691
        parser.add_argument(
692
            "--triton-attention-reduce-in-fp32",
693
            action="store_true",
694
            help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
695
            "This only affects Triton attention kernels.",
696
        )
697
698
699
700
701
702
703
704
        parser.add_argument(
            "--num-continuous-decode-steps",
            type=int,
            default=ServerArgs.num_continuous_decode_steps,
            help="Run multiple continuous decoding steps to reduce scheduling overhead. "
            "This can potentially increase throughput but may also increase time-to-first-token latency. "
            "The default value is 1, meaning only run one decoding step at a time.",
        )
705
706
707
708
709
        parser.add_argument(
            "--delete-ckpt-after-loading",
            action="store_true",
            help="Delete the model checkpoint after loading the model.",
        )
710

711
712
713
714
715
716
717
718
719
720
721
722
        # Deprecated arguments
        parser.add_argument(
            "--disable-flashinfer",
            action=DeprecatedAction,
            help="'--disable-flashinfer' is deprecated. Please use '--attention-backend triton' instead.",
        )
        parser.add_argument(
            "--disable-flashinfer-sampling",
            action=DeprecatedAction,
            help="'--disable-flashinfer-sampling' is deprecated. Please use '--sampling-backend pytroch' instead.",
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
723
724
    @classmethod
    def from_cli_args(cls, args: argparse.Namespace):
725
726
        args.tp_size = args.tensor_parallel_size
        args.dp_size = args.data_parallel_size
Lianmin Zheng's avatar
Lianmin Zheng committed
727
728
729
730
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        return cls(**{attr: getattr(args, attr) for attr in attrs})

    def url(self):
731
732
733
734
        if is_ipv6(self.host):
            return f"http://[{self.host}]:{self.port}"
        else:
            return f"http://{self.host}:{self.port}"
Lianmin Zheng's avatar
Lianmin Zheng committed
735

736
737
738
739
740
    def check_server_args(self):
        assert (
            self.tp_size % self.nnodes == 0
        ), "tp_size must be divisible by number of nodes"
        assert not (
741
            self.dp_size > 1 and self.nnodes != 1
742
        ), "multi-node data parallel is not supported"
743
744
745
746
747
748
        assert (
            self.max_loras_per_batch > 0
            # FIXME
            and (self.lora_paths is None or self.disable_cuda_graph)
            and (self.lora_paths is None or self.disable_radix_cache)
        ), "compatibility of lora and cuda graph and radix attention is in progress"
749

750
751
752
753
754
755
756
757
758
759
        if isinstance(self.lora_paths, list):
            lora_paths = self.lora_paths
            self.lora_paths = {}
            for lora_path in lora_paths:
                if "=" in lora_path:
                    name, path = lora_path.split("=", 1)
                    self.lora_paths[name] = path
                else:
                    self.lora_paths[lora_path] = lora_path

Lianmin Zheng's avatar
Lianmin Zheng committed
760

Lianmin Zheng's avatar
Lianmin Zheng committed
761
def prepare_server_args(argv: List[str]) -> ServerArgs:
762
763
764
765
766
767
768
769
770
771
772
773
    """
    Prepare the server arguments from the command line arguments.

    Args:
        args: The command line arguments. Typically, it should be `sys.argv[1:]`
            to ensure compatibility with `parse_args` when no arguments are passed.

    Returns:
        The server arguments.
    """
    parser = argparse.ArgumentParser()
    ServerArgs.add_cli_args(parser)
Lianmin Zheng's avatar
Lianmin Zheng committed
774
    raw_args = parser.parse_args(argv)
775
776
777
778
    server_args = ServerArgs.from_cli_args(raw_args)
    return server_args


Lianmin Zheng's avatar
Lianmin Zheng committed
779
780
@dataclasses.dataclass
class PortArgs:
781
782
783
784
785
786
    # The ipc filename for tokenizer to receive inputs from detokenizer (zmq)
    tokenizer_ipc_name: str
    # The ipc filename for scheduler (rank 0) to receive inputs from tokenizer (zmq)
    scheduler_input_ipc_name: str
    # The ipc filename for detokenizer to receive inputs from scheduler (zmq)
    detokenizer_ipc_name: str
787

788
789
    # The port for nccl initialization (torch.dist)
    nccl_port: int
790

791
792
    @staticmethod
    def init_new(server_args) -> "PortArgs":
793
        port = server_args.port + 42
794
795
796
        while True:
            if is_port_available(port):
                break
797
            port += 42
798
799
800
801
802

        return PortArgs(
            tokenizer_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
            scheduler_input_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
            detokenizer_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
803
            nccl_port=port,
804
805
        )

806
807
808
809
810
811
812
813
814
815

class LoRAPathAction(argparse.Action):
    def __call__(self, parser, namespace, values, option_string=None):
        setattr(namespace, self.dest, {})
        for lora_path in values:
            if "=" in lora_path:
                name, path = lora_path.split("=", 1)
                getattr(namespace, self.dest)[name] = path
            else:
                getattr(namespace, self.dest)[lora_path] = lora_path
816
817
818
819
820
821
822
823
824
825


class DeprecatedAction(argparse.Action):
    def __init__(self, option_strings, dest, nargs=0, **kwargs):
        super(DeprecatedAction, self).__init__(
            option_strings, dest, nargs=nargs, **kwargs
        )

    def __call__(self, parser, namespace, values, option_string=None):
        raise ValueError(self.help)