server_args.py 51.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
Lianmin Zheng's avatar
Lianmin Zheng committed
14
15
"""The arguments of the server."""

Lianmin Zheng's avatar
Lianmin Zheng committed
16
17
import argparse
import dataclasses
18
import json
19
import logging
20
import os
21
import random
22
import tempfile
23
from typing import List, Literal, Optional
Lianmin Zheng's avatar
Lianmin Zheng committed
24

25
from sglang.srt.hf_transformers_utils import check_gguf_file
Xihuai Wang's avatar
Xihuai Wang committed
26
from sglang.srt.reasoning_parser import ReasoningParser
27
from sglang.srt.utils import (
Vincent's avatar
Vincent committed
28
    configure_ipv6,
HAI's avatar
HAI committed
29
    get_amdgpu_memory_capacity,
30
    get_device,
31
    get_hpu_memory_capacity,
HAI's avatar
HAI committed
32
    get_nvgpu_memory_capacity,
33
    is_cuda,
34
    is_flashinfer_available,
HAI's avatar
HAI committed
35
    is_hip,
36
    is_port_available,
37
    is_remote_url,
38
    is_valid_ipv6_address,
bjmsong's avatar
bjmsong committed
39
    nullable_str,
40
)
41

42
43
logger = logging.getLogger(__name__)

Lianmin Zheng's avatar
Lianmin Zheng committed
44
45
46

@dataclasses.dataclass
class ServerArgs:
Lianmin Zheng's avatar
Lianmin Zheng committed
47
    # Model and tokenizer
Lianmin Zheng's avatar
Lianmin Zheng committed
48
49
50
    model_path: str
    tokenizer_path: Optional[str] = None
    tokenizer_mode: str = "auto"
51
    skip_tokenizer_init: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
52
    load_format: str = "auto"
53
    trust_remote_code: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
54
    dtype: str = "auto"
55
    kv_cache_dtype: str = "auto"
Lianmin Zheng's avatar
Lianmin Zheng committed
56
    quantization: Optional[str] = None
Vincent's avatar
Vincent committed
57
    quantization_param_path: Optional[str] = None
58
    context_length: Optional[int] = None
59
    device: Optional[str] = None
60
    served_model_name: Optional[str] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
61
    chat_template: Optional[str] = None
62
    completion_template: Optional[str] = None
63
    is_embedding: bool = False
64
    revision: Optional[str] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
65

66
    # Port for the HTTP server
Lianmin Zheng's avatar
Lianmin Zheng committed
67
68
69
70
    host: str = "127.0.0.1"
    port: int = 30000

    # Memory and scheduling
Lianmin Zheng's avatar
Lianmin Zheng committed
71
    mem_fraction_static: Optional[float] = None
72
    max_running_requests: Optional[int] = None
73
    max_total_tokens: Optional[int] = None
74
    chunked_prefill_size: Optional[int] = None
75
    max_prefill_tokens: int = 16384
76
    schedule_policy: str = "fcfs"
77
    schedule_conservativeness: float = 1.0
78
    cpu_offload_gb: int = 0
79
    page_size: int = 1
Lianmin Zheng's avatar
Lianmin Zheng committed
80
81
82

    # Other runtime options
    tp_size: int = 1
83
    stream_interval: int = 1
84
    stream_output: bool = False
85
    random_seed: Optional[int] = None
86
    constrained_json_whitespace_pattern: Optional[str] = None
87
    watchdog_timeout: float = 300
88
    dist_timeout: Optional[int] = None  # timeout for torch.distributed
89
    download_dir: Optional[str] = None
90
    base_gpu_id: int = 0
91
    gpu_id_step: int = 1
Lianmin Zheng's avatar
Lianmin Zheng committed
92
93
94

    # Logging
    log_level: str = "info"
95
    log_level_http: Optional[str] = None
96
    log_requests: bool = False
97
    log_requests_level: int = 0
Liangsheng Yin's avatar
Liangsheng Yin committed
98
    show_time_cost: bool = False
99
    enable_metrics: bool = False
100
    decode_log_interval: int = 40
Liangsheng Yin's avatar
Liangsheng Yin committed
101

102
    # API related
103
    api_key: Optional[str] = None
104
    file_storage_path: str = "sglang_storage"
105
    enable_cache_report: bool = False
Xihuai Wang's avatar
Xihuai Wang committed
106
    reasoning_parser: Optional[str] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
107

108
109
110
    # Data parallelism
    dp_size: int = 1
    load_balance_method: str = "round_robin"
111

xiaobochen's avatar
xiaobochen committed
112
113
    # Expert parallelism
    ep_size: int = 1
114

115
    # Multi-node distributed serving
116
    dist_init_addr: Optional[str] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
117
    nnodes: int = 1
118
    node_rank: int = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
119
120
121
122

    # Model override args in JSON
    json_model_override_args: str = "{}"

123
124
125
    # LoRA
    lora_paths: Optional[List[str]] = None
    max_loras_per_batch: int = 8
126
    lora_backend: str = "triton"
127
128

    # Kernel backend
129
130
    attention_backend: Optional[str] = None
    sampling_backend: Optional[str] = None
131
    grammar_backend: Optional[str] = None
132

133
134
    # Speculative decoding
    speculative_algorithm: Optional[str] = None
135
    speculative_draft_model_path: Optional[str] = None
136
137
138
    speculative_num_steps: Optional[int] = None
    speculative_eagle_topk: Optional[int] = None
    speculative_num_draft_tokens: Optional[int] = None
139
140
    speculative_accept_threshold_single: float = 1.0
    speculative_accept_threshold_acc: float = 1.0
141
    speculative_token_map: Optional[str] = None
142
143
144

    # Double Sparsity
    enable_double_sparsity: bool = False
Vincent's avatar
Vincent committed
145
    ds_channel_config_path: Optional[str] = None
146
147
148
149
150
    ds_heavy_channel_num: int = 32
    ds_heavy_token_num: int = 256
    ds_heavy_channel_type: str = "qk"
    ds_sparse_decode_threshold: int = 4096

151
    # Optimization/debug options
Lianmin Zheng's avatar
Lianmin Zheng committed
152
    disable_radix_cache: bool = False
153
    disable_cuda_graph: bool = False
154
    disable_cuda_graph_padding: bool = False
155
    enable_nccl_nvls: bool = False
156
    disable_outlines_disk_cache: bool = False
157
    disable_custom_all_reduce: bool = False
Ke Bao's avatar
Ke Bao committed
158
    disable_mla: bool = False
159
    disable_overlap_schedule: bool = False
160
    enable_mixed_chunk: bool = False
Ke Bao's avatar
Ke Bao committed
161
    enable_dp_attention: bool = False
xiaobochen's avatar
xiaobochen committed
162
    enable_ep_moe: bool = False
163
    enable_deepep_moe: bool = False
164
    deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto"
165
    enable_torch_compile: bool = False
166
    torch_compile_max_bs: int = 32
167
    cuda_graph_max_bs: Optional[int] = None
168
    cuda_graph_bs: Optional[List[int]] = None
169
    torchao_config: str = ""
170
    enable_nan_detection: bool = False
171
    enable_p2p_check: bool = False
172
    triton_attention_reduce_in_fp32: bool = False
173
    triton_attention_num_kv_splits: int = 8
174
    num_continuous_decode_steps: int = 1
175
    delete_ckpt_after_loading: bool = False
176
    enable_memory_saver: bool = False
177
    allow_auto_truncate: bool = False
178
    enable_custom_logit_processor: bool = False
Vincent's avatar
Vincent committed
179
    tool_call_parser: Optional[str] = None
180
    enable_hierarchical_cache: bool = False
181
    hicache_ratio: float = 2.0
182
    enable_flashinfer_mla: bool = False
lukec's avatar
lukec committed
183
    enable_flashmla: bool = False
184
    flashinfer_mla_disable_ragged: bool = False
185
    warmups: Optional[str] = None
186
187
    n_share_experts_fusion: Optional[int] = None
    disable_shared_experts_fusion: bool = False
188
189
190
191
192

    # Debug tensor dumps
    debug_tensor_dump_output_folder: Optional[str] = None
    debug_tensor_dump_input_file: Optional[str] = None
    debug_tensor_dump_inject: bool = False
193

Byron Hsu's avatar
Byron Hsu committed
194
195
196
197
    # For PD disaggregation: can be "null" (not disaggregated), "prefill" (prefill-only), or "decode" (decode-only)
    disaggregation_mode: str = "null"
    disaggregation_bootstrap_port: int = 8998

Lianmin Zheng's avatar
Lianmin Zheng committed
198
    def __post_init__(self):
199
200
201
202
203
204
205
        # Expert parallelism
        if self.enable_ep_moe:
            self.ep_size = self.tp_size
            logger.info(
                f"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
            )

206
        # Set missing default values
Lianmin Zheng's avatar
Lianmin Zheng committed
207
208
        if self.tokenizer_path is None:
            self.tokenizer_path = self.model_path
209

210
211
212
        if self.device is None:
            self.device = get_device()

213
214
215
        if self.served_model_name is None:
            self.served_model_name = self.model_path

216
217
218
        if self.random_seed is None:
            self.random_seed = random.randint(0, 1 << 30)

219
        if is_cuda():
220
            gpu_mem = get_nvgpu_memory_capacity()
221
222
        elif is_hip():
            gpu_mem = get_amdgpu_memory_capacity()
223
224
        elif self.device == "hpu":
            gpu_mem = get_hpu_memory_capacity()
225
226
227
        else:
            # GPU memory is not known yet or no GPU is available.
            gpu_mem = None
228

229
230
231
        if is_hip():
            self.disable_shared_experts_fusion = True

232
        # Set mem fraction static, which depends on the tensor parallelism size
Lianmin Zheng's avatar
Lianmin Zheng committed
233
        if self.mem_fraction_static is None:
234
            if self.tp_size >= 16:
235
                self.mem_fraction_static = 0.79
236
            elif self.tp_size >= 8:
237
                self.mem_fraction_static = 0.81
Lianmin Zheng's avatar
Lianmin Zheng committed
238
            elif self.tp_size >= 4:
239
                self.mem_fraction_static = 0.85
Lianmin Zheng's avatar
Lianmin Zheng committed
240
            elif self.tp_size >= 2:
241
                self.mem_fraction_static = 0.87
Ying Sheng's avatar
Ying Sheng committed
242
            else:
243
                self.mem_fraction_static = 0.88
244

245
246
        # Set chunked prefill size, which depends on the gpu memory capacity
        if self.chunked_prefill_size is None:
247
            if gpu_mem is not None and gpu_mem < 25_000:
248
249
250
                self.chunked_prefill_size = 2048
            else:
                self.chunked_prefill_size = 8192
251

Lianmin Zheng's avatar
Lianmin Zheng committed
252
253
        assert self.chunked_prefill_size % self.page_size == 0

lukec's avatar
lukec committed
254
        if self.enable_flashmla is True:
255
256
257
258
            logger.warning(
                "FlashMLA only supports a page_size of 64, change page_size to 64."
            )
            self.page_size = 64
259
260
        # Set cuda graph max batch size
        if self.cuda_graph_max_bs is None:
261
            # Based on detailed statistics, when serving TP1/TP2 models on lower-end GPUs with HBM<25G, you can either disable cuda graph or set `cuda_graph_max_bs` to a very small value to reduce the memory overhead of creating cuda graphs, with almost no impact on performance. However, when serving models with TP4 or TP8, we need to enable cuda graph to maintain high performance. In this case, we can set `cuda_graph_max_bs` to 80 (half of the default value 160) to reduce the memory overhead of creating cuda graphs. Looking at the logs from TP4 serving of qwen2-72b, a value of 80 is sufficient and can reduce the memory overhead of creating cuda graphs on lower-end GPUs compared to the original 160, avoiding OOM issues.
262
            if gpu_mem is not None and gpu_mem < 25_000:
263
264
265
266
                if self.tp_size < 4:
                    self.cuda_graph_max_bs = 8
                else:
                    self.cuda_graph_max_bs = 80
267
268
            else:
                self.cuda_graph_max_bs = 160
269

270
        # Choose kernel backends
271
272
273
274
        if self.device == "hpu":
            self.attention_backend = "torch_native"
            self.sampling_backend = "pytorch"

275
        if self.attention_backend is None:
276
277
278
            self.attention_backend = (
                "flashinfer" if is_flashinfer_available() else "triton"
            )
279
        if self.sampling_backend is None:
280
281
282
283
284
            self.sampling_backend = (
                "flashinfer" if is_flashinfer_available() else "pytorch"
            )

        if self.attention_backend == "torch_native":
285
            logger.warning(
286
287
288
                "Cuda graph is disabled because of using torch native attention backend"
            )
            self.disable_cuda_graph = True
289

290
291
292
        # Choose grammar backend
        if self.grammar_backend is None:
            self.grammar_backend = "xgrammar"
293

294
295
296
297
298
299
300
        # Expert parallelism
        if self.enable_ep_moe:
            self.ep_size = self.tp_size
            logger.info(
                f"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
            )

301
        # Data parallelism attention
Ke Bao's avatar
Ke Bao committed
302
        if self.enable_dp_attention:
303
            self.schedule_conservativeness = self.schedule_conservativeness * 0.3
Lianmin Zheng's avatar
Lianmin Zheng committed
304
305
306
307
308
            assert (
                self.dp_size > 1
            ), "Please set a dp-size > 1. You can use 1 < dp-size <= tp-size "
            assert self.tp_size % self.dp_size == 0
            self.chunked_prefill_size = self.chunked_prefill_size // self.dp_size
309
            logger.warning(
310
                f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE kernel issues. "
311
            )
312
313
314
315

        self.enable_sp_layernorm = False
        # DeepEP MoE
        if self.enable_deepep_moe:
316
317
318
319
            if self.deepep_mode == "auto":
                assert (
                    not self.enable_dp_attention
                ), "DeepEP MoE `auto` mode is not supported with DP Attention."
320
321
322
323
324
325
326
            self.ep_size = self.tp_size
            self.enable_sp_layernorm = (
                self.dp_size < self.tp_size if self.enable_dp_attention else True
            )
            logger.info(
                f"DeepEP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
            )
327

328
        # Speculative Decoding
329
330
331
332
        if self.speculative_algorithm == "NEXTN":
            # NEXTN shares the same implementation of EAGLE
            self.speculative_algorithm = "EAGLE"

James Liu's avatar
James Liu committed
333
334
335
336
        if (
            self.speculative_algorithm == "EAGLE"
            or self.speculative_algorithm == "EAGLE3"
        ):
337
            if self.max_running_requests is None:
338
                self.max_running_requests = 48
339
            self.disable_overlap_schedule = True
340
            logger.info(
341
                "Overlap scheduler is disabled because of using "
342
                "eagle speculative decoding."
343
            )
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360

            # Auto choose parameters
            if self.speculative_num_steps is None:
                assert (
                    self.speculative_eagle_topk is None
                    and self.speculative_num_draft_tokens is None
                )
                (
                    self.speculative_num_steps,
                    self.speculative_eagle_topk,
                    self.speculative_num_draft_tokens,
                ) = auto_choose_speculative_params(self)

            if self.page_size > 1 and self.speculative_eagle_topk > 1:
                self.speculative_eagle_topk = 1
                logger.info("speculative_eagle_topk is changed to 1 when page_size > 1")

361
            # The token generated from the verify step is counted.
362
            # If sepculative_num_steps >= speculative_num_draft_tokens, the additional tokens will definitely be discarded.
363
            # assert self.speculative_num_steps < self.speculative_num_draft_tokens
364

365
366
367
368
369
370
        # GGUF
        if (
            self.load_format == "auto" or self.load_format == "gguf"
        ) and check_gguf_file(self.model_path):
            self.quantization = self.load_format = "gguf"

371
372
373
        if is_remote_url(self.model_path):
            self.load_format = "remote"

374
375
376
377
        # AMD-specific Triton attention KV splits default number
        if is_hip():
            self.triton_attention_num_kv_splits = 16

Byron Hsu's avatar
Byron Hsu committed
378
379
380
381
382
383
384
385
386
387
388
389
        # PD disaggregation
        if self.disaggregation_mode == "prefill":
            self.disable_cuda_graph = True
            logger.warning("KV cache is forced as chunk cache for decode server")
            self.disable_overlap_schedule = True
            logger.warning("Overlap scheduler is disabled for prefill server")
        elif self.disaggregation_mode == "decode":
            self.disable_radix_cache = True
            logger.warning("Cuda graph is disabled for prefill server")
            self.disable_overlap_schedule = True
            logger.warning("Overlap scheduler is disabled for decode server")

390
391
392
393
        os.environ["SGLANG_ENABLE_TORCH_COMPILE"] = (
            "1" if self.enable_torch_compile else "0"
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
394
395
    @staticmethod
    def add_cli_args(parser: argparse.ArgumentParser):
396
        # Model and port args
Lianmin Zheng's avatar
Lianmin Zheng committed
397
398
399
400
401
402
403
404
405
406
407
408
        parser.add_argument(
            "--model-path",
            type=str,
            help="The path of the model weights. This can be a local folder or a Hugging Face repo ID.",
            required=True,
        )
        parser.add_argument(
            "--tokenizer-path",
            type=str,
            default=ServerArgs.tokenizer_path,
            help="The path of the tokenizer.",
        )
Yuanhan Zhang's avatar
Yuanhan Zhang committed
409
410
411
412
413
414
        parser.add_argument(
            "--host", type=str, default=ServerArgs.host, help="The host of the server."
        )
        parser.add_argument(
            "--port", type=int, default=ServerArgs.port, help="The port of the server."
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
415
416
417
418
419
420
421
422
423
        parser.add_argument(
            "--tokenizer-mode",
            type=str,
            default=ServerArgs.tokenizer_mode,
            choices=["auto", "slow"],
            help="Tokenizer mode. 'auto' will use the fast "
            "tokenizer if available, and 'slow' will "
            "always use the slow tokenizer.",
        )
424
425
426
427
428
        parser.add_argument(
            "--skip-tokenizer-init",
            action="store_true",
            help="If set, skip init tokenizer and pass input_ids in generate request",
        )
429
        parser.add_argument(
Lianmin Zheng's avatar
Lianmin Zheng committed
430
431
432
            "--load-format",
            type=str,
            default=ServerArgs.load_format,
433
434
435
436
437
438
            choices=[
                "auto",
                "pt",
                "safetensors",
                "npcache",
                "dummy",
439
                "sharded_state",
440
441
                "gguf",
                "bitsandbytes",
442
                "layered",
443
                "remote",
444
            ],
Lianmin Zheng's avatar
Lianmin Zheng committed
445
446
447
448
449
450
451
452
453
            help="The format of the model weights to load. "
            '"auto" will try to load the weights in the safetensors format '
            "and fall back to the pytorch bin format if safetensors format "
            "is not available. "
            '"pt" will load the weights in the pytorch bin format. '
            '"safetensors" will load the weights in the safetensors format. '
            '"npcache" will load the weights in pytorch format and store '
            "a numpy cache to speed up the loading. "
            '"dummy" will initialize the weights with random values, '
454
            "which is mainly for profiling."
455
456
            '"gguf" will load the weights in the gguf format. '
            '"bitsandbytes" will load the weights using bitsandbytes '
457
458
459
460
            "quantization."
            '"layered" loads weights layer by layer so that one can quantize a '
            "layer before loading another to make the peak memory envelope "
            "smaller.",
Lianmin Zheng's avatar
Lianmin Zheng committed
461
        )
462
463
464
465
466
        parser.add_argument(
            "--trust-remote-code",
            action="store_true",
            help="Whether or not to allow for custom models defined on the Hub in their own modeling files.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
467
        parser.add_argument(
Lianmin Zheng's avatar
Lianmin Zheng committed
468
            "--dtype",
Cody Yu's avatar
Cody Yu committed
469
            type=str,
Lianmin Zheng's avatar
Lianmin Zheng committed
470
            default=ServerArgs.dtype,
Ying Sheng's avatar
Ying Sheng committed
471
472
            choices=["auto", "half", "float16", "bfloat16", "float", "float32"],
            help="Data type for model weights and activations.\n\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
473
            '* "auto" will use FP16 precision for FP32 and FP16 models, and '
Ying Sheng's avatar
Ying Sheng committed
474
            "BF16 precision for BF16 models.\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
475
476
477
478
            '* "half" for FP16. Recommended for AWQ quantization.\n'
            '* "float16" is the same as "half".\n'
            '* "bfloat16" for a balance between precision and range.\n'
            '* "float" is shorthand for FP32 precision.\n'
Ying Sheng's avatar
Ying Sheng committed
479
480
            '* "float32" for FP32 precision.',
        )
481
482
483
484
        parser.add_argument(
            "--kv-cache-dtype",
            type=str,
            default=ServerArgs.kv_cache_dtype,
bjmsong's avatar
bjmsong committed
485
486
487
            choices=["auto", "fp8_e5m2", "fp8_e4m3"],
            help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" and "fp8_e4m3" is supported for CUDA 11.8+.',
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
488
489
490
491
        parser.add_argument(
            "--quantization",
            type=str,
            default=ServerArgs.quantization,
Ying Sheng's avatar
Ying Sheng committed
492
493
494
495
496
497
            choices=[
                "awq",
                "fp8",
                "gptq",
                "marlin",
                "gptq_marlin",
Ying Sheng's avatar
Ying Sheng committed
498
                "awq_marlin",
Ying Sheng's avatar
Ying Sheng committed
499
                "bitsandbytes",
500
                "gguf",
501
                "modelopt",
502
                "w8a8_int8",
HandH1998's avatar
HandH1998 committed
503
                "w8a8_fp8",
AniZpZ's avatar
AniZpZ committed
504
                "moe_wna16",
Ying Sheng's avatar
Ying Sheng committed
505
            ],
Lianmin Zheng's avatar
Lianmin Zheng committed
506
507
            help="The quantization method.",
        )
508
509
510
511
512
513
514
515
516
        parser.add_argument(
            "--quantization-param-path",
            type=nullable_str,
            default=None,
            help="Path to the JSON file containing the KV cache "
            "scaling factors. This should generally be supplied, when "
            "KV cache dtype is FP8. Otherwise, KV cache scaling factors "
            "default to 1.0, which may cause accuracy issues. ",
        )
517
518
519
520
521
522
523
524
525
        parser.add_argument(
            "--context-length",
            type=int,
            default=ServerArgs.context_length,
            help="The model's maximum context length. Defaults to None (will use the value from the model's config.json instead).",
        )
        parser.add_argument(
            "--device",
            type=str,
526
527
            default=ServerArgs.device,
            help="The device to use ('cuda', 'xpu', 'hpu', 'cpu'). Defaults to auto-detection if not specified.",
528
        )
529
530
531
532
533
534
        parser.add_argument(
            "--served-model-name",
            type=str,
            default=ServerArgs.served_model_name,
            help="Override the model name returned by the v1/models endpoint in OpenAI API server.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
535
536
537
538
539
540
        parser.add_argument(
            "--chat-template",
            type=str,
            default=ServerArgs.chat_template,
            help="The buliltin chat template name or the path of the chat template file. This is only used for OpenAI-compatible API server.",
        )
541
542
543
544
545
546
        parser.add_argument(
            "--completion-template",
            type=str,
            default=ServerArgs.completion_template,
            help="The buliltin completion template name or the path of the completion template file. This is only used for OpenAI-compatible API server. only for code completion currently.",
        )
547
548
549
550
551
        parser.add_argument(
            "--is-embedding",
            action="store_true",
            help="Whether to use a CausalLM as an embedding model.",
        )
552
553
554
555
556
557
558
559
        parser.add_argument(
            "--revision",
            type=str,
            default=None,
            help="The specific model version to use. It can be a branch "
            "name, a tag name, or a commit id. If unspecified, will use "
            "the default version.",
        )
560
        # Memory and scheduling
Lianmin Zheng's avatar
Lianmin Zheng committed
561
562
563
564
        parser.add_argument(
            "--mem-fraction-static",
            type=float,
            default=ServerArgs.mem_fraction_static,
565
            help="The fraction of the memory used for static allocation (model weights and KV cache memory pool). Use a smaller value if you see out-of-memory errors.",
Lianmin Zheng's avatar
Lianmin Zheng committed
566
        )
567
568
569
570
571
572
        parser.add_argument(
            "--max-running-requests",
            type=int,
            default=ServerArgs.max_running_requests,
            help="The maximum number of running requests.",
        )
573
574
575
576
        parser.add_argument(
            "--max-total-tokens",
            type=int,
            default=ServerArgs.max_total_tokens,
577
578
            help="The maximum number of tokens in the memory pool. If not specified, it will be automatically calculated based on the memory usage fraction. "
            "This option is typically used for development and debugging purposes.",
579
        )
580
581
582
583
        parser.add_argument(
            "--chunked-prefill-size",
            type=int,
            default=ServerArgs.chunked_prefill_size,
584
            help="The maximum number of tokens in a chunk for the chunked prefill. Setting this to -1 means disabling chunked prefill.",
585
586
587
588
589
590
591
        )
        parser.add_argument(
            "--max-prefill-tokens",
            type=int,
            default=ServerArgs.max_prefill_tokens,
            help="The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model's maximum context length.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
592
        parser.add_argument(
593
            "--schedule-policy",
Lianmin Zheng's avatar
Lianmin Zheng committed
594
            type=str,
595
            default=ServerArgs.schedule_policy,
Liangsheng Yin's avatar
Liangsheng Yin committed
596
            choices=["lpm", "random", "fcfs", "dfs-weight"],
597
            help="The scheduling policy of the requests.",
Lianmin Zheng's avatar
Lianmin Zheng committed
598
        )
599
600
601
602
        parser.add_argument(
            "--schedule-conservativeness",
            type=float,
            default=ServerArgs.schedule_conservativeness,
603
            help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see requests being retracted frequently.",
604
        )
605
606
607
608
        parser.add_argument(
            "--cpu-offload-gb",
            type=int,
            default=ServerArgs.cpu_offload_gb,
609
            help="How many GBs of RAM to reserve for CPU offloading.",
610
        )
611
612
613
614
615
616
        parser.add_argument(
            "--page-size",
            type=int,
            default=ServerArgs.page_size,
            help="The number of tokens in a page.",
        )
617

618
        # Other runtime options
Lianmin Zheng's avatar
Lianmin Zheng committed
619
        parser.add_argument(
620
            "--tensor-parallel-size",
Lianmin Zheng's avatar
Lianmin Zheng committed
621
            "--tp-size",
Lianmin Zheng's avatar
Lianmin Zheng committed
622
            type=int,
Lianmin Zheng's avatar
Lianmin Zheng committed
623
            default=ServerArgs.tp_size,
624
            help="The tensor parallelism size.",
625
        )
626
627
628
        parser.add_argument(
            "--stream-interval",
            type=int,
Lianmin Zheng's avatar
Lianmin Zheng committed
629
            default=ServerArgs.stream_interval,
630
            help="The interval (or buffer size) for streaming in terms of the token length. A smaller value makes streaming smoother, while a larger value makes the throughput higher",
631
        )
632
633
634
635
636
        parser.add_argument(
            "--stream-output",
            action="store_true",
            help="Whether to output as a sequence of disjoint segments.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
637
638
639
640
        parser.add_argument(
            "--random-seed",
            type=int,
            default=ServerArgs.random_seed,
641
            help="The random seed.",
Lianmin Zheng's avatar
Lianmin Zheng committed
642
        )
643
644
645
646
647
648
        parser.add_argument(
            "--constrained-json-whitespace-pattern",
            type=str,
            default=ServerArgs.constrained_json_whitespace_pattern,
            help=r"Regex pattern for syntactic whitespaces allowed in JSON constrained output. For example, to allow the model generate consecutive whitespaces, set the pattern to [\n\t ]*",
        )
649
650
651
652
653
654
        parser.add_argument(
            "--watchdog-timeout",
            type=float,
            default=ServerArgs.watchdog_timeout,
            help="Set watchdog timeout in seconds. If a forward batch takes longer than this, the server will crash to prevent hanging.",
        )
655
656
657
658
659
660
        parser.add_argument(
            "--dist-timeout",
            type=int,
            default=ServerArgs.dist_timeout,
            help="Set timeout for torch.distributed initialization.",
        )
661
662
663
664
        parser.add_argument(
            "--download-dir",
            type=str,
            default=ServerArgs.download_dir,
Lianmin Zheng's avatar
Lianmin Zheng committed
665
            help="Model download directory for huggingface.",
666
        )
667
668
669
670
671
672
        parser.add_argument(
            "--base-gpu-id",
            type=int,
            default=ServerArgs.base_gpu_id,
            help="The base GPU ID to start allocating GPUs from. Useful when running multiple instances on the same machine.",
        )
673
674
675
676
677
678
        parser.add_argument(
            "--gpu-id-step",
            type=int,
            default=ServerArgs.gpu_id_step,
            help="The delta between consecutive GPU IDs that are used. For example, setting it to 2 will use GPU 0,2,4,...",
        )
679
680

        # Logging
Lianmin Zheng's avatar
Lianmin Zheng committed
681
682
683
684
        parser.add_argument(
            "--log-level",
            type=str,
            default=ServerArgs.log_level,
685
            help="The logging level of all loggers.",
Lianmin Zheng's avatar
Lianmin Zheng committed
686
        )
687
        parser.add_argument(
688
689
690
691
            "--log-level-http",
            type=str,
            default=ServerArgs.log_level_http,
            help="The logging level of HTTP server. If not set, reuse --log-level by default.",
692
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
693
        parser.add_argument(
694
            "--log-requests",
Lianmin Zheng's avatar
Lianmin Zheng committed
695
            action="store_true",
696
697
698
699
700
701
702
703
            help="Log metadata, inputs, outputs of all requests. The verbosity is decided by --log-requests-level",
        )
        parser.add_argument(
            "--log-requests-level",
            type=int,
            default=0,
            help="0: Log metadata. 1. Log metadata and partial input/output. 2. Log every input/output.",
            choices=[0, 1, 2],
Lianmin Zheng's avatar
Lianmin Zheng committed
704
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
705
706
707
        parser.add_argument(
            "--show-time-cost",
            action="store_true",
Ying Sheng's avatar
Ying Sheng committed
708
            help="Show time cost of custom marks.",
Lianmin Zheng's avatar
Lianmin Zheng committed
709
        )
710
711
712
713
714
        parser.add_argument(
            "--enable-metrics",
            action="store_true",
            help="Enable log prometheus metrics.",
        )
715
716
717
718
        parser.add_argument(
            "--decode-log-interval",
            type=int,
            default=ServerArgs.decode_log_interval,
719
            help="The log interval of decode batch.",
720
        )
721

722
        # API related
Liangsheng Yin's avatar
Liangsheng Yin committed
723
724
725
726
        parser.add_argument(
            "--api-key",
            type=str,
            default=ServerArgs.api_key,
727
            help="Set API key of the server. It is also used in the OpenAI API compatible server.",
Liangsheng Yin's avatar
Liangsheng Yin committed
728
        )
729
        parser.add_argument(
730
            "--file-storage-path",
731
            type=str,
732
            default=ServerArgs.file_storage_path,
733
734
            help="The path of the file storage in backend.",
        )
735
736
737
738
739
        parser.add_argument(
            "--enable-cache-report",
            action="store_true",
            help="Return number of cached tokens in usage.prompt_tokens_details for each openai request.",
        )
Xihuai Wang's avatar
Xihuai Wang committed
740
741
742
743
744
745
746
        parser.add_argument(
            "--reasoning-parser",
            type=str,
            choices=list(ReasoningParser.DetectorMap.keys()),
            default=ServerArgs.reasoning_parser,
            help=f"Specify the parser for reasoning models, supported parsers are: {list(ReasoningParser.DetectorMap.keys())}.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
747

748
749
        # Data parallelism
        parser.add_argument(
750
            "--data-parallel-size",
751
752
753
            "--dp-size",
            type=int,
            default=ServerArgs.dp_size,
754
            help="The data parallelism size.",
755
756
757
758
759
        )
        parser.add_argument(
            "--load-balance-method",
            type=str,
            default=ServerArgs.load_balance_method,
760
            help="The load balancing strategy for data parallelism.",
761
762
763
764
765
            choices=[
                "round_robin",
                "shortest_queue",
            ],
        )
766

xiaobochen's avatar
xiaobochen committed
767
768
769
770
771
772
773
774
        # Expert parallelism
        parser.add_argument(
            "--expert-parallel-size",
            "--ep-size",
            type=int,
            default=ServerArgs.ep_size,
            help="The expert parallelism size.",
        )
775

776
        # Multi-node distributed serving
777
        parser.add_argument(
778
779
            "--dist-init-addr",
            "--nccl-init-addr",  # For backward compatbility. This will be removed in the future.
780
            type=str,
781
            help="The host address for initializing distributed backend (e.g., `192.168.0.2:25000`).",
782
783
        )
        parser.add_argument(
Liangsheng Yin's avatar
Liangsheng Yin committed
784
            "--nnodes", type=int, default=ServerArgs.nnodes, help="The number of nodes."
785
        )
786
787
788
        parser.add_argument(
            "--node-rank", type=int, default=ServerArgs.node_rank, help="The node rank."
        )
789

Lianmin Zheng's avatar
Lianmin Zheng committed
790
791
792
793
794
795
796
797
        # Model override args
        parser.add_argument(
            "--json-model-override-args",
            type=str,
            help="A dictionary in JSON string format used to override default model configurations.",
            default=ServerArgs.json_model_override_args,
        )

798
799
800
801
802
803
804
        # LoRA
        parser.add_argument(
            "--lora-paths",
            type=str,
            nargs="*",
            default=None,
            action=LoRAPathAction,
805
            help="The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}.",
806
807
808
809
810
        )
        parser.add_argument(
            "--max-loras-per-batch",
            type=int,
            default=8,
811
812
813
814
815
816
817
            help="Maximum number of adapters for a running batch, include base-only request.",
        )
        parser.add_argument(
            "--lora-backend",
            type=str,
            default="triton",
            help="Choose the kernel backend for multi-LoRA serving.",
818
819
820
        )

        # Kernel backend
821
822
823
        parser.add_argument(
            "--attention-backend",
            type=str,
824
            choices=["flashinfer", "triton", "torch_native", "fa3"],
825
826
827
            default=ServerArgs.attention_backend,
            help="Choose the kernels for attention layers.",
        )
828
829
830
831
832
833
834
        parser.add_argument(
            "--sampling-backend",
            type=str,
            choices=["flashinfer", "pytorch"],
            default=ServerArgs.sampling_backend,
            help="Choose the kernels for sampling layers.",
        )
835
836
837
        parser.add_argument(
            "--grammar-backend",
            type=str,
838
            choices=["xgrammar", "outlines", "llguidance", "none"],
839
            default=ServerArgs.grammar_backend,
Lianmin Zheng's avatar
Lianmin Zheng committed
840
            help="Choose the backend for grammar-guided decoding.",
841
        )
842
843
844
        parser.add_argument(
            "--enable-flashinfer-mla",
            action="store_true",
845
            help="Enable FlashInfer MLA optimization",
846
        )
lukec's avatar
lukec committed
847
848
849
850
851
        parser.add_argument(
            "--enable-flashmla",
            action="store_true",
            help="Enable FlashMLA decode optimization",
        )
852
853
854
855
856
        parser.add_argument(
            "--flashinfer-mla-disable-ragged",
            action="store_true",
            help="Not using ragged prefill wrapper when running flashinfer mla",
        )
857

858
859
860
861
        # Speculative decoding
        parser.add_argument(
            "--speculative-algorithm",
            type=str,
James Liu's avatar
James Liu committed
862
            choices=["EAGLE", "EAGLE3", "NEXTN"],
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
            help="Speculative algorithm.",
        )
        parser.add_argument(
            "--speculative-draft-model-path",
            type=str,
            help="The path of the draft model weights. This can be a local folder or a Hugging Face repo ID.",
        )
        parser.add_argument(
            "--speculative-num-steps",
            type=int,
            help="The number of steps sampled from draft model in Speculative Decoding.",
            default=ServerArgs.speculative_num_steps,
        )
        parser.add_argument(
            "--speculative-eagle-topk",
            type=int,
879
            help="The number of tokens sampled from the draft model in eagle2 each step.",
880
881
            default=ServerArgs.speculative_eagle_topk,
        )
882
883
884
        parser.add_argument(
            "--speculative-num-draft-tokens",
            type=int,
885
            help="The number of tokens sampled from the draft model in Speculative Decoding.",
886
887
            default=ServerArgs.speculative_num_draft_tokens,
        )
888
889
890
891
892
893
894
895
896
897
898
899
        parser.add_argument(
            "--speculative-accept-threshold-single",
            type=float,
            help="Accept a draft token if its probability in the target model is greater than this threshold.",
            default=ServerArgs.speculative_accept_threshold_single,
        )
        parser.add_argument(
            "--speculative-accept-threshold-acc",
            type=float,
            help="The accept probability of a draft token is raised from its target probability p to min(1, p / threshold_acc).",
            default=ServerArgs.speculative_accept_threshold_acc,
        )
900
901
902
903
904
905
        parser.add_argument(
            "--speculative-token-map",
            type=str,
            help="The path of the draft model's small vocab table.",
            default=ServerArgs.speculative_token_map,
        )
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943

        # Double Sparsity
        parser.add_argument(
            "--enable-double-sparsity",
            action="store_true",
            help="Enable double sparsity attention",
        )
        parser.add_argument(
            "--ds-channel-config-path",
            type=str,
            default=ServerArgs.ds_channel_config_path,
            help="The path of the double sparsity channel config",
        )
        parser.add_argument(
            "--ds-heavy-channel-num",
            type=int,
            default=ServerArgs.ds_heavy_channel_num,
            help="The number of heavy channels in double sparsity attention",
        )
        parser.add_argument(
            "--ds-heavy-token-num",
            type=int,
            default=ServerArgs.ds_heavy_token_num,
            help="The number of heavy tokens in double sparsity attention",
        )
        parser.add_argument(
            "--ds-heavy-channel-type",
            type=str,
            default=ServerArgs.ds_heavy_channel_type,
            help="The type of heavy channels in double sparsity attention",
        )
        parser.add_argument(
            "--ds-sparse-decode-threshold",
            type=int,
            default=ServerArgs.ds_sparse_decode_threshold,
            help="The type of heavy channels in double sparsity attention",
        )

944
        # Optimization/debug options
Liangsheng Yin's avatar
Liangsheng Yin committed
945
        parser.add_argument(
Lianmin Zheng's avatar
Lianmin Zheng committed
946
            "--disable-radix-cache",
Liangsheng Yin's avatar
Liangsheng Yin committed
947
            action="store_true",
Ying Sheng's avatar
Ying Sheng committed
948
            help="Disable RadixAttention for prefix caching.",
Liangsheng Yin's avatar
Liangsheng Yin committed
949
        )
950
951
952
953
954
        parser.add_argument(
            "--disable-cuda-graph",
            action="store_true",
            help="Disable cuda graph.",
        )
955
        parser.add_argument(
956
957
958
959
            "--disable-cuda-graph-padding",
            action="store_true",
            help="Disable cuda graph when padding is needed. Still uses cuda graph when padding is not needed.",
        )
960
961
962
963
964
        parser.add_argument(
            "--enable-nccl-nvls",
            action="store_true",
            help="Enable NCCL NVLS for prefill heavy requests when available.",
        )
965
        parser.add_argument(
966
            "--disable-outlines-disk-cache",
967
            action="store_true",
968
            help="Disable disk cache of outlines to avoid possible crashes related to file system or high concurrency.",
969
        )
970
971
972
973
974
        parser.add_argument(
            "--disable-custom-all-reduce",
            action="store_true",
            help="Disable the custom all-reduce kernel and fall back to NCCL.",
        )
Ke Bao's avatar
Ke Bao committed
975
976
977
        parser.add_argument(
            "--disable-mla",
            action="store_true",
Xiaoyu Zhang's avatar
Xiaoyu Zhang committed
978
            help="Disable Multi-head Latent Attention (MLA) for DeepSeek V2/V3/R1 series models.",
Ke Bao's avatar
Ke Bao committed
979
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
980
        parser.add_argument(
981
            "--disable-overlap-schedule",
Lianmin Zheng's avatar
Lianmin Zheng committed
982
            action="store_true",
983
            help="Disable the overlap scheduler, which overlaps the CPU scheduler with GPU model worker.",
Lianmin Zheng's avatar
Lianmin Zheng committed
984
        )
985
986
987
        parser.add_argument(
            "--enable-mixed-chunk",
            action="store_true",
988
            help="Enabling mixing prefill and decode in a batch when using chunked prefill.",
989
        )
Ke Bao's avatar
Ke Bao committed
990
991
992
993
994
        parser.add_argument(
            "--enable-dp-attention",
            action="store_true",
            help="Enabling data parallelism for attention and tensor parallelism for FFN. The dp size should be equal to the tp size. Currently only DeepSeek-V2 is supported.",
        )
xiaobochen's avatar
xiaobochen committed
995
996
997
998
999
        parser.add_argument(
            "--enable-ep-moe",
            action="store_true",
            help="Enabling expert parallelism for moe. The ep size is equal to the tp size.",
        )
1000
1001
1002
        parser.add_argument(
            "--enable-torch-compile",
            action="store_true",
1003
1004
            help="Optimize the model with torch.compile. Experimental feature.",
        )
1005
        parser.add_argument(
1006
            "--torch-compile-max-bs",
1007
            type=int,
1008
            default=ServerArgs.torch_compile_max_bs,
1009
1010
            help="Set the maximum batch size when using torch compile.",
        )
1011
        parser.add_argument(
1012
            "--cuda-graph-max-bs",
1013
            type=int,
1014
            default=ServerArgs.cuda_graph_max_bs,
1015
1016
            help="Set the maximum batch size for cuda graph.",
        )
1017
1018
1019
1020
1021
1022
        parser.add_argument(
            "--cuda-graph-bs",
            type=int,
            nargs="+",
            help="Set the list of batch sizes for cuda graph.",
        )
1023
1024
1025
1026
        parser.add_argument(
            "--torchao-config",
            type=str,
            default=ServerArgs.torchao_config,
1027
            help="Optimize the model with torchao. Experimental feature. Current choices are: int8dq, int8wo, int4wo-<group_size>, fp8wo, fp8dq-per_tensor, fp8dq-per_row",
1028
        )
1029
1030
1031
1032
1033
        parser.add_argument(
            "--enable-nan-detection",
            action="store_true",
            help="Enable the NaN detection for debugging purposes.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1034
        parser.add_argument(
1035
            "--enable-p2p-check",
Lianmin Zheng's avatar
Lianmin Zheng committed
1036
            action="store_true",
1037
            help="Enable P2P check for GPU access, otherwise the p2p access is allowed by default.",
Lianmin Zheng's avatar
Lianmin Zheng committed
1038
        )
1039
        parser.add_argument(
1040
            "--triton-attention-reduce-in-fp32",
1041
            action="store_true",
1042
            help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
1043
            "This only affects Triton attention kernels.",
1044
        )
1045
1046
1047
1048
1049
1050
        parser.add_argument(
            "--triton-attention-num-kv-splits",
            type=int,
            default=ServerArgs.triton_attention_num_kv_splits,
            help="The number of KV splits in flash decoding Triton kernel. Larger value is better in longer context scenarios. The default value is 8.",
        )
1051
1052
1053
1054
1055
1056
1057
1058
        parser.add_argument(
            "--num-continuous-decode-steps",
            type=int,
            default=ServerArgs.num_continuous_decode_steps,
            help="Run multiple continuous decoding steps to reduce scheduling overhead. "
            "This can potentially increase throughput but may also increase time-to-first-token latency. "
            "The default value is 1, meaning only run one decoding step at a time.",
        )
1059
1060
1061
1062
1063
        parser.add_argument(
            "--delete-ckpt-after-loading",
            action="store_true",
            help="Delete the model checkpoint after loading the model.",
        )
1064
1065
1066
1067
1068
        parser.add_argument(
            "--enable-memory-saver",
            action="store_true",
            help="Allow saving memory using release_memory_occupation and resume_memory_occupation",
        )
1069
1070
1071
1072
1073
        parser.add_argument(
            "--allow-auto-truncate",
            action="store_true",
            help="Allow automatically truncating requests that exceed the maximum input length instead of returning an error.",
        )
1074
1075
1076
1077
1078
        parser.add_argument(
            "--enable-custom-logit-processor",
            action="store_true",
            help="Enable users to pass custom logit processors to the server (disabled by default for security)",
        )
YAMY's avatar
YAMY committed
1079
1080
1081
1082
1083
1084
1085
        parser.add_argument(
            "--tool-call-parser",
            type=str,
            choices=["qwen25", "mistral", "llama3"],
            default=ServerArgs.tool_call_parser,
            help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', and 'llama3'.",
        )
1086
1087
1088
1089
1090
        parser.add_argument(
            "--enable-hierarchical-cache",
            action="store_true",
            help="Enable hierarchical cache",
        )
1091
1092
1093
1094
1095
1096
1097
        parser.add_argument(
            "--hicache-ratio",
            type=float,
            required=False,
            default=ServerArgs.hicache_ratio,
            help="The ratio of the size of host KV cache memory pool to the size of device pool.",
        )
1098
1099
1100
1101
1102
        parser.add_argument(
            "--enable-deepep-moe",
            action="store_true",
            help="Enabling DeepEP MoE implementation for EP MoE.",
        )
1103
1104
1105
1106
1107
1108
        parser.add_argument(
            "--deepep-mode",
            type=str,
            choices=["normal", "low_latency", "auto"],
            help="Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch.",
        )
1109

1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
        parser.add_argument(
            "--n-share-experts-fusion",
            type=int,
            default=None,
            help="The number of shared_experts need to be replica to fuse with normal experts in deepseek v3/r1 "
            "we use tp_size by default.",
        )
        parser.add_argument(
            "--disable-shared-experts-fusion",
            action="store_true",
            help="Disable shared experts fusion by setting n_share_experts_fusion to 0.",
        )

1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
        # Server warmups
        parser.add_argument(
            "--warmups",
            type=str,
            required=False,
            help="Specify custom warmup functions (csv) to run before server starts eg. --warmups=warmup_name1,warmup_name2 "
            "will run the functions `warmup_name1` and `warmup_name2` specified in warmup.py before the server starts listening for requests",
        )

        # Debug tensor dumps
        parser.add_argument(
            "--debug-tensor-dump-output-folder",
            type=str,
            default=ServerArgs.debug_tensor_dump_output_folder,
            help="The output folder for dumping tensors.",
        )
        parser.add_argument(
            "--debug-tensor-dump-input-file",
            type=str,
            default=ServerArgs.debug_tensor_dump_input_file,
            help="The input filename for dumping tensors",
        )
        parser.add_argument(
            "--debug-tensor-dump-inject",
            type=str,
            default=ServerArgs.debug_tensor_dump_inject,
            help="Inject the outputs from jax as the input of every layer.",
        )

Byron Hsu's avatar
Byron Hsu committed
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
        # Disaggregation
        parser.add_argument(
            "--disaggregation-mode",
            type=str,
            default="null",
            choices=["null", "prefill", "decode"],
            help='Only used for PD disaggregation. "prefill" for prefill-only server, and "decode" for decode-only server. If not specified, it is not PD disaggregated',
        )
        parser.add_argument(
            "--disaggregation-bootstrap-port",
            type=int,
            default=ServerArgs.disaggregation_bootstrap_port,
            help="Bootstrap server port on the prefill server. Default is 8998.",
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
1167
1168
    @classmethod
    def from_cli_args(cls, args: argparse.Namespace):
1169
1170
        args.tp_size = args.tensor_parallel_size
        args.dp_size = args.data_parallel_size
xiaobochen's avatar
xiaobochen committed
1171
        args.ep_size = args.expert_parallel_size
Lianmin Zheng's avatar
Lianmin Zheng committed
1172
1173
1174
1175
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        return cls(**{attr: getattr(args, attr) for attr in attrs})

    def url(self):
1176
        if is_valid_ipv6_address(self.host):
1177
1178
1179
            return f"http://[{self.host}]:{self.port}"
        else:
            return f"http://{self.host}:{self.port}"
Lianmin Zheng's avatar
Lianmin Zheng committed
1180

1181
1182
1183
1184
1185
    def check_server_args(self):
        assert (
            self.tp_size % self.nnodes == 0
        ), "tp_size must be divisible by number of nodes"
        assert not (
1186
1187
            self.dp_size > 1 and self.nnodes != 1 and not self.enable_dp_attention
        ), "multi-node data parallel is not supported unless dp attention!"
1188
1189
1190
1191
1192
1193
        assert (
            self.max_loras_per_batch > 0
            # FIXME
            and (self.lora_paths is None or self.disable_cuda_graph)
            and (self.lora_paths is None or self.disable_radix_cache)
        ), "compatibility of lora and cuda graph and radix attention is in progress"
1194
        assert self.base_gpu_id >= 0, "base_gpu_id must be non-negative"
1195
        assert self.gpu_id_step >= 1, "gpu_id_step must be positive"
1196

1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
        if isinstance(self.lora_paths, list):
            lora_paths = self.lora_paths
            self.lora_paths = {}
            for lora_path in lora_paths:
                if "=" in lora_path:
                    name, path = lora_path.split("=", 1)
                    self.lora_paths[name] = path
                else:
                    self.lora_paths[lora_path] = lora_path

Lianmin Zheng's avatar
Lianmin Zheng committed
1207

Lianmin Zheng's avatar
Lianmin Zheng committed
1208
def prepare_server_args(argv: List[str]) -> ServerArgs:
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
    """
    Prepare the server arguments from the command line arguments.

    Args:
        args: The command line arguments. Typically, it should be `sys.argv[1:]`
            to ensure compatibility with `parse_args` when no arguments are passed.

    Returns:
        The server arguments.
    """
    parser = argparse.ArgumentParser()
    ServerArgs.add_cli_args(parser)
Lianmin Zheng's avatar
Lianmin Zheng committed
1221
    raw_args = parser.parse_args(argv)
1222
1223
1224
1225
    server_args = ServerArgs.from_cli_args(raw_args)
    return server_args


1226
1227
1228
ZMQ_TCP_PORT_DELTA = 233


Lianmin Zheng's avatar
Lianmin Zheng committed
1229
1230
@dataclasses.dataclass
class PortArgs:
1231
1232
1233
1234
1235
1236
    # The ipc filename for tokenizer to receive inputs from detokenizer (zmq)
    tokenizer_ipc_name: str
    # The ipc filename for scheduler (rank 0) to receive inputs from tokenizer (zmq)
    scheduler_input_ipc_name: str
    # The ipc filename for detokenizer to receive inputs from scheduler (zmq)
    detokenizer_ipc_name: str
1237

1238
1239
    # The port for nccl initialization (torch.dist)
    nccl_port: int
1240

1241
1242
1243
    # The ipc filename for rpc call between Engine and Scheduler
    rpc_ipc_name: str

1244
    @staticmethod
1245
    def init_new(server_args, dp_rank: Optional[int] = None) -> "PortArgs":
1246
        port = server_args.port + random.randint(100, 1000)
1247
1248
1249
        while True:
            if is_port_available(port):
                break
TianYu GUO's avatar
TianYu GUO committed
1250
1251
1252
1253
            if port < 60000:
                port += 42
            else:
                port -= 43
1254

1255
1256
1257
1258
1259
1260
1261
        if not server_args.enable_dp_attention:
            # Normal case, use IPC within a single node
            return PortArgs(
                tokenizer_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
                scheduler_input_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
                detokenizer_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
                nccl_port=port,
1262
                rpc_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
1263
1264
1265
1266
1267
            )
        else:
            # DP attention. Use TCP + port to handle both single-node and multi-node.
            if server_args.nnodes == 1 and server_args.dist_init_addr is None:
                dist_init_addr = ("127.0.0.1", server_args.port + ZMQ_TCP_PORT_DELTA)
Vincent's avatar
Vincent committed
1268
1269
1270
            elif server_args.dist_init_addr.startswith("["):  # ipv6 address
                port_num, host = configure_ipv6(server_args.dist_init_addr)
                dist_init_addr = (host, str(port_num))
1271
1272
            else:
                dist_init_addr = server_args.dist_init_addr.split(":")
Vincent's avatar
Vincent committed
1273

1274
1275
1276
1277
1278
1279
1280
1281
            assert (
                len(dist_init_addr) == 2
            ), "please provide --dist-init-addr as host:port of head node"

            dist_init_host, dist_init_port = dist_init_addr
            port_base = int(dist_init_port) + 1
            if dp_rank is None:
                scheduler_input_port = (
1282
                    port_base + 3
1283
                )  # TokenizerManager to DataParallelController
1284
            else:
1285
                scheduler_input_port = port_base + 3 + 1 + dp_rank
1286
1287
1288
1289
1290
1291

            return PortArgs(
                tokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base}",
                scheduler_input_ipc_name=f"tcp://{dist_init_host}:{scheduler_input_port}",
                detokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base + 1}",
                nccl_port=port,
1292
                rpc_ipc_name=f"tcp://{dist_init_host}:{port_base + 2}",
1293
            )
1294

1295
1296
1297
1298
1299
1300
1301
1302
1303
1304

class LoRAPathAction(argparse.Action):
    def __call__(self, parser, namespace, values, option_string=None):
        setattr(namespace, self.dest, {})
        for lora_path in values:
            if "=" in lora_path:
                name, path = lora_path.split("=", 1)
                getattr(namespace, self.dest)[name] = path
            else:
                getattr(namespace, self.dest)[lora_path] = lora_path
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314


class DeprecatedAction(argparse.Action):
    def __init__(self, option_strings, dest, nargs=0, **kwargs):
        super(DeprecatedAction, self).__init__(
            option_strings, dest, nargs=nargs, **kwargs
        )

    def __call__(self, parser, namespace, values, option_string=None):
        raise ValueError(self.help)
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344


def auto_choose_speculative_params(self: ServerArgs):
    """
    Automatically choose the parameters for speculative decoding.

    You can tune them on your own models and prompts with scripts/playground/bench_speculative.py
    """
    if self.decrypted_config_file:
        config_path = self.decrypted_config_file
    else:
        config_path = os.path.join(self.model_path, "config.json")
    if not os.path.exists(config_path):
        raise ValueError(f"{config_path} is not found.")

    config = json.load(open(config_path))

    arch = config.get("architectures", ["Unknown"])[0]

    if arch in ["LlamaForCausalLM"]:
        # The default value for llama
        return (5, 4, 8)
    elif arch in ["DeepseekV3ForCausalLM", "DeepseekV2ForCausalLM"]:
        # The default value for deepseek
        return (5, 4, 8)
    elif arch in ["Grok1ForCausalLM", "Grok1VForCausalLM"]:
        return (5, 4, 8)
    else:
        # The default value for all other models
        return (5, 4, 8)