server_args.py 66.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
Lianmin Zheng's avatar
Lianmin Zheng committed
14
15
"""The arguments of the server."""

Lianmin Zheng's avatar
Lianmin Zheng committed
16
17
import argparse
import dataclasses
18
import json
19
import logging
20
import os
21
import random
22
import tempfile
23
from typing import List, Literal, Optional
Lianmin Zheng's avatar
Lianmin Zheng committed
24

25
from sglang.srt.hf_transformers_utils import check_gguf_file, get_config
Xihuai Wang's avatar
Xihuai Wang committed
26
from sglang.srt.reasoning_parser import ReasoningParser
27
from sglang.srt.utils import (
Vincent's avatar
Vincent committed
28
    configure_ipv6,
29
    get_device,
Lianmin Zheng's avatar
Lianmin Zheng committed
30
    get_device_memory_capacity,
31
    is_cuda,
32
    is_flashinfer_available,
HAI's avatar
HAI committed
33
    is_hip,
34
    is_port_available,
35
    is_remote_url,
36
    is_valid_ipv6_address,
bjmsong's avatar
bjmsong committed
37
    nullable_str,
38
)
39

40
41
logger = logging.getLogger(__name__)

Lianmin Zheng's avatar
Lianmin Zheng committed
42
43
44

@dataclasses.dataclass
class ServerArgs:
Lianmin Zheng's avatar
Lianmin Zheng committed
45
    # Model and tokenizer
Lianmin Zheng's avatar
Lianmin Zheng committed
46
47
48
    model_path: str
    tokenizer_path: Optional[str] = None
    tokenizer_mode: str = "auto"
49
    skip_tokenizer_init: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
50
    load_format: str = "auto"
51
    trust_remote_code: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
52
    dtype: str = "auto"
53
    kv_cache_dtype: str = "auto"
Lianmin Zheng's avatar
Lianmin Zheng committed
54
    quantization: Optional[str] = None
Vincent's avatar
Vincent committed
55
    quantization_param_path: Optional[str] = None
56
    context_length: Optional[int] = None
57
    device: Optional[str] = None
58
    served_model_name: Optional[str] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
59
    chat_template: Optional[str] = None
60
    completion_template: Optional[str] = None
61
    is_embedding: bool = False
62
    enable_multimodal: Optional[bool] = None
63
    revision: Optional[str] = None
64
    impl: str = "auto"
Lianmin Zheng's avatar
Lianmin Zheng committed
65

66
    # Port for the HTTP server
Lianmin Zheng's avatar
Lianmin Zheng committed
67
68
69
70
    host: str = "127.0.0.1"
    port: int = 30000

    # Memory and scheduling
Lianmin Zheng's avatar
Lianmin Zheng committed
71
    mem_fraction_static: Optional[float] = None
72
    max_running_requests: Optional[int] = None
73
    max_total_tokens: Optional[int] = None
74
    chunked_prefill_size: Optional[int] = None
75
    max_prefill_tokens: int = 16384
76
    schedule_policy: str = "fcfs"
77
    schedule_conservativeness: float = 1.0
78
    cpu_offload_gb: int = 0
79
    page_size: int = 1
Lianmin Zheng's avatar
Lianmin Zheng committed
80
81
82

    # Other runtime options
    tp_size: int = 1
83
84
    pp_size: int = 1
    max_micro_batch_size: Optional[int] = None
85
    stream_interval: int = 1
86
    stream_output: bool = False
87
    random_seed: Optional[int] = None
88
    constrained_json_whitespace_pattern: Optional[str] = None
89
    watchdog_timeout: float = 300
90
    dist_timeout: Optional[int] = None  # timeout for torch.distributed
91
    download_dir: Optional[str] = None
92
    base_gpu_id: int = 0
93
    gpu_id_step: int = 1
Lianmin Zheng's avatar
Lianmin Zheng committed
94
95
96

    # Logging
    log_level: str = "info"
97
    log_level_http: Optional[str] = None
98
    log_requests: bool = False
99
    log_requests_level: int = 0
Liangsheng Yin's avatar
Liangsheng Yin committed
100
    show_time_cost: bool = False
101
    enable_metrics: bool = False
102
103
104
105
    bucket_time_to_first_token: Optional[List[float]] = None
    bucket_e2e_request_latency: Optional[List[float]] = None
    bucket_inter_token_latency: Optional[List[float]] = None
    collect_tokens_histogram: bool = False
106
    decode_log_interval: int = 40
107
    enable_request_time_stats_logging: bool = False
108
    kv_events_config: Optional[str] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
109

110
    # API related
111
    api_key: Optional[str] = None
112
    file_storage_path: str = "sglang_storage"
113
    enable_cache_report: bool = False
Xihuai Wang's avatar
Xihuai Wang committed
114
    reasoning_parser: Optional[str] = None
115
    tool_call_parser: Optional[str] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
116

117
118
119
    # Data parallelism
    dp_size: int = 1
    load_balance_method: str = "round_robin"
120

121
    # Multi-node distributed serving
122
    dist_init_addr: Optional[str] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
123
    nnodes: int = 1
124
    node_rank: int = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
125
126
127

    # Model override args in JSON
    json_model_override_args: str = "{}"
128
    preferred_sampling_params: Optional[str] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
129

130
131
132
    # LoRA
    lora_paths: Optional[List[str]] = None
    max_loras_per_batch: int = 8
133
    lora_backend: str = "triton"
134
135

    # Kernel backend
136
137
    attention_backend: Optional[str] = None
    sampling_backend: Optional[str] = None
138
    grammar_backend: Optional[str] = None
139
    mm_attention_backend: Optional[str] = None
140

141
142
    # Speculative decoding
    speculative_algorithm: Optional[str] = None
143
    speculative_draft_model_path: Optional[str] = None
144
145
146
    speculative_num_steps: Optional[int] = None
    speculative_eagle_topk: Optional[int] = None
    speculative_num_draft_tokens: Optional[int] = None
147
148
    speculative_accept_threshold_single: float = 1.0
    speculative_accept_threshold_acc: float = 1.0
149
    speculative_token_map: Optional[str] = None
150

151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
    # Expert parallelism
    ep_size: int = 1
    enable_ep_moe: bool = False
    enable_deepep_moe: bool = False
    deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto"
    ep_num_redundant_experts: int = 0
    ep_dispatch_algorithm: Optional[Literal["static", "dynamic", "fake"]] = None
    init_expert_location: str = "trivial"
    enable_eplb: bool = False
    eplb_algorithm: str = "auto"
    eplb_rebalance_num_iterations: int = 1000
    eplb_rebalance_layers_per_chunk: Optional[int] = None
    expert_distribution_recorder_mode: Optional[
        Literal["stat", "stat_approx", "per_pass", "per_token"]
    ] = None
    expert_distribution_recorder_buffer_size: Optional[int] = None
    enable_expert_distribution_metrics: bool = False
    deepep_config: Optional[str] = None
    moe_dense_tp_size: Optional[int] = None

171
172
    # Double Sparsity
    enable_double_sparsity: bool = False
Vincent's avatar
Vincent committed
173
    ds_channel_config_path: Optional[str] = None
174
175
176
177
178
    ds_heavy_channel_num: int = 32
    ds_heavy_token_num: int = 256
    ds_heavy_channel_type: str = "qk"
    ds_sparse_decode_threshold: int = 4096

179
    # Optimization/debug options
Lianmin Zheng's avatar
Lianmin Zheng committed
180
    disable_radix_cache: bool = False
181
182
    cuda_graph_max_bs: Optional[int] = None
    cuda_graph_bs: Optional[List[int]] = None
183
    disable_cuda_graph: bool = False
184
    disable_cuda_graph_padding: bool = False
185
    enable_profile_cuda_graph: bool = False
186
    enable_nccl_nvls: bool = False
187
    enable_tokenizer_batch_encode: bool = False
188
    disable_outlines_disk_cache: bool = False
189
    disable_custom_all_reduce: bool = False
190
    enable_mscclpp: bool = False
191
    disable_overlap_schedule: bool = False
192
    disable_overlap_cg_plan: bool = False
193
    enable_mixed_chunk: bool = False
Ke Bao's avatar
Ke Bao committed
194
    enable_dp_attention: bool = False
195
    enable_dp_lm_head: bool = False
196
    enable_two_batch_overlap: bool = False
197
    enable_torch_compile: bool = False
198
    torch_compile_max_bs: int = 32
199
    torchao_config: str = ""
200
    enable_nan_detection: bool = False
201
    enable_p2p_check: bool = False
202
    triton_attention_reduce_in_fp32: bool = False
203
    triton_attention_num_kv_splits: int = 8
204
    num_continuous_decode_steps: int = 1
205
    delete_ckpt_after_loading: bool = False
206
    enable_memory_saver: bool = False
207
    allow_auto_truncate: bool = False
208
    enable_custom_logit_processor: bool = False
209
    enable_hierarchical_cache: bool = False
210
    hicache_ratio: float = 2.0
Zhiqiang Xie's avatar
Zhiqiang Xie committed
211
212
    hicache_size: int = 0
    hicache_write_policy: str = "write_through_selective"
213
    flashinfer_mla_disable_ragged: bool = False
214
    disable_shared_experts_fusion: bool = False
215
    disable_chunked_prefix_cache: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
216
    disable_fast_image_processor: bool = False
217
    warmups: Optional[str] = None
218
219
220
221
222

    # Debug tensor dumps
    debug_tensor_dump_output_folder: Optional[str] = None
    debug_tensor_dump_input_file: Optional[str] = None
    debug_tensor_dump_inject: bool = False
223
    debug_tensor_dump_prefill_only: bool = False
224

Byron Hsu's avatar
Byron Hsu committed
225
226
    # For PD disaggregation: can be "null" (not disaggregated), "prefill" (prefill-only), or "decode" (decode-only)
    disaggregation_mode: str = "null"
227
    disaggregation_transfer_backend: str = "mooncake"
228
    disaggregation_bootstrap_port: int = 8998
229
    disaggregation_ib_device: Optional[str] = None
230
    num_reserved_decode_tokens: int = 512  # used for decode kv cache offload in PD
231
    pdlb_url: Optional[str] = None
Byron Hsu's avatar
Byron Hsu committed
232

Lianmin Zheng's avatar
Lianmin Zheng committed
233
    def __post_init__(self):
234
235
236
        # Expert parallelism
        if self.enable_ep_moe:
            self.ep_size = self.tp_size
Lianmin Zheng's avatar
Lianmin Zheng committed
237
            logger.warning(
238
239
240
                f"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
            )

241
        # Set missing default values
Lianmin Zheng's avatar
Lianmin Zheng committed
242
243
        if self.tokenizer_path is None:
            self.tokenizer_path = self.model_path
244

245
246
247
        if self.device is None:
            self.device = get_device()

248
249
250
        if self.served_model_name is None:
            self.served_model_name = self.model_path

251
252
253
        if self.random_seed is None:
            self.random_seed = random.randint(0, 1 << 30)

Lianmin Zheng's avatar
Lianmin Zheng committed
254
        gpu_mem = get_device_memory_capacity(self.device)
255
256

        # Set mem fraction static, which depends on the tensor parallelism size
Lianmin Zheng's avatar
Lianmin Zheng committed
257
        if self.mem_fraction_static is None:
258
            parallel_size = self.tp_size * self.pp_size
Yi Liu's avatar
Yi Liu committed
259
            if gpu_mem is not None and gpu_mem <= 81920:
260
261
262
263
264
265
266
267
268
269
                if parallel_size >= 16:
                    self.mem_fraction_static = 0.79
                elif parallel_size >= 8:
                    self.mem_fraction_static = 0.81
                elif parallel_size >= 4:
                    self.mem_fraction_static = 0.85
                elif parallel_size >= 2:
                    self.mem_fraction_static = 0.87
                else:
                    self.mem_fraction_static = 0.88
Ying Sheng's avatar
Ying Sheng committed
270
            else:
271
                self.mem_fraction_static = 0.88
272
            if gpu_mem is not None and gpu_mem > 180 * 1000 and is_cuda():
273
274
                self.mem_fraction_static = 0.79
            elif gpu_mem is not None and gpu_mem > 96 * 1024:
275
                mem_fraction = self.mem_fraction_static
276
277
278
279
280
                # 15 GB + additional 3GB for cuda graph
                reserve_mem = 1024 * 18
                # need reserve more memory for spec cuda graph
                if self.speculative_algorithm is not None:
                    reserve_mem = 1024 * 20
281
282
                self.mem_fraction_static = min(
                    mem_fraction + 48 * 1024 * (1 - mem_fraction) / gpu_mem,
283
                    (gpu_mem - reserve_mem) / gpu_mem,
284
                )
285
286
287
            else:
                if self.speculative_algorithm is not None:
                    self.mem_fraction_static *= 0.95
288

289
290
        # Set chunked prefill size, which depends on the gpu memory capacity
        if self.chunked_prefill_size is None:
291
292
293
            if gpu_mem is not None and gpu_mem > 180_000:
                self.chunked_prefill_size = 16384
            elif gpu_mem is not None and gpu_mem < 25_000:
294
                self.chunked_prefill_size = 2048
295
296
            elif self.disaggregation_mode != "null":
                self.chunked_prefill_size = 16384
297
298
            else:
                self.chunked_prefill_size = 8192
Lianmin Zheng's avatar
Lianmin Zheng committed
299
300
        assert self.chunked_prefill_size % self.page_size == 0

301
302
303
        assert self.moe_dense_tp_size in {
            1,
            None,
Lianmin Zheng's avatar
Lianmin Zheng committed
304
        }, "moe_dense_tp_size only support 1 and None currently"
305

306
        if self.attention_backend == "flashmla":
307
308
309
310
            logger.warning(
                "FlashMLA only supports a page_size of 64, change page_size to 64."
            )
            self.page_size = 64
Lianmin Zheng's avatar
Lianmin Zheng committed
311

312
313
314
315
316
317
        if self.attention_backend == "cutlass_mla":
            logger.warning(
                "Cutlass MLA only supports a page_size of 128, change page_size to 128."
            )
            self.page_size = 128

318
        # Set cuda graph max batch size
319
        if self.cuda_graph_max_bs is None:
320
            # Based on detailed statistics, when serving TP1/TP2 models on lower-end GPUs with HBM<25G, you can either disable cuda graph or set `cuda_graph_max_bs` to a very small value to reduce the memory overhead of creating cuda graphs, with almost no impact on performance. However, when serving models with TP4 or TP8, we need to enable cuda graph to maintain high performance. In this case, we can set `cuda_graph_max_bs` to 80 (half of the default value 160) to reduce the memory overhead of creating cuda graphs. Looking at the logs from TP4 serving of qwen2-72b, a value of 80 is sufficient and can reduce the memory overhead of creating cuda graphs on lower-end GPUs compared to the original 160, avoiding OOM issues.
321
            if gpu_mem is not None and gpu_mem < 25_000:
322
323
324
325
                if self.tp_size < 4:
                    self.cuda_graph_max_bs = 8
                else:
                    self.cuda_graph_max_bs = 80
326

327
        # Set kernel backends for hpu device
328
329
330
331
        if self.device == "hpu":
            self.attention_backend = "torch_native"
            self.sampling_backend = "pytorch"

Lianmin Zheng's avatar
Lianmin Zheng committed
332
        # Set kernel backends
333
334
335
336
337
        if self.device == "cpu":
            if self.attention_backend is None:
                self.attention_backend = "intel_amx"
            self.sampling_backend = "pytorch"

338
        if self.sampling_backend is None:
339
340
341
342
343
            self.sampling_backend = (
                "flashinfer" if is_flashinfer_available() else "pytorch"
            )

        if self.attention_backend == "torch_native":
344
            logger.warning(
345
346
347
                "Cuda graph is disabled because of using torch native attention backend"
            )
            self.disable_cuda_graph = True
348

349
350
351
        # Choose grammar backend
        if self.grammar_backend is None:
            self.grammar_backend = "xgrammar"
352

353
        # Data parallelism attention
Ke Bao's avatar
Ke Bao committed
354
        if self.enable_dp_attention:
355
            self.schedule_conservativeness = self.schedule_conservativeness * 0.3
Lianmin Zheng's avatar
Lianmin Zheng committed
356
357
358
359
360
            assert (
                self.dp_size > 1
            ), "Please set a dp-size > 1. You can use 1 < dp-size <= tp-size "
            assert self.tp_size % self.dp_size == 0
            self.chunked_prefill_size = self.chunked_prefill_size // self.dp_size
361
            logger.warning(
362
                f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE kernel issues. "
363
            )
364

365
366
367
368
369
        if self.enable_dp_lm_head:
            assert (
                self.enable_dp_attention
            ), "Please enable dp attention when setting enable_dp_attention. "

370
        # DeepEP MoE
Lianmin Zheng's avatar
Lianmin Zheng committed
371
        self.enable_sp_layernorm = False
372
        if self.enable_deepep_moe:
373
374
375
376
            if self.deepep_mode == "auto":
                assert (
                    not self.enable_dp_attention
                ), "DeepEP MoE `auto` mode is not supported with DP Attention."
377
378
379
            if self.deepep_mode == "normal":
                logger.warning("Cuda graph is disabled because deepep_mode=`normal`")
                self.disable_cuda_graph = True
380
381
382
383
            self.ep_size = self.tp_size
            self.enable_sp_layernorm = (
                self.dp_size < self.tp_size if self.enable_dp_attention else True
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
384
            logger.warning(
385
386
                f"DeepEP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
            )
387

388
389
390
391
392
393
        if self.pp_size > 1:
            self.disable_overlap_schedule = True
            logger.warning(
                "Pipeline parallelism is incompatible with overlap schedule."
            )

394
395
396
        if self.enable_eplb and (self.expert_distribution_recorder_mode is None):
            self.expert_distribution_recorder_mode = "stat"
            logger.info(
397
                "EPLB is enabled. The expert_distribution_recorder_mode is automatically set."
398
399
400
401
402
403
404
            )

        if (self.enable_eplb or (self.init_expert_location is not None)) and (
            self.ep_dispatch_algorithm is None
        ):
            self.ep_dispatch_algorithm = "static"
            logger.info(
405
                "EPLB is enabled or init_expert_location is provided. ep_dispatch_algorithm is configured."
406
407
408
409
410
411
412
            )

        if self.enable_expert_distribution_metrics and (
            self.expert_distribution_recorder_mode is None
        ):
            self.expert_distribution_recorder_mode = "stat"

413
        if self.expert_distribution_recorder_buffer_size is None:
414
415
            if (x := self.eplb_rebalance_num_iterations) is not None:
                self.expert_distribution_recorder_buffer_size = x
416
417
418
            elif self.expert_distribution_recorder_mode is not None:
                self.expert_distribution_recorder_buffer_size = 1000

419
        # Speculative Decoding
420
421
422
423
        if self.speculative_algorithm == "NEXTN":
            # NEXTN shares the same implementation of EAGLE
            self.speculative_algorithm = "EAGLE"

Lianmin Zheng's avatar
Lianmin Zheng committed
424
        if self.speculative_algorithm in ("EAGLE", "EAGLE3"):
425
            if self.max_running_requests is None:
426
                self.max_running_requests = 48
427
            self.disable_overlap_schedule = True
Lianmin Zheng's avatar
Lianmin Zheng committed
428
            logger.warning(
429
                "Overlap scheduler is disabled because of using "
430
                "eagle speculative decoding."
431
            )
432
433
434
435
436
437
            if self.enable_mixed_chunk:
                self.enable_mixed_chunk = False
                logger.warning(
                    "Mixed chunked prefill is disabled because of using "
                    "eagle speculative decoding."
                )
438

439
440
441
            model_arch = get_model_arch(self)

            # Auto set draft_model_path DeepSeek-V3/R1
442
443
444
445
446
447
448
            if model_arch == "DeepseekV3ForCausalLM":
                if self.speculative_draft_model_path is None:
                    self.speculative_draft_model_path = self.model_path
                else:
                    logger.warning(
                        "DeepSeek MTP does not require setting speculative_draft_model_path."
                    )
449

450
451
452
453
454
455
456
457
458
459
            # Auto choose parameters
            if self.speculative_num_steps is None:
                assert (
                    self.speculative_eagle_topk is None
                    and self.speculative_num_draft_tokens is None
                )
                (
                    self.speculative_num_steps,
                    self.speculative_eagle_topk,
                    self.speculative_num_draft_tokens,
460
                ) = auto_choose_speculative_params(self)
461
462
463

            if self.page_size > 1 and self.speculative_eagle_topk > 1:
                self.speculative_eagle_topk = 1
Lianmin Zheng's avatar
Lianmin Zheng committed
464
                logger.warning(
465
466
467
468
469
470
471
                    "speculative_eagle_topk is adjusted to 1 when page_size > 1"
                )

            if (
                self.speculative_eagle_topk == 1
                and self.speculative_num_draft_tokens != self.speculative_num_steps + 1
            ):
Lianmin Zheng's avatar
Lianmin Zheng committed
472
                logger.warning(
473
474
475
                    "speculative_num_draft_tokens is adjusted to speculative_num_steps + 1 when speculative_eagle_topk == 1"
                )
                self.speculative_num_draft_tokens = self.speculative_num_steps + 1
476

477
            # The token generated from the verify step is counted.
478
            # If sepculative_num_steps >= speculative_num_draft_tokens, the additional tokens will definitely be discarded.
479
            # assert self.speculative_num_steps < self.speculative_num_draft_tokens
480

481
482
483
484
485
486
        # GGUF
        if (
            self.load_format == "auto" or self.load_format == "gguf"
        ) and check_gguf_file(self.model_path):
            self.quantization = self.load_format = "gguf"

487
488
489
        if is_remote_url(self.model_path):
            self.load_format = "remote"

490
491
492
493
        # AMD-specific Triton attention KV splits default number
        if is_hip():
            self.triton_attention_num_kv_splits = 16

Byron Hsu's avatar
Byron Hsu committed
494
495
496
        # PD disaggregation
        if self.disaggregation_mode == "prefill":
            self.disable_cuda_graph = True
497
            logger.warning("Cuda graph is disabled for prefill server")
Byron Hsu's avatar
Byron Hsu committed
498
499
        elif self.disaggregation_mode == "decode":
            self.disable_radix_cache = True
500
            logger.warning("KV cache is forced as chunk cache for decode server")
Byron Hsu's avatar
Byron Hsu committed
501

502
503
504
        os.environ["SGLANG_ENABLE_TORCH_COMPILE"] = (
            "1" if self.enable_torch_compile else "0"
        )
505
506
507
508
        # Set env var before grammar backends init
        os.environ["SGLANG_DISABLE_OUTLINES_DISK_CACHE"] = (
            "1" if self.disable_outlines_disk_cache else "0"
        )
509

Lianmin Zheng's avatar
Lianmin Zheng committed
510
511
    @staticmethod
    def add_cli_args(parser: argparse.ArgumentParser):
512
        # Model and port args
Lianmin Zheng's avatar
Lianmin Zheng committed
513
514
515
516
517
518
519
520
521
522
523
524
        parser.add_argument(
            "--model-path",
            type=str,
            help="The path of the model weights. This can be a local folder or a Hugging Face repo ID.",
            required=True,
        )
        parser.add_argument(
            "--tokenizer-path",
            type=str,
            default=ServerArgs.tokenizer_path,
            help="The path of the tokenizer.",
        )
Yuanhan Zhang's avatar
Yuanhan Zhang committed
525
526
527
528
529
530
        parser.add_argument(
            "--host", type=str, default=ServerArgs.host, help="The host of the server."
        )
        parser.add_argument(
            "--port", type=int, default=ServerArgs.port, help="The port of the server."
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
531
532
533
534
535
536
537
538
539
        parser.add_argument(
            "--tokenizer-mode",
            type=str,
            default=ServerArgs.tokenizer_mode,
            choices=["auto", "slow"],
            help="Tokenizer mode. 'auto' will use the fast "
            "tokenizer if available, and 'slow' will "
            "always use the slow tokenizer.",
        )
540
541
542
        parser.add_argument(
            "--skip-tokenizer-init",
            action="store_true",
543
            help="If set, skip init tokenizer and pass input_ids in generate request.",
544
        )
545
        parser.add_argument(
Lianmin Zheng's avatar
Lianmin Zheng committed
546
547
548
            "--load-format",
            type=str,
            default=ServerArgs.load_format,
549
550
551
552
553
554
            choices=[
                "auto",
                "pt",
                "safetensors",
                "npcache",
                "dummy",
555
                "sharded_state",
556
557
                "gguf",
                "bitsandbytes",
558
                "layered",
559
                "remote",
560
            ],
Lianmin Zheng's avatar
Lianmin Zheng committed
561
562
563
564
565
566
567
568
569
            help="The format of the model weights to load. "
            '"auto" will try to load the weights in the safetensors format '
            "and fall back to the pytorch bin format if safetensors format "
            "is not available. "
            '"pt" will load the weights in the pytorch bin format. '
            '"safetensors" will load the weights in the safetensors format. '
            '"npcache" will load the weights in pytorch format and store '
            "a numpy cache to speed up the loading. "
            '"dummy" will initialize the weights with random values, '
570
            "which is mainly for profiling."
571
572
            '"gguf" will load the weights in the gguf format. '
            '"bitsandbytes" will load the weights using bitsandbytes '
573
574
575
576
            "quantization."
            '"layered" loads weights layer by layer so that one can quantize a '
            "layer before loading another to make the peak memory envelope "
            "smaller.",
Lianmin Zheng's avatar
Lianmin Zheng committed
577
        )
578
579
580
581
582
        parser.add_argument(
            "--trust-remote-code",
            action="store_true",
            help="Whether or not to allow for custom models defined on the Hub in their own modeling files.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
583
        parser.add_argument(
Lianmin Zheng's avatar
Lianmin Zheng committed
584
            "--dtype",
Cody Yu's avatar
Cody Yu committed
585
            type=str,
Lianmin Zheng's avatar
Lianmin Zheng committed
586
            default=ServerArgs.dtype,
Ying Sheng's avatar
Ying Sheng committed
587
588
            choices=["auto", "half", "float16", "bfloat16", "float", "float32"],
            help="Data type for model weights and activations.\n\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
589
            '* "auto" will use FP16 precision for FP32 and FP16 models, and '
Ying Sheng's avatar
Ying Sheng committed
590
            "BF16 precision for BF16 models.\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
591
592
593
594
            '* "half" for FP16. Recommended for AWQ quantization.\n'
            '* "float16" is the same as "half".\n'
            '* "bfloat16" for a balance between precision and range.\n'
            '* "float" is shorthand for FP32 precision.\n'
Ying Sheng's avatar
Ying Sheng committed
595
596
            '* "float32" for FP32 precision.',
        )
597
598
599
600
        parser.add_argument(
            "--kv-cache-dtype",
            type=str,
            default=ServerArgs.kv_cache_dtype,
bjmsong's avatar
bjmsong committed
601
602
603
            choices=["auto", "fp8_e5m2", "fp8_e4m3"],
            help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" and "fp8_e4m3" is supported for CUDA 11.8+.',
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
604
605
606
607
        parser.add_argument(
            "--quantization",
            type=str,
            default=ServerArgs.quantization,
Ying Sheng's avatar
Ying Sheng committed
608
609
610
611
612
613
            choices=[
                "awq",
                "fp8",
                "gptq",
                "marlin",
                "gptq_marlin",
Ying Sheng's avatar
Ying Sheng committed
614
                "awq_marlin",
Ying Sheng's avatar
Ying Sheng committed
615
                "bitsandbytes",
616
                "gguf",
617
                "modelopt",
618
                "modelopt_fp4",
619
                "w8a8_int8",
HandH1998's avatar
HandH1998 committed
620
                "w8a8_fp8",
AniZpZ's avatar
AniZpZ committed
621
                "moe_wna16",
HandH1998's avatar
HandH1998 committed
622
                "qoq",
Ying Sheng's avatar
Ying Sheng committed
623
            ],
Lianmin Zheng's avatar
Lianmin Zheng committed
624
625
            help="The quantization method.",
        )
626
627
628
629
630
631
632
633
634
        parser.add_argument(
            "--quantization-param-path",
            type=nullable_str,
            default=None,
            help="Path to the JSON file containing the KV cache "
            "scaling factors. This should generally be supplied, when "
            "KV cache dtype is FP8. Otherwise, KV cache scaling factors "
            "default to 1.0, which may cause accuracy issues. ",
        )
635
636
637
638
639
640
641
642
643
        parser.add_argument(
            "--context-length",
            type=int,
            default=ServerArgs.context_length,
            help="The model's maximum context length. Defaults to None (will use the value from the model's config.json instead).",
        )
        parser.add_argument(
            "--device",
            type=str,
644
            default=ServerArgs.device,
645
            help="The device to use ('cuda', 'xpu', 'hpu', 'npu', 'cpu'). Defaults to auto-detection if not specified.",
646
        )
647
648
649
650
651
652
        parser.add_argument(
            "--served-model-name",
            type=str,
            default=ServerArgs.served_model_name,
            help="Override the model name returned by the v1/models endpoint in OpenAI API server.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
653
654
655
656
657
658
        parser.add_argument(
            "--chat-template",
            type=str,
            default=ServerArgs.chat_template,
            help="The buliltin chat template name or the path of the chat template file. This is only used for OpenAI-compatible API server.",
        )
659
660
661
662
663
664
        parser.add_argument(
            "--completion-template",
            type=str,
            default=ServerArgs.completion_template,
            help="The buliltin completion template name or the path of the completion template file. This is only used for OpenAI-compatible API server. only for code completion currently.",
        )
665
666
667
668
669
        parser.add_argument(
            "--is-embedding",
            action="store_true",
            help="Whether to use a CausalLM as an embedding model.",
        )
670
671
672
673
674
675
        parser.add_argument(
            "--enable-multimodal",
            default=ServerArgs.enable_multimodal,
            action="store_true",
            help="Enable the multimodal functionality for the served model. If the model being served is not multimodal, nothing will happen",
        )
676
677
678
679
680
681
682
683
        parser.add_argument(
            "--revision",
            type=str,
            default=None,
            help="The specific model version to use. It can be a branch "
            "name, a tag name, or a commit id. If unspecified, will use "
            "the default version.",
        )
684

685
        # Memory and scheduling
Lianmin Zheng's avatar
Lianmin Zheng committed
686
687
688
689
        parser.add_argument(
            "--mem-fraction-static",
            type=float,
            default=ServerArgs.mem_fraction_static,
690
            help="The fraction of the memory used for static allocation (model weights and KV cache memory pool). Use a smaller value if you see out-of-memory errors.",
Lianmin Zheng's avatar
Lianmin Zheng committed
691
        )
692
693
694
695
696
697
        parser.add_argument(
            "--max-running-requests",
            type=int,
            default=ServerArgs.max_running_requests,
            help="The maximum number of running requests.",
        )
698
699
700
701
        parser.add_argument(
            "--max-total-tokens",
            type=int,
            default=ServerArgs.max_total_tokens,
702
703
            help="The maximum number of tokens in the memory pool. If not specified, it will be automatically calculated based on the memory usage fraction. "
            "This option is typically used for development and debugging purposes.",
704
        )
705
706
707
708
        parser.add_argument(
            "--chunked-prefill-size",
            type=int,
            default=ServerArgs.chunked_prefill_size,
709
            help="The maximum number of tokens in a chunk for the chunked prefill. Setting this to -1 means disabling chunked prefill.",
710
711
712
713
714
715
716
        )
        parser.add_argument(
            "--max-prefill-tokens",
            type=int,
            default=ServerArgs.max_prefill_tokens,
            help="The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model's maximum context length.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
717
        parser.add_argument(
718
            "--schedule-policy",
Lianmin Zheng's avatar
Lianmin Zheng committed
719
            type=str,
720
            default=ServerArgs.schedule_policy,
Liangsheng Yin's avatar
Liangsheng Yin committed
721
            choices=["lpm", "random", "fcfs", "dfs-weight"],
722
            help="The scheduling policy of the requests.",
Lianmin Zheng's avatar
Lianmin Zheng committed
723
        )
724
725
726
727
        parser.add_argument(
            "--schedule-conservativeness",
            type=float,
            default=ServerArgs.schedule_conservativeness,
728
            help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see requests being retracted frequently.",
729
        )
730
731
732
733
        parser.add_argument(
            "--cpu-offload-gb",
            type=int,
            default=ServerArgs.cpu_offload_gb,
734
            help="How many GBs of RAM to reserve for CPU offloading.",
735
        )
736
737
738
739
740
741
        parser.add_argument(
            "--page-size",
            type=int,
            default=ServerArgs.page_size,
            help="The number of tokens in a page.",
        )
742
743
744
745
746
747
748
749
750
751
752
753
        parser.add_argument(
            "--impl",
            type=str,
            default=ServerArgs.impl,
            help="Which implementation of the model to use.\n\n"
            '* "auto" will try to use the SGLang implementation if it exists '
            "and fall back to the Transformers implementation if no SGLang "
            "implementation is available.\n"
            '* "sglang" will use the SGLang model implementation.\n'
            '* "transformers" will use the Transformers model '
            "implementation.\n",
        )
754

755
        # Other runtime options
Lianmin Zheng's avatar
Lianmin Zheng committed
756
        parser.add_argument(
757
            "--tensor-parallel-size",
Lianmin Zheng's avatar
Lianmin Zheng committed
758
            "--tp-size",
Lianmin Zheng's avatar
Lianmin Zheng committed
759
            type=int,
Lianmin Zheng's avatar
Lianmin Zheng committed
760
            default=ServerArgs.tp_size,
761
            help="The tensor parallelism size.",
762
        )
763
764
765
766
767
768
769
770
771
772
773
774
775
        parser.add_argument(
            "--pipeline-parallel-size",
            "--pp-size",
            type=int,
            default=ServerArgs.pp_size,
            help="The pipeline parallelism size.",
        )
        parser.add_argument(
            "--max-micro-batch-size",
            type=int,
            default=ServerArgs.max_micro_batch_size,
            help="The maximum micro batch size in pipeline parallelism.",
        )
776
777
778
        parser.add_argument(
            "--stream-interval",
            type=int,
Lianmin Zheng's avatar
Lianmin Zheng committed
779
            default=ServerArgs.stream_interval,
780
            help="The interval (or buffer size) for streaming in terms of the token length. A smaller value makes streaming smoother, while a larger value makes the throughput higher",
781
        )
782
783
784
785
786
        parser.add_argument(
            "--stream-output",
            action="store_true",
            help="Whether to output as a sequence of disjoint segments.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
787
788
789
790
        parser.add_argument(
            "--random-seed",
            type=int,
            default=ServerArgs.random_seed,
791
            help="The random seed.",
Lianmin Zheng's avatar
Lianmin Zheng committed
792
        )
793
794
795
796
797
798
        parser.add_argument(
            "--constrained-json-whitespace-pattern",
            type=str,
            default=ServerArgs.constrained_json_whitespace_pattern,
            help=r"Regex pattern for syntactic whitespaces allowed in JSON constrained output. For example, to allow the model generate consecutive whitespaces, set the pattern to [\n\t ]*",
        )
799
800
801
802
803
804
        parser.add_argument(
            "--watchdog-timeout",
            type=float,
            default=ServerArgs.watchdog_timeout,
            help="Set watchdog timeout in seconds. If a forward batch takes longer than this, the server will crash to prevent hanging.",
        )
805
806
807
808
809
810
        parser.add_argument(
            "--dist-timeout",
            type=int,
            default=ServerArgs.dist_timeout,
            help="Set timeout for torch.distributed initialization.",
        )
811
812
813
814
        parser.add_argument(
            "--download-dir",
            type=str,
            default=ServerArgs.download_dir,
815
            help="Model download directory for huggingface.",
816
        )
817
818
819
820
821
822
        parser.add_argument(
            "--base-gpu-id",
            type=int,
            default=ServerArgs.base_gpu_id,
            help="The base GPU ID to start allocating GPUs from. Useful when running multiple instances on the same machine.",
        )
823
824
825
826
827
828
        parser.add_argument(
            "--gpu-id-step",
            type=int,
            default=ServerArgs.gpu_id_step,
            help="The delta between consecutive GPU IDs that are used. For example, setting it to 2 will use GPU 0,2,4,...",
        )
829
830

        # Logging
Lianmin Zheng's avatar
Lianmin Zheng committed
831
832
833
834
        parser.add_argument(
            "--log-level",
            type=str,
            default=ServerArgs.log_level,
835
            help="The logging level of all loggers.",
Lianmin Zheng's avatar
Lianmin Zheng committed
836
        )
837
        parser.add_argument(
838
839
840
841
            "--log-level-http",
            type=str,
            default=ServerArgs.log_level_http,
            help="The logging level of HTTP server. If not set, reuse --log-level by default.",
842
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
843
        parser.add_argument(
844
            "--log-requests",
Lianmin Zheng's avatar
Lianmin Zheng committed
845
            action="store_true",
846
847
848
849
850
851
852
853
            help="Log metadata, inputs, outputs of all requests. The verbosity is decided by --log-requests-level",
        )
        parser.add_argument(
            "--log-requests-level",
            type=int,
            default=0,
            help="0: Log metadata. 1. Log metadata and partial input/output. 2. Log every input/output.",
            choices=[0, 1, 2],
Lianmin Zheng's avatar
Lianmin Zheng committed
854
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
855
856
857
        parser.add_argument(
            "--show-time-cost",
            action="store_true",
Ying Sheng's avatar
Ying Sheng committed
858
            help="Show time cost of custom marks.",
Lianmin Zheng's avatar
Lianmin Zheng committed
859
        )
860
861
862
863
864
        parser.add_argument(
            "--enable-metrics",
            action="store_true",
            help="Enable log prometheus metrics.",
        )
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
        parser.add_argument(
            "--bucket-time-to-first-token",
            type=float,
            nargs="+",
            default=ServerArgs.bucket_time_to_first_token,
            help="The buckets of time to first token, specified as a list of floats.",
        )
        parser.add_argument(
            "--bucket-inter-token-latency",
            type=float,
            nargs="+",
            default=ServerArgs.bucket_inter_token_latency,
            help="The buckets of inter-token latency, specified as a list of floats.",
        )
        parser.add_argument(
            "--bucket-e2e-request-latency",
            type=float,
            nargs="+",
            default=ServerArgs.bucket_e2e_request_latency,
            help="The buckets of end-to-end request latency, specified as a list of floats.",
        )
        parser.add_argument(
            "--collect-tokens-histogram",
            action="store_true",
            default=ServerArgs.collect_tokens_histogram,
            help="Collect prompt/generation tokens histogram.",
        )
892
893
894
895
896
897
        parser.add_argument(
            "--kv-events-config",
            type=str,
            default=None,
            help="Config in json format for NVIDIA dynamo KV event publishing. Publishing will be enabled if this flag is used.",
        )
898
899
900
901
        parser.add_argument(
            "--decode-log-interval",
            type=int,
            default=ServerArgs.decode_log_interval,
902
            help="The log interval of decode batch.",
903
        )
904
905
906
907
908
909
        parser.add_argument(
            "--enable-request-time-stats-logging",
            action="store_true",
            default=ServerArgs.enable_request_time_stats_logging,
            help="Enable per request time stats logging",
        )
910

911
        # API related
Liangsheng Yin's avatar
Liangsheng Yin committed
912
913
914
915
        parser.add_argument(
            "--api-key",
            type=str,
            default=ServerArgs.api_key,
916
            help="Set API key of the server. It is also used in the OpenAI API compatible server.",
Liangsheng Yin's avatar
Liangsheng Yin committed
917
        )
918
        parser.add_argument(
919
            "--file-storage-path",
920
            type=str,
921
            default=ServerArgs.file_storage_path,
922
923
            help="The path of the file storage in backend.",
        )
924
925
926
927
928
        parser.add_argument(
            "--enable-cache-report",
            action="store_true",
            help="Return number of cached tokens in usage.prompt_tokens_details for each openai request.",
        )
Xihuai Wang's avatar
Xihuai Wang committed
929
930
931
932
933
934
935
        parser.add_argument(
            "--reasoning-parser",
            type=str,
            choices=list(ReasoningParser.DetectorMap.keys()),
            default=ServerArgs.reasoning_parser,
            help=f"Specify the parser for reasoning models, supported parsers are: {list(ReasoningParser.DetectorMap.keys())}.",
        )
936
937
938
939
940
941
942
        parser.add_argument(
            "--tool-call-parser",
            type=str,
            choices=["qwen25", "mistral", "llama3", "deepseekv3", "pythonic"],
            default=ServerArgs.tool_call_parser,
            help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', 'llama3', 'deepseekv3', and 'pythonic'.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
943

944
945
        # Data parallelism
        parser.add_argument(
946
            "--data-parallel-size",
947
948
949
            "--dp-size",
            type=int,
            default=ServerArgs.dp_size,
950
            help="The data parallelism size.",
951
952
953
954
955
        )
        parser.add_argument(
            "--load-balance-method",
            type=str,
            default=ServerArgs.load_balance_method,
956
            help="The load balancing strategy for data parallelism.",
957
958
959
960
961
            choices=[
                "round_robin",
                "shortest_queue",
            ],
        )
962

963
        # Multi-node distributed serving
964
        parser.add_argument(
965
            "--dist-init-addr",
966
            "--nccl-init-addr",  # For backward compatibility. This will be removed in the future.
967
            type=str,
968
            help="The host address for initializing distributed backend (e.g., `192.168.0.2:25000`).",
969
970
        )
        parser.add_argument(
Liangsheng Yin's avatar
Liangsheng Yin committed
971
            "--nnodes", type=int, default=ServerArgs.nnodes, help="The number of nodes."
972
        )
973
974
975
        parser.add_argument(
            "--node-rank", type=int, default=ServerArgs.node_rank, help="The node rank."
        )
976

Lianmin Zheng's avatar
Lianmin Zheng committed
977
978
979
980
981
982
983
        # Model override args
        parser.add_argument(
            "--json-model-override-args",
            type=str,
            help="A dictionary in JSON string format used to override default model configurations.",
            default=ServerArgs.json_model_override_args,
        )
984
985
986
987
988
        parser.add_argument(
            "--preferred-sampling-params",
            type=str,
            help="json-formatted sampling settings that will be returned in /get_model_info",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
989

990
991
992
993
994
995
996
        # LoRA
        parser.add_argument(
            "--lora-paths",
            type=str,
            nargs="*",
            default=None,
            action=LoRAPathAction,
997
            help="The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}.",
998
999
1000
1001
1002
        )
        parser.add_argument(
            "--max-loras-per-batch",
            type=int,
            default=8,
1003
1004
1005
1006
1007
1008
1009
            help="Maximum number of adapters for a running batch, include base-only request.",
        )
        parser.add_argument(
            "--lora-backend",
            type=str,
            default="triton",
            help="Choose the kernel backend for multi-LoRA serving.",
1010
1011
1012
        )

        # Kernel backend
1013
1014
1015
        parser.add_argument(
            "--attention-backend",
            type=str,
1016
            choices=[
1017
                "aiter",
1018
                "cutlass_mla",
1019
                "fa3",
1020
                "flashinfer",
1021
                "flashmla",
1022
                "intel_amx",
1023
1024
                "torch_native",
                "triton",
1025
            ],
1026
1027
1028
            default=ServerArgs.attention_backend,
            help="Choose the kernels for attention layers.",
        )
1029
1030
1031
1032
1033
1034
1035
        parser.add_argument(
            "--sampling-backend",
            type=str,
            choices=["flashinfer", "pytorch"],
            default=ServerArgs.sampling_backend,
            help="Choose the kernels for sampling layers.",
        )
1036
1037
1038
        parser.add_argument(
            "--grammar-backend",
            type=str,
1039
            choices=["xgrammar", "outlines", "llguidance", "none"],
1040
            default=ServerArgs.grammar_backend,
Lianmin Zheng's avatar
Lianmin Zheng committed
1041
            help="Choose the backend for grammar-guided decoding.",
1042
        )
1043

1044
1045
1046
1047
        # Speculative decoding
        parser.add_argument(
            "--speculative-algorithm",
            type=str,
James Liu's avatar
James Liu committed
1048
            choices=["EAGLE", "EAGLE3", "NEXTN"],
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
            help="Speculative algorithm.",
        )
        parser.add_argument(
            "--speculative-draft-model-path",
            type=str,
            help="The path of the draft model weights. This can be a local folder or a Hugging Face repo ID.",
        )
        parser.add_argument(
            "--speculative-num-steps",
            type=int,
            help="The number of steps sampled from draft model in Speculative Decoding.",
            default=ServerArgs.speculative_num_steps,
        )
        parser.add_argument(
            "--speculative-eagle-topk",
            type=int,
1065
            help="The number of tokens sampled from the draft model in eagle2 each step.",
1066
1067
            default=ServerArgs.speculative_eagle_topk,
        )
1068
1069
1070
        parser.add_argument(
            "--speculative-num-draft-tokens",
            type=int,
1071
            help="The number of tokens sampled from the draft model in Speculative Decoding.",
1072
1073
            default=ServerArgs.speculative_num_draft_tokens,
        )
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
        parser.add_argument(
            "--speculative-accept-threshold-single",
            type=float,
            help="Accept a draft token if its probability in the target model is greater than this threshold.",
            default=ServerArgs.speculative_accept_threshold_single,
        )
        parser.add_argument(
            "--speculative-accept-threshold-acc",
            type=float,
            help="The accept probability of a draft token is raised from its target probability p to min(1, p / threshold_acc).",
            default=ServerArgs.speculative_accept_threshold_acc,
        )
1086
1087
1088
1089
1090
1091
        parser.add_argument(
            "--speculative-token-map",
            type=str,
            help="The path of the draft model's small vocab table.",
            default=ServerArgs.speculative_token_map,
        )
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
        parser.add_argument(
            "--mm-attention-backend",
            type=str,
            choices=["sdpa", "fa3", "triton_attn"],
            default=ServerArgs.mm_attention_backend,
            help="Set multimodal attention backend.",
        )

        # Expert parallelism
        parser.add_argument(
            "--expert-parallel-size",
            "--ep-size",
            type=int,
            default=ServerArgs.ep_size,
            help="The expert parallelism size.",
        )
        parser.add_argument(
            "--enable-ep-moe",
            action="store_true",
            help="Enabling expert parallelism for moe. The ep size is equal to the tp size.",
        )
        parser.add_argument(
            "--enable-deepep-moe",
            action="store_true",
            help="Enabling DeepEP MoE implementation for EP MoE.",
        )
        parser.add_argument(
            "--deepep-mode",
            type=str,
            choices=["normal", "low_latency", "auto"],
            default="auto",
            help="Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch.",
        )
        parser.add_argument(
            "--ep-num-redundant-experts",
            type=int,
            default=ServerArgs.ep_num_redundant_experts,
            help="Allocate this number of redundant experts in expert parallel.",
        )
        parser.add_argument(
            "--ep-dispatch-algorithm",
            type=str,
            default=ServerArgs.ep_dispatch_algorithm,
            help="The algorithm to choose ranks for redundant experts in expert parallel.",
        )
        parser.add_argument(
            "--init-expert-location",
            type=str,
            default=ServerArgs.init_expert_location,
            help="Initial location of EP experts.",
        )
        parser.add_argument(
            "--enable-eplb",
            action="store_true",
            help="Enable EPLB algorithm",
        )
        parser.add_argument(
            "--eplb-algorithm",
            type=str,
            default=ServerArgs.eplb_algorithm,
            help="Chosen EPLB algorithm",
        )
        parser.add_argument(
            "--eplb-rebalance-num-iterations",
            type=int,
            default=ServerArgs.eplb_rebalance_num_iterations,
            help="Number of iterations to automatically trigger a EPLB re-balance.",
        )
        parser.add_argument(
            "--eplb-rebalance-layers-per-chunk",
            type=int,
            default=ServerArgs.eplb_rebalance_layers_per_chunk,
            help="Number of layers to rebalance per forward pass.",
        )
        parser.add_argument(
            "--expert-distribution-recorder-mode",
            type=str,
            default=ServerArgs.expert_distribution_recorder_mode,
            help="Mode of expert distribution recorder.",
        )
        parser.add_argument(
            "--expert-distribution-recorder-buffer-size",
            type=int,
            default=ServerArgs.expert_distribution_recorder_buffer_size,
            help="Circular buffer size of expert distribution recorder. Set to -1 to denote infinite buffer.",
        )
        parser.add_argument(
            "--enable-expert-distribution-metrics",
            action="store_true",
            help="Enable logging metrics for expert balancedness",
        )
        parser.add_argument(
            "--deepep-config",
            type=str,
            default=ServerArgs.deepep_config,
            help="Tuned DeepEP config suitable for your own cluster. It can be either a string with JSON content or a file path.",
        )
        parser.add_argument(
            "--moe-dense-tp-size",
            type=int,
            default=ServerArgs.moe_dense_tp_size,
            help="TP size for MoE dense MLP layers. This flag is useful when, with large TP size, there are errors caused by weights in MLP layers having dimension smaller than the min dimension GEMM supports.",
        )
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232

        # Double Sparsity
        parser.add_argument(
            "--enable-double-sparsity",
            action="store_true",
            help="Enable double sparsity attention",
        )
        parser.add_argument(
            "--ds-channel-config-path",
            type=str,
            default=ServerArgs.ds_channel_config_path,
            help="The path of the double sparsity channel config",
        )
        parser.add_argument(
            "--ds-heavy-channel-num",
            type=int,
            default=ServerArgs.ds_heavy_channel_num,
            help="The number of heavy channels in double sparsity attention",
        )
        parser.add_argument(
            "--ds-heavy-token-num",
            type=int,
            default=ServerArgs.ds_heavy_token_num,
            help="The number of heavy tokens in double sparsity attention",
        )
        parser.add_argument(
            "--ds-heavy-channel-type",
            type=str,
            default=ServerArgs.ds_heavy_channel_type,
            help="The type of heavy channels in double sparsity attention",
        )
        parser.add_argument(
            "--ds-sparse-decode-threshold",
            type=int,
            default=ServerArgs.ds_sparse_decode_threshold,
            help="The type of heavy channels in double sparsity attention",
        )

1233
        # Optimization/debug options
Liangsheng Yin's avatar
Liangsheng Yin committed
1234
        parser.add_argument(
Lianmin Zheng's avatar
Lianmin Zheng committed
1235
            "--disable-radix-cache",
Liangsheng Yin's avatar
Liangsheng Yin committed
1236
            action="store_true",
Ying Sheng's avatar
Ying Sheng committed
1237
            help="Disable RadixAttention for prefix caching.",
Liangsheng Yin's avatar
Liangsheng Yin committed
1238
        )
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
        parser.add_argument(
            "--cuda-graph-max-bs",
            type=int,
            default=ServerArgs.cuda_graph_max_bs,
            help="Set the maximum batch size for cuda graph. It will extend the cuda graph capture batch size to this value.",
        )
        parser.add_argument(
            "--cuda-graph-bs",
            type=int,
            nargs="+",
            help="Set the list of batch sizes for cuda graph.",
        )
1251
1252
1253
        parser.add_argument(
            "--disable-cuda-graph",
            action="store_true",
1254
            help="Disable cuda graph.",
1255
        )
1256
        parser.add_argument(
1257
1258
            "--disable-cuda-graph-padding",
            action="store_true",
1259
            help="Disable cuda graph when padding is needed. Still uses cuda graph when padding is not needed.",
1260
        )
1261
1262
1263
1264
1265
        parser.add_argument(
            "--enable-profile-cuda-graph",
            action="store_true",
            help="Enable profiling of cuda graph capture.",
        )
1266
1267
1268
1269
1270
        parser.add_argument(
            "--enable-nccl-nvls",
            action="store_true",
            help="Enable NCCL NVLS for prefill heavy requests when available.",
        )
1271
1272
1273
1274
1275
        parser.add_argument(
            "--enable-tokenizer-batch-encode",
            action="store_true",
            help="Enable batch tokenization for improved performance when processing multiple text inputs. Do not use with image inputs, pre-tokenized input_ids, or input_embeds.",
        )
1276
        parser.add_argument(
1277
            "--disable-outlines-disk-cache",
1278
            action="store_true",
1279
            help="Disable disk cache of outlines to avoid possible crashes related to file system or high concurrency.",
1280
        )
1281
1282
1283
1284
1285
        parser.add_argument(
            "--disable-custom-all-reduce",
            action="store_true",
            help="Disable the custom all-reduce kernel and fall back to NCCL.",
        )
1286
1287
1288
1289
1290
        parser.add_argument(
            "--enable-mscclpp",
            action="store_true",
            help="Enable using mscclpp for small messages for all-reduce kernel and fall back to NCCL.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1291
        parser.add_argument(
1292
            "--disable-overlap-schedule",
Lianmin Zheng's avatar
Lianmin Zheng committed
1293
            action="store_true",
1294
            help="Disable the overlap scheduler, which overlaps the CPU scheduler with GPU model worker.",
Lianmin Zheng's avatar
Lianmin Zheng committed
1295
        )
1296
1297
1298
1299
1300
        parser.add_argument(
            "--disable-overlap-cg-plan",
            action="store_true",
            help="Disable the overlap optimization for cudagraph preparation in eagle verify.",
        )
1301
1302
1303
        parser.add_argument(
            "--enable-mixed-chunk",
            action="store_true",
1304
            help="Enabling mixing prefill and decode in a batch when using chunked prefill.",
1305
        )
Ke Bao's avatar
Ke Bao committed
1306
1307
1308
        parser.add_argument(
            "--enable-dp-attention",
            action="store_true",
1309
            help="Enabling data parallelism for attention and tensor parallelism for FFN. The dp size should be equal to the tp size. Currently DeepSeek-V2 and Qwen 2/3 MoE models are supported.",
Ke Bao's avatar
Ke Bao committed
1310
        )
1311
1312
1313
1314
1315
        parser.add_argument(
            "--enable-dp-lm-head",
            action="store_true",
            help="Enable vocabulary parallel across the attention TP group to avoid all-gather across DP groups, optimizing performance under DP attention.",
        )
1316
1317
1318
1319
1320
        parser.add_argument(
            "--enable-two-batch-overlap",
            action="store_true",
            help="Enabling two micro batches to overlap.",
        )
1321
1322
1323
        parser.add_argument(
            "--enable-torch-compile",
            action="store_true",
1324
1325
            help="Optimize the model with torch.compile. Experimental feature.",
        )
1326
        parser.add_argument(
1327
            "--torch-compile-max-bs",
1328
            type=int,
1329
            default=ServerArgs.torch_compile_max_bs,
1330
1331
            help="Set the maximum batch size when using torch compile.",
        )
1332
1333
1334
1335
        parser.add_argument(
            "--torchao-config",
            type=str,
            default=ServerArgs.torchao_config,
1336
            help="Optimize the model with torchao. Experimental feature. Current choices are: int8dq, int8wo, int4wo-<group_size>, fp8wo, fp8dq-per_tensor, fp8dq-per_row",
1337
        )
1338
1339
1340
1341
1342
        parser.add_argument(
            "--enable-nan-detection",
            action="store_true",
            help="Enable the NaN detection for debugging purposes.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1343
        parser.add_argument(
1344
            "--enable-p2p-check",
Lianmin Zheng's avatar
Lianmin Zheng committed
1345
            action="store_true",
1346
            help="Enable P2P check for GPU access, otherwise the p2p access is allowed by default.",
Lianmin Zheng's avatar
Lianmin Zheng committed
1347
        )
1348
        parser.add_argument(
1349
            "--triton-attention-reduce-in-fp32",
1350
            action="store_true",
1351
            help="Cast the intermediate attention results to fp32 to avoid possible crashes related to fp16."
1352
            "This only affects Triton attention kernels.",
1353
        )
1354
1355
1356
1357
1358
1359
        parser.add_argument(
            "--triton-attention-num-kv-splits",
            type=int,
            default=ServerArgs.triton_attention_num_kv_splits,
            help="The number of KV splits in flash decoding Triton kernel. Larger value is better in longer context scenarios. The default value is 8.",
        )
1360
1361
1362
1363
1364
1365
1366
1367
        parser.add_argument(
            "--num-continuous-decode-steps",
            type=int,
            default=ServerArgs.num_continuous_decode_steps,
            help="Run multiple continuous decoding steps to reduce scheduling overhead. "
            "This can potentially increase throughput but may also increase time-to-first-token latency. "
            "The default value is 1, meaning only run one decoding step at a time.",
        )
1368
1369
1370
1371
1372
        parser.add_argument(
            "--delete-ckpt-after-loading",
            action="store_true",
            help="Delete the model checkpoint after loading the model.",
        )
1373
1374
1375
1376
1377
        parser.add_argument(
            "--enable-memory-saver",
            action="store_true",
            help="Allow saving memory using release_memory_occupation and resume_memory_occupation",
        )
1378
1379
1380
1381
1382
        parser.add_argument(
            "--allow-auto-truncate",
            action="store_true",
            help="Allow automatically truncating requests that exceed the maximum input length instead of returning an error.",
        )
1383
1384
1385
1386
1387
        parser.add_argument(
            "--enable-custom-logit-processor",
            action="store_true",
            help="Enable users to pass custom logit processors to the server (disabled by default for security)",
        )
1388
1389
1390
1391
1392
        parser.add_argument(
            "--enable-hierarchical-cache",
            action="store_true",
            help="Enable hierarchical cache",
        )
1393
1394
1395
1396
1397
1398
        parser.add_argument(
            "--hicache-ratio",
            type=float,
            default=ServerArgs.hicache_ratio,
            help="The ratio of the size of host KV cache memory pool to the size of device pool.",
        )
Zhiqiang Xie's avatar
Zhiqiang Xie committed
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
        parser.add_argument(
            "--hicache-size",
            type=int,
            default=ServerArgs.hicache_size,
            help="The size of host KV cache memory pool in gigabytes, which will override the hicache_ratio if set.",
        )
        parser.add_argument(
            "--hicache-write-policy",
            type=str,
            choices=["write_back", "write_through", "write_through_selective"],
            default=ServerArgs.hicache_write_policy,
            help="The write policy of hierarchical cache.",
        )
1412
        parser.add_argument(
1413
            "--flashinfer-mla-disable-ragged",
1414
            action="store_true",
1415
            help="Not using ragged prefill wrapper when running flashinfer mla",
1416
        )
1417
        parser.add_argument(
1418
1419
1420
            "--disable-shared-experts-fusion",
            action="store_true",
            help="Disable shared experts fusion optimization for deepseek v3/r1.",
1421
        )
1422
1423
1424
1425
1426
        parser.add_argument(
            "--disable-chunked-prefix-cache",
            action="store_true",
            help="Disable chunked prefix cache feature for deepseek, which should save overhead for short sequences.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1427
1428
1429
1430
1431
        parser.add_argument(
            "--disable-fast-image-processor",
            action="store_true",
            help="Adopt base image processor instead of fast image processor.",
        )
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
        parser.add_argument(
            "--warmups",
            type=str,
            required=False,
            help="Specify custom warmup functions (csv) to run before server starts eg. --warmups=warmup_name1,warmup_name2 "
            "will run the functions `warmup_name1` and `warmup_name2` specified in warmup.py before the server starts listening for requests",
        )

        # Debug tensor dumps
        parser.add_argument(
            "--debug-tensor-dump-output-folder",
            type=str,
            default=ServerArgs.debug_tensor_dump_output_folder,
            help="The output folder for dumping tensors.",
        )
        parser.add_argument(
            "--debug-tensor-dump-input-file",
            type=str,
            default=ServerArgs.debug_tensor_dump_input_file,
            help="The input filename for dumping tensors",
        )
        parser.add_argument(
            "--debug-tensor-dump-inject",
            type=str,
            default=ServerArgs.debug_tensor_dump_inject,
            help="Inject the outputs from jax as the input of every layer.",
        )
1459
1460
1461
1462
1463
        parser.add_argument(
            "--debug-tensor-dump-prefill-only",
            action="store_true",
            help="Only dump the tensors for prefill requests (i.e. batch size > 1).",
        )
1464

Byron Hsu's avatar
Byron Hsu committed
1465
1466
1467
1468
1469
1470
1471
1472
        # Disaggregation
        parser.add_argument(
            "--disaggregation-mode",
            type=str,
            default="null",
            choices=["null", "prefill", "decode"],
            help='Only used for PD disaggregation. "prefill" for prefill-only server, and "decode" for decode-only server. If not specified, it is not PD disaggregated',
        )
1473
1474
1475
1476
        parser.add_argument(
            "--disaggregation-transfer-backend",
            type=str,
            default=ServerArgs.disaggregation_transfer_backend,
1477
            choices=["mooncake", "nixl"],
1478
1479
            help="The backend for disaggregation transfer. Default is mooncake.",
        )
1480
1481
1482
1483
1484
1485
        parser.add_argument(
            "--disaggregation-bootstrap-port",
            type=int,
            default=ServerArgs.disaggregation_bootstrap_port,
            help="Bootstrap server port on the prefill server. Default is 8998.",
        )
1486
1487
1488
1489
        parser.add_argument(
            "--disaggregation-ib-device",
            type=str,
            default=ServerArgs.disaggregation_ib_device,
1490
1491
1492
            help="The InfiniBand devices for disaggregation transfer, accepts single device (e.g., --disaggregation-ib-device mlx5_0) "
            "or multiple comma-separated devices (e.g., --disaggregation-ib-device mlx5_0,mlx5_1). "
            "Default is None, which triggers automatic device detection when mooncake backend is enabled.",
1493
        )
1494
1495
1496
1497
1498
1499
        parser.add_argument(
            "--num-reserved-decode-tokens",
            type=int,
            default=ServerArgs.num_reserved_decode_tokens,
            help="Number of decode tokens that will have memory reserved when adding new request to the running batch.",
        )
1500
1501
1502
1503
1504
1505
        parser.add_argument(
            "--pdlb-url",
            type=str,
            default=None,
            help="The URL of the PD disaggregation load balancer. If set, the prefill/decode server will register with the load balancer.",
        )
Byron Hsu's avatar
Byron Hsu committed
1506

Lianmin Zheng's avatar
Lianmin Zheng committed
1507
1508
    @classmethod
    def from_cli_args(cls, args: argparse.Namespace):
1509
        args.tp_size = args.tensor_parallel_size
1510
        args.pp_size = args.pipeline_parallel_size
1511
        args.dp_size = args.data_parallel_size
xiaobochen's avatar
xiaobochen committed
1512
        args.ep_size = args.expert_parallel_size
Lianmin Zheng's avatar
Lianmin Zheng committed
1513
1514
1515
1516
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        return cls(**{attr: getattr(args, attr) for attr in attrs})

    def url(self):
1517
        if is_valid_ipv6_address(self.host):
1518
1519
1520
            return f"http://[{self.host}]:{self.port}"
        else:
            return f"http://{self.host}:{self.port}"
Lianmin Zheng's avatar
Lianmin Zheng committed
1521

1522
1523
    def check_server_args(self):
        assert (
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
            self.tp_size * self.pp_size
        ) % self.nnodes == 0, "tp_size must be divisible by number of nodes"

        # FIXME pp constraints
        if self.pp_size > 1:
            assert (
                self.disable_overlap_schedule
                and self.speculative_algorithm is None
                and not self.enable_mixed_chunk
            ), "Pipeline parallelism is not compatible with overlap schedule, speculative decoding, mixed chunked prefill."

1535
        assert not (
1536
1537
            self.dp_size > 1 and self.nnodes != 1 and not self.enable_dp_attention
        ), "multi-node data parallel is not supported unless dp attention!"
1538
1539
1540
1541
        assert (
            self.max_loras_per_batch > 0
            # FIXME
            and (self.lora_paths is None or self.disable_radix_cache)
1542
        ), "compatibility of lora and radix attention is in progress"
1543
        assert self.base_gpu_id >= 0, "base_gpu_id must be non-negative"
1544
        assert self.gpu_id_step >= 1, "gpu_id_step must be positive"
1545

1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
        if isinstance(self.lora_paths, list):
            lora_paths = self.lora_paths
            self.lora_paths = {}
            for lora_path in lora_paths:
                if "=" in lora_path:
                    name, path = lora_path.split("=", 1)
                    self.lora_paths[name] = path
                else:
                    self.lora_paths[lora_path] = lora_path

Lianmin Zheng's avatar
Lianmin Zheng committed
1556

Lianmin Zheng's avatar
Lianmin Zheng committed
1557
def prepare_server_args(argv: List[str]) -> ServerArgs:
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
    """
    Prepare the server arguments from the command line arguments.

    Args:
        args: The command line arguments. Typically, it should be `sys.argv[1:]`
            to ensure compatibility with `parse_args` when no arguments are passed.

    Returns:
        The server arguments.
    """
    parser = argparse.ArgumentParser()
    ServerArgs.add_cli_args(parser)
Lianmin Zheng's avatar
Lianmin Zheng committed
1570
    raw_args = parser.parse_args(argv)
1571
1572
1573
1574
    server_args = ServerArgs.from_cli_args(raw_args)
    return server_args


1575
1576
1577
ZMQ_TCP_PORT_DELTA = 233


Lianmin Zheng's avatar
Lianmin Zheng committed
1578
1579
@dataclasses.dataclass
class PortArgs:
1580
1581
1582
1583
1584
1585
    # The ipc filename for tokenizer to receive inputs from detokenizer (zmq)
    tokenizer_ipc_name: str
    # The ipc filename for scheduler (rank 0) to receive inputs from tokenizer (zmq)
    scheduler_input_ipc_name: str
    # The ipc filename for detokenizer to receive inputs from scheduler (zmq)
    detokenizer_ipc_name: str
1586

1587
1588
    # The port for nccl initialization (torch.dist)
    nccl_port: int
1589

1590
1591
1592
    # The ipc filename for rpc call between Engine and Scheduler
    rpc_ipc_name: str

1593
    @staticmethod
1594
    def init_new(server_args, dp_rank: Optional[int] = None) -> "PortArgs":
1595
        port = server_args.port + random.randint(100, 1000)
1596
1597
1598
        while True:
            if is_port_available(port):
                break
TianYu GUO's avatar
TianYu GUO committed
1599
1600
1601
1602
            if port < 60000:
                port += 42
            else:
                port -= 43
1603

1604
1605
1606
1607
1608
1609
1610
        if not server_args.enable_dp_attention:
            # Normal case, use IPC within a single node
            return PortArgs(
                tokenizer_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
                scheduler_input_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
                detokenizer_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
                nccl_port=port,
1611
                rpc_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
1612
1613
1614
1615
1616
            )
        else:
            # DP attention. Use TCP + port to handle both single-node and multi-node.
            if server_args.nnodes == 1 and server_args.dist_init_addr is None:
                dist_init_addr = ("127.0.0.1", server_args.port + ZMQ_TCP_PORT_DELTA)
Vincent's avatar
Vincent committed
1617
1618
1619
            elif server_args.dist_init_addr.startswith("["):  # ipv6 address
                port_num, host = configure_ipv6(server_args.dist_init_addr)
                dist_init_addr = (host, str(port_num))
1620
1621
            else:
                dist_init_addr = server_args.dist_init_addr.split(":")
Vincent's avatar
Vincent committed
1622

1623
1624
1625
1626
1627
1628
1629
1630
            assert (
                len(dist_init_addr) == 2
            ), "please provide --dist-init-addr as host:port of head node"

            dist_init_host, dist_init_port = dist_init_addr
            port_base = int(dist_init_port) + 1
            if dp_rank is None:
                scheduler_input_port = (
1631
                    port_base + 3
1632
                )  # TokenizerManager to DataParallelController
1633
            else:
1634
                scheduler_input_port = port_base + 3 + 1 + dp_rank
1635
1636
1637
1638
1639
1640

            return PortArgs(
                tokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base}",
                scheduler_input_ipc_name=f"tcp://{dist_init_host}:{scheduler_input_port}",
                detokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base + 1}",
                nccl_port=port,
1641
                rpc_ipc_name=f"tcp://{dist_init_host}:{port_base + 2}",
1642
            )
1643

1644
1645
1646
1647
1648
1649
1650
1651
1652
1653

class LoRAPathAction(argparse.Action):
    def __call__(self, parser, namespace, values, option_string=None):
        setattr(namespace, self.dest, {})
        for lora_path in values:
            if "=" in lora_path:
                name, path = lora_path.split("=", 1)
                getattr(namespace, self.dest)[name] = path
            else:
                getattr(namespace, self.dest)[lora_path] = lora_path
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663


class DeprecatedAction(argparse.Action):
    def __init__(self, option_strings, dest, nargs=0, **kwargs):
        super(DeprecatedAction, self).__init__(
            option_strings, dest, nargs=nargs, **kwargs
        )

    def __call__(self, parser, namespace, values, option_string=None):
        raise ValueError(self.help)
1664
1665


1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
def get_model_arch(args: ServerArgs):
    hf_config = get_config(
        args.model_path,
        trust_remote_code=args.trust_remote_code,
        revision=args.revision,
        model_override_args=json.loads(args.json_model_override_args),
    )
    return hf_config.architectures[0]


1676
def auto_choose_speculative_params(self: ServerArgs):
1677
1678
1679
1680
1681
    """
    Automatically choose the parameters for speculative decoding.

    You can tune them on your own models and prompts with scripts/playground/bench_speculative.py
    """
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
    kwargs = {}

    hf_config = get_config(
        self.model_path,
        trust_remote_code=self.trust_remote_code,
        revision=self.revision,
        model_override_args=json.loads(self.json_model_override_args),
        **kwargs,
    )
    arch = hf_config.architectures[0]

1693
1694
1695
1696
1697
    if arch in ["LlamaForCausalLM"]:
        # The default value for llama
        return (5, 4, 8)
    elif arch in ["DeepseekV3ForCausalLM", "DeepseekV2ForCausalLM"]:
        # The default value for deepseek
1698
        return (3, 1, 4)
1699
1700
1701
1702
1703
    elif arch in ["Grok1ForCausalLM", "Grok1VForCausalLM"]:
        return (5, 4, 8)
    else:
        # The default value for all other models
        return (5, 4, 8)