server_args.py 68 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
Lianmin Zheng's avatar
Lianmin Zheng committed
14
15
"""The arguments of the server."""

Lianmin Zheng's avatar
Lianmin Zheng committed
16
17
import argparse
import dataclasses
18
import json
19
import logging
20
import os
21
import random
22
import tempfile
23
from typing import List, Literal, Optional
Lianmin Zheng's avatar
Lianmin Zheng committed
24

25
from sglang.srt.hf_transformers_utils import check_gguf_file, get_config
Xihuai Wang's avatar
Xihuai Wang committed
26
from sglang.srt.reasoning_parser import ReasoningParser
27
from sglang.srt.utils import (
Vincent's avatar
Vincent committed
28
    configure_ipv6,
29
    get_device,
Lianmin Zheng's avatar
Lianmin Zheng committed
30
    get_device_memory_capacity,
31
    is_flashinfer_available,
HAI's avatar
HAI committed
32
    is_hip,
33
    is_port_available,
34
    is_remote_url,
35
    is_valid_ipv6_address,
bjmsong's avatar
bjmsong committed
36
    nullable_str,
37
)
38

39
40
logger = logging.getLogger(__name__)

Lianmin Zheng's avatar
Lianmin Zheng committed
41
42
43

@dataclasses.dataclass
class ServerArgs:
Lianmin Zheng's avatar
Lianmin Zheng committed
44
    # Model and tokenizer
Lianmin Zheng's avatar
Lianmin Zheng committed
45
46
47
    model_path: str
    tokenizer_path: Optional[str] = None
    tokenizer_mode: str = "auto"
48
    skip_tokenizer_init: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
49
    load_format: str = "auto"
50
    trust_remote_code: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
51
    dtype: str = "auto"
52
    kv_cache_dtype: str = "auto"
Lianmin Zheng's avatar
Lianmin Zheng committed
53
    quantization: Optional[str] = None
Vincent's avatar
Vincent committed
54
    quantization_param_path: Optional[str] = None
55
    context_length: Optional[int] = None
56
    device: Optional[str] = None
57
    served_model_name: Optional[str] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
58
    chat_template: Optional[str] = None
59
    completion_template: Optional[str] = None
60
    is_embedding: bool = False
61
    enable_multimodal: Optional[bool] = None
62
    revision: Optional[str] = None
63
    impl: str = "auto"
Lianmin Zheng's avatar
Lianmin Zheng committed
64

65
    # Port for the HTTP server
Lianmin Zheng's avatar
Lianmin Zheng committed
66
67
68
69
    host: str = "127.0.0.1"
    port: int = 30000

    # Memory and scheduling
Lianmin Zheng's avatar
Lianmin Zheng committed
70
    mem_fraction_static: Optional[float] = None
71
    max_running_requests: Optional[int] = None
72
    max_total_tokens: Optional[int] = None
73
    chunked_prefill_size: Optional[int] = None
74
    max_prefill_tokens: int = 16384
75
    schedule_policy: str = "fcfs"
76
    schedule_conservativeness: float = 1.0
77
    cpu_offload_gb: int = 0
78
    page_size: int = 1
Lianmin Zheng's avatar
Lianmin Zheng committed
79
80
81

    # Other runtime options
    tp_size: int = 1
82
83
    pp_size: int = 1
    max_micro_batch_size: Optional[int] = None
84
    stream_interval: int = 1
85
    stream_output: bool = False
86
    random_seed: Optional[int] = None
87
    constrained_json_whitespace_pattern: Optional[str] = None
88
    watchdog_timeout: float = 300
89
    dist_timeout: Optional[int] = None  # timeout for torch.distributed
90
    download_dir: Optional[str] = None
91
    base_gpu_id: int = 0
92
    gpu_id_step: int = 1
93
    sleep_on_idle: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
94
95
96

    # Logging
    log_level: str = "info"
97
    log_level_http: Optional[str] = None
98
    log_requests: bool = False
99
    log_requests_level: int = 0
Liangsheng Yin's avatar
Liangsheng Yin committed
100
    show_time_cost: bool = False
101
    enable_metrics: bool = False
102
103
104
105
    bucket_time_to_first_token: Optional[List[float]] = None
    bucket_e2e_request_latency: Optional[List[float]] = None
    bucket_inter_token_latency: Optional[List[float]] = None
    collect_tokens_histogram: bool = False
106
    decode_log_interval: int = 40
107
    enable_request_time_stats_logging: bool = False
108
    kv_events_config: Optional[str] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
109

110
    # API related
111
    api_key: Optional[str] = None
112
    file_storage_path: str = "sglang_storage"
113
    enable_cache_report: bool = False
Xihuai Wang's avatar
Xihuai Wang committed
114
    reasoning_parser: Optional[str] = None
115
    tool_call_parser: Optional[str] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
116

117
118
119
    # Data parallelism
    dp_size: int = 1
    load_balance_method: str = "round_robin"
120

121
    # Multi-node distributed serving
122
    dist_init_addr: Optional[str] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
123
    nnodes: int = 1
124
    node_rank: int = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
125
126
127

    # Model override args in JSON
    json_model_override_args: str = "{}"
128
    preferred_sampling_params: Optional[str] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
129

130
131
132
    # LoRA
    lora_paths: Optional[List[str]] = None
    max_loras_per_batch: int = 8
133
    lora_backend: str = "triton"
134
135

    # Kernel backend
136
137
    attention_backend: Optional[str] = None
    sampling_backend: Optional[str] = None
138
    grammar_backend: Optional[str] = None
139
    mm_attention_backend: Optional[str] = None
140

141
142
    # Speculative decoding
    speculative_algorithm: Optional[str] = None
143
    speculative_draft_model_path: Optional[str] = None
144
145
146
    speculative_num_steps: Optional[int] = None
    speculative_eagle_topk: Optional[int] = None
    speculative_num_draft_tokens: Optional[int] = None
147
148
    speculative_accept_threshold_single: float = 1.0
    speculative_accept_threshold_acc: float = 1.0
149
    speculative_token_map: Optional[str] = None
150

151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
    # Expert parallelism
    ep_size: int = 1
    enable_ep_moe: bool = False
    enable_deepep_moe: bool = False
    deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto"
    ep_num_redundant_experts: int = 0
    ep_dispatch_algorithm: Optional[Literal["static", "dynamic", "fake"]] = None
    init_expert_location: str = "trivial"
    enable_eplb: bool = False
    eplb_algorithm: str = "auto"
    eplb_rebalance_num_iterations: int = 1000
    eplb_rebalance_layers_per_chunk: Optional[int] = None
    expert_distribution_recorder_mode: Optional[
        Literal["stat", "stat_approx", "per_pass", "per_token"]
    ] = None
    expert_distribution_recorder_buffer_size: Optional[int] = None
    enable_expert_distribution_metrics: bool = False
    deepep_config: Optional[str] = None
    moe_dense_tp_size: Optional[int] = None

171
172
    # Double Sparsity
    enable_double_sparsity: bool = False
Vincent's avatar
Vincent committed
173
    ds_channel_config_path: Optional[str] = None
174
175
176
177
178
    ds_heavy_channel_num: int = 32
    ds_heavy_token_num: int = 256
    ds_heavy_channel_type: str = "qk"
    ds_sparse_decode_threshold: int = 4096

179
    # Optimization/debug options
Lianmin Zheng's avatar
Lianmin Zheng committed
180
    disable_radix_cache: bool = False
181
182
    cuda_graph_max_bs: Optional[int] = None
    cuda_graph_bs: Optional[List[int]] = None
183
    disable_cuda_graph: bool = False
184
    disable_cuda_graph_padding: bool = False
185
    enable_profile_cuda_graph: bool = False
186
    enable_nccl_nvls: bool = False
187
    enable_tokenizer_batch_encode: bool = False
188
    disable_outlines_disk_cache: bool = False
189
    disable_custom_all_reduce: bool = False
190
    enable_mscclpp: bool = False
191
    disable_overlap_schedule: bool = False
192
    disable_overlap_cg_plan: bool = False
193
    enable_mixed_chunk: bool = False
Ke Bao's avatar
Ke Bao committed
194
    enable_dp_attention: bool = False
195
    enable_dp_lm_head: bool = False
196
    enable_two_batch_overlap: bool = False
197
    enable_torch_compile: bool = False
198
    torch_compile_max_bs: int = 32
199
    torchao_config: str = ""
200
    enable_nan_detection: bool = False
201
    enable_p2p_check: bool = False
202
    triton_attention_reduce_in_fp32: bool = False
203
    triton_attention_num_kv_splits: int = 8
204
    num_continuous_decode_steps: int = 1
205
    delete_ckpt_after_loading: bool = False
206
    enable_memory_saver: bool = False
207
    allow_auto_truncate: bool = False
208
    enable_custom_logit_processor: bool = False
209
    enable_hierarchical_cache: bool = False
210
    hicache_ratio: float = 2.0
Zhiqiang Xie's avatar
Zhiqiang Xie committed
211
212
    hicache_size: int = 0
    hicache_write_policy: str = "write_through_selective"
213
    flashinfer_mla_disable_ragged: bool = False
214
    disable_shared_experts_fusion: bool = False
215
    disable_chunked_prefix_cache: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
216
    disable_fast_image_processor: bool = False
217
    enable_return_hidden_states: bool = False
218
    warmups: Optional[str] = None
219
220
221
222
223

    # Debug tensor dumps
    debug_tensor_dump_output_folder: Optional[str] = None
    debug_tensor_dump_input_file: Optional[str] = None
    debug_tensor_dump_inject: bool = False
224
    debug_tensor_dump_prefill_only: bool = False
225

Byron Hsu's avatar
Byron Hsu committed
226
227
    # For PD disaggregation: can be "null" (not disaggregated), "prefill" (prefill-only), or "decode" (decode-only)
    disaggregation_mode: str = "null"
228
    disaggregation_transfer_backend: str = "mooncake"
229
    disaggregation_bootstrap_port: int = 8998
230
    disaggregation_ib_device: Optional[str] = None
231
    num_reserved_decode_tokens: int = 512  # used for decode kv cache offload in PD
232
    pdlb_url: Optional[str] = None
Byron Hsu's avatar
Byron Hsu committed
233

Lianmin Zheng's avatar
Lianmin Zheng committed
234
    def __post_init__(self):
235
236
237
        # Expert parallelism
        if self.enable_ep_moe:
            self.ep_size = self.tp_size
Lianmin Zheng's avatar
Lianmin Zheng committed
238
            logger.warning(
239
240
241
                f"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
            )

242
        # Set missing default values
Lianmin Zheng's avatar
Lianmin Zheng committed
243
244
        if self.tokenizer_path is None:
            self.tokenizer_path = self.model_path
245

246
247
248
        if self.device is None:
            self.device = get_device()

249
250
251
        if self.served_model_name is None:
            self.served_model_name = self.model_path

252
253
254
        if self.random_seed is None:
            self.random_seed = random.randint(0, 1 << 30)

Lianmin Zheng's avatar
Lianmin Zheng committed
255
        gpu_mem = get_device_memory_capacity(self.device)
256

257
        # Set mem fraction static
Lianmin Zheng's avatar
Lianmin Zheng committed
258
        if self.mem_fraction_static is None:
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
            if gpu_mem is not None:
                # GPU memory capacity = model weights + KV cache pool + activations + cuda graph buffers
                # mem_fraction_static = (model weights + KV cache pool) / GPU memory capacity.

                # We want mem_fraction_static to be as large as possible but still has enough room
                # for activations and cuda graph buffers. We use the following heuristic to
                # compute the needed size for activations and cuda graph buffers:
                # - The size of the activation depends on the chunked_prefill_size and model size.
                # - The size of cuda graph buffers depends on the cuda graph capture range and model size.
                # For GPUs with more memory, we use a larger chunked_prefill_size and
                # capture more cuda graphs, so they need to reserve more memory.
                parallel_size = self.tp_size * self.pp_size

                if gpu_mem < 20 * 1024:
                    # T4, 4080. (chunked_prefill_size 2k, cuda_graph_max_bs 8)
                    reserved_mem = (2.8 + parallel_size / 10) * 1024
                elif gpu_mem < 35 * 1024:
                    # A10, L40, 4090, 5090. (chunked_prefill_size 2k, cuda_graph_max_bs 8)
                    reserved_mem = (2.8 + parallel_size / 10) * 1024
                elif gpu_mem < 90 * 1024:
                    # H100, A100. (chunked_prefill_size 8k, cuda_graph_max_bs 160)
                    reserved_mem = (9.5 + parallel_size / 2) * 1024
                elif gpu_mem < 100 * 1024:
                    # H20. (chunked_prefill_size 8k, cuda_graph_max_bs 256)
                    reserved_mem = (12 + parallel_size / 2) * 1024
                elif gpu_mem < 160 * 1024:
                    # H200. (chunked_prefill_size 8k, cuda_graph_max_bs 256)
                    reserved_mem = (12 + parallel_size / 2) * 1024
287
                else:
288
289
290
                    # B200, MI300. (chunked_prefill_size 16k, cuda_graph_max_bs 512)
                    reserved_mem = 32 * 1024

291
                if self.speculative_algorithm is not None:
292
293
294
295
296
297
                    # draft model and larger cuda graph buffers
                    reserved_mem += 2 * 1024
                if self.enable_dp_attention:
                    reserved_mem += 4 * 1024

                self.mem_fraction_static = round((gpu_mem - reserved_mem) / gpu_mem, 3)
298
            else:
299
                self.mem_fraction_static = 0.88
300

301
302
        # Set chunked prefill size, which depends on the gpu memory capacity
        if self.chunked_prefill_size is None:
303
304
305
306
307
308
309
            if gpu_mem is not None:
                if gpu_mem < 35 * 1024:  # A10, L40, 4090
                    self.chunked_prefill_size = 2048
                elif gpu_mem < 160 * 1024:  # H100, H200, A100, H20
                    self.chunked_prefill_size = 8192
                else:  # B200, MI300
                    self.chunked_prefill_size = 16384
310
            else:
311
                self.chunked_prefill_size = 4096
Lianmin Zheng's avatar
Lianmin Zheng committed
312
313
        assert self.chunked_prefill_size % self.page_size == 0

314
315
316
317
318
319
320
321
322
        # Set cuda graph max batch size
        if self.cuda_graph_max_bs is None:
            # Based on detailed statistics, when serving TP1/TP2 models on lower-end GPUs with HBM<25G, you can either disable cuda graph or set `cuda_graph_max_bs` to a very small value to reduce the memory overhead of creating cuda graphs, with almost no impact on performance. However, when serving models with TP4 or TP8, we need to enable cuda graph to maintain high performance. In this case, we can set `cuda_graph_max_bs` to 80 (half of the default value 160) to reduce the memory overhead of creating cuda graphs. Looking at the logs from TP4 serving of qwen2-72b, a value of 80 is sufficient and can reduce the memory overhead of creating cuda graphs on lower-end GPUs compared to the original 160, avoiding OOM issues.
            if gpu_mem is not None and gpu_mem < 35 * 1024:
                if self.tp_size < 4:
                    self.cuda_graph_max_bs = 8
                else:
                    self.cuda_graph_max_bs = 80

323
324
325
        assert self.moe_dense_tp_size in {
            1,
            None,
Lianmin Zheng's avatar
Lianmin Zheng committed
326
        }, "moe_dense_tp_size only support 1 and None currently"
327

328
        if self.attention_backend == "flashmla":
329
330
331
332
            logger.warning(
                "FlashMLA only supports a page_size of 64, change page_size to 64."
            )
            self.page_size = 64
Lianmin Zheng's avatar
Lianmin Zheng committed
333

334
335
336
337
338
339
        if self.attention_backend == "cutlass_mla":
            logger.warning(
                "Cutlass MLA only supports a page_size of 128, change page_size to 128."
            )
            self.page_size = 128

340
        # Set kernel backends for hpu device
341
342
343
344
        if self.device == "hpu":
            self.attention_backend = "torch_native"
            self.sampling_backend = "pytorch"

Lianmin Zheng's avatar
Lianmin Zheng committed
345
        # Set kernel backends
346
347
348
349
350
        if self.device == "cpu":
            if self.attention_backend is None:
                self.attention_backend = "intel_amx"
            self.sampling_backend = "pytorch"

351
        if self.sampling_backend is None:
352
353
354
355
356
            self.sampling_backend = (
                "flashinfer" if is_flashinfer_available() else "pytorch"
            )

        if self.attention_backend == "torch_native":
357
            logger.warning(
358
359
360
                "Cuda graph is disabled because of using torch native attention backend"
            )
            self.disable_cuda_graph = True
361

362
363
364
        # Choose grammar backend
        if self.grammar_backend is None:
            self.grammar_backend = "xgrammar"
365

366
        # Data parallelism attention
Ke Bao's avatar
Ke Bao committed
367
        if self.enable_dp_attention:
368
            self.schedule_conservativeness = self.schedule_conservativeness * 0.3
Lianmin Zheng's avatar
Lianmin Zheng committed
369
370
371
372
373
            assert (
                self.dp_size > 1
            ), "Please set a dp-size > 1. You can use 1 < dp-size <= tp-size "
            assert self.tp_size % self.dp_size == 0
            self.chunked_prefill_size = self.chunked_prefill_size // self.dp_size
374
            logger.warning(
375
                f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE kernel issues. "
376
            )
377

378
379
380
381
382
        if self.enable_dp_lm_head:
            assert (
                self.enable_dp_attention
            ), "Please enable dp attention when setting enable_dp_attention. "

383
        # DeepEP MoE
Lianmin Zheng's avatar
Lianmin Zheng committed
384
        self.enable_sp_layernorm = False
385
        if self.enable_deepep_moe:
386
387
388
389
            if self.deepep_mode == "auto":
                assert (
                    not self.enable_dp_attention
                ), "DeepEP MoE `auto` mode is not supported with DP Attention."
390
391
392
            if self.deepep_mode == "normal":
                logger.warning("Cuda graph is disabled because deepep_mode=`normal`")
                self.disable_cuda_graph = True
393
394
395
396
            self.ep_size = self.tp_size
            self.enable_sp_layernorm = (
                self.dp_size < self.tp_size if self.enable_dp_attention else True
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
397
            logger.warning(
398
399
                f"DeepEP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
            )
400

401
402
403
404
405
406
        if self.pp_size > 1:
            self.disable_overlap_schedule = True
            logger.warning(
                "Pipeline parallelism is incompatible with overlap schedule."
            )

407
408
409
        if self.enable_eplb and (self.expert_distribution_recorder_mode is None):
            self.expert_distribution_recorder_mode = "stat"
            logger.info(
410
                "EPLB is enabled. The expert_distribution_recorder_mode is automatically set."
411
412
413
414
415
416
417
            )

        if (self.enable_eplb or (self.init_expert_location is not None)) and (
            self.ep_dispatch_algorithm is None
        ):
            self.ep_dispatch_algorithm = "static"
            logger.info(
418
                "EPLB is enabled or init_expert_location is provided. ep_dispatch_algorithm is configured."
419
420
421
422
423
424
425
            )

        if self.enable_expert_distribution_metrics and (
            self.expert_distribution_recorder_mode is None
        ):
            self.expert_distribution_recorder_mode = "stat"

426
        if self.expert_distribution_recorder_buffer_size is None:
427
428
            if (x := self.eplb_rebalance_num_iterations) is not None:
                self.expert_distribution_recorder_buffer_size = x
429
430
431
            elif self.expert_distribution_recorder_mode is not None:
                self.expert_distribution_recorder_buffer_size = 1000

432
        # Speculative Decoding
433
434
435
436
        if self.speculative_algorithm == "NEXTN":
            # NEXTN shares the same implementation of EAGLE
            self.speculative_algorithm = "EAGLE"

Lianmin Zheng's avatar
Lianmin Zheng committed
437
        if self.speculative_algorithm in ("EAGLE", "EAGLE3"):
438
            if self.max_running_requests is None:
439
                self.max_running_requests = 48
440
            self.disable_overlap_schedule = True
Lianmin Zheng's avatar
Lianmin Zheng committed
441
            logger.warning(
442
                "Overlap scheduler is disabled because of using "
443
                "eagle speculative decoding."
444
            )
445
446
447
448
449
450
            if self.enable_mixed_chunk:
                self.enable_mixed_chunk = False
                logger.warning(
                    "Mixed chunked prefill is disabled because of using "
                    "eagle speculative decoding."
                )
451

452
453
454
            model_arch = get_model_arch(self)

            # Auto set draft_model_path DeepSeek-V3/R1
455
456
457
458
459
460
461
            if model_arch == "DeepseekV3ForCausalLM":
                if self.speculative_draft_model_path is None:
                    self.speculative_draft_model_path = self.model_path
                else:
                    logger.warning(
                        "DeepSeek MTP does not require setting speculative_draft_model_path."
                    )
462

463
464
465
466
467
468
469
470
471
472
            # Auto choose parameters
            if self.speculative_num_steps is None:
                assert (
                    self.speculative_eagle_topk is None
                    and self.speculative_num_draft_tokens is None
                )
                (
                    self.speculative_num_steps,
                    self.speculative_eagle_topk,
                    self.speculative_num_draft_tokens,
473
                ) = auto_choose_speculative_params(self)
474
475
476

            if self.page_size > 1 and self.speculative_eagle_topk > 1:
                self.speculative_eagle_topk = 1
Lianmin Zheng's avatar
Lianmin Zheng committed
477
                logger.warning(
478
479
480
481
482
483
484
                    "speculative_eagle_topk is adjusted to 1 when page_size > 1"
                )

            if (
                self.speculative_eagle_topk == 1
                and self.speculative_num_draft_tokens != self.speculative_num_steps + 1
            ):
Lianmin Zheng's avatar
Lianmin Zheng committed
485
                logger.warning(
486
487
488
                    "speculative_num_draft_tokens is adjusted to speculative_num_steps + 1 when speculative_eagle_topk == 1"
                )
                self.speculative_num_draft_tokens = self.speculative_num_steps + 1
489

490
            # The token generated from the verify step is counted.
491
            # If sepculative_num_steps >= speculative_num_draft_tokens, the additional tokens will definitely be discarded.
492
            # assert self.speculative_num_steps < self.speculative_num_draft_tokens
493

494
495
496
497
498
499
        # GGUF
        if (
            self.load_format == "auto" or self.load_format == "gguf"
        ) and check_gguf_file(self.model_path):
            self.quantization = self.load_format = "gguf"

500
501
502
        if is_remote_url(self.model_path):
            self.load_format = "remote"

503
504
505
506
        # AMD-specific Triton attention KV splits default number
        if is_hip():
            self.triton_attention_num_kv_splits = 16

Byron Hsu's avatar
Byron Hsu committed
507
508
509
        # PD disaggregation
        if self.disaggregation_mode == "prefill":
            self.disable_cuda_graph = True
510
            logger.warning("Cuda graph is disabled for prefill server")
Byron Hsu's avatar
Byron Hsu committed
511
512
        elif self.disaggregation_mode == "decode":
            self.disable_radix_cache = True
513
            logger.warning("KV cache is forced as chunk cache for decode server")
Byron Hsu's avatar
Byron Hsu committed
514

515
516
517
        os.environ["SGLANG_ENABLE_TORCH_COMPILE"] = (
            "1" if self.enable_torch_compile else "0"
        )
518
519
520
521
        # Set env var before grammar backends init
        os.environ["SGLANG_DISABLE_OUTLINES_DISK_CACHE"] = (
            "1" if self.disable_outlines_disk_cache else "0"
        )
522

Lianmin Zheng's avatar
Lianmin Zheng committed
523
524
    @staticmethod
    def add_cli_args(parser: argparse.ArgumentParser):
525
        # Model and port args
Lianmin Zheng's avatar
Lianmin Zheng committed
526
527
528
529
530
531
532
533
534
535
536
537
        parser.add_argument(
            "--model-path",
            type=str,
            help="The path of the model weights. This can be a local folder or a Hugging Face repo ID.",
            required=True,
        )
        parser.add_argument(
            "--tokenizer-path",
            type=str,
            default=ServerArgs.tokenizer_path,
            help="The path of the tokenizer.",
        )
Yuanhan Zhang's avatar
Yuanhan Zhang committed
538
        parser.add_argument(
539
540
541
542
            "--host",
            type=str,
            default=ServerArgs.host,
            help="The host of the HTTP server.",
Yuanhan Zhang's avatar
Yuanhan Zhang committed
543
544
        )
        parser.add_argument(
545
546
547
548
            "--port",
            type=int,
            default=ServerArgs.port,
            help="The port of the HTTP server.",
Yuanhan Zhang's avatar
Yuanhan Zhang committed
549
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
550
551
552
553
554
555
556
557
558
        parser.add_argument(
            "--tokenizer-mode",
            type=str,
            default=ServerArgs.tokenizer_mode,
            choices=["auto", "slow"],
            help="Tokenizer mode. 'auto' will use the fast "
            "tokenizer if available, and 'slow' will "
            "always use the slow tokenizer.",
        )
559
560
561
        parser.add_argument(
            "--skip-tokenizer-init",
            action="store_true",
562
            help="If set, skip init tokenizer and pass input_ids in generate request.",
563
        )
564
        parser.add_argument(
Lianmin Zheng's avatar
Lianmin Zheng committed
565
566
567
            "--load-format",
            type=str,
            default=ServerArgs.load_format,
568
569
570
571
572
573
            choices=[
                "auto",
                "pt",
                "safetensors",
                "npcache",
                "dummy",
574
                "sharded_state",
575
576
                "gguf",
                "bitsandbytes",
577
                "layered",
578
                "remote",
579
            ],
Lianmin Zheng's avatar
Lianmin Zheng committed
580
581
582
583
584
585
586
587
588
            help="The format of the model weights to load. "
            '"auto" will try to load the weights in the safetensors format '
            "and fall back to the pytorch bin format if safetensors format "
            "is not available. "
            '"pt" will load the weights in the pytorch bin format. '
            '"safetensors" will load the weights in the safetensors format. '
            '"npcache" will load the weights in pytorch format and store '
            "a numpy cache to speed up the loading. "
            '"dummy" will initialize the weights with random values, '
589
            "which is mainly for profiling."
590
591
            '"gguf" will load the weights in the gguf format. '
            '"bitsandbytes" will load the weights using bitsandbytes '
592
593
594
595
            "quantization."
            '"layered" loads weights layer by layer so that one can quantize a '
            "layer before loading another to make the peak memory envelope "
            "smaller.",
Lianmin Zheng's avatar
Lianmin Zheng committed
596
        )
597
598
599
600
601
        parser.add_argument(
            "--trust-remote-code",
            action="store_true",
            help="Whether or not to allow for custom models defined on the Hub in their own modeling files.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
602
        parser.add_argument(
Lianmin Zheng's avatar
Lianmin Zheng committed
603
            "--dtype",
Cody Yu's avatar
Cody Yu committed
604
            type=str,
Lianmin Zheng's avatar
Lianmin Zheng committed
605
            default=ServerArgs.dtype,
Ying Sheng's avatar
Ying Sheng committed
606
607
            choices=["auto", "half", "float16", "bfloat16", "float", "float32"],
            help="Data type for model weights and activations.\n\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
608
            '* "auto" will use FP16 precision for FP32 and FP16 models, and '
Ying Sheng's avatar
Ying Sheng committed
609
            "BF16 precision for BF16 models.\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
610
611
612
613
            '* "half" for FP16. Recommended for AWQ quantization.\n'
            '* "float16" is the same as "half".\n'
            '* "bfloat16" for a balance between precision and range.\n'
            '* "float" is shorthand for FP32 precision.\n'
Ying Sheng's avatar
Ying Sheng committed
614
615
            '* "float32" for FP32 precision.',
        )
616
617
618
619
        parser.add_argument(
            "--kv-cache-dtype",
            type=str,
            default=ServerArgs.kv_cache_dtype,
bjmsong's avatar
bjmsong committed
620
621
622
            choices=["auto", "fp8_e5m2", "fp8_e4m3"],
            help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" and "fp8_e4m3" is supported for CUDA 11.8+.',
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
623
624
625
626
        parser.add_argument(
            "--quantization",
            type=str,
            default=ServerArgs.quantization,
Ying Sheng's avatar
Ying Sheng committed
627
628
629
630
631
632
            choices=[
                "awq",
                "fp8",
                "gptq",
                "marlin",
                "gptq_marlin",
Ying Sheng's avatar
Ying Sheng committed
633
                "awq_marlin",
Ying Sheng's avatar
Ying Sheng committed
634
                "bitsandbytes",
635
                "gguf",
636
                "modelopt",
637
                "modelopt_fp4",
638
                "w8a8_int8",
HandH1998's avatar
HandH1998 committed
639
                "w8a8_fp8",
AniZpZ's avatar
AniZpZ committed
640
                "moe_wna16",
HandH1998's avatar
HandH1998 committed
641
                "qoq",
Ying Sheng's avatar
Ying Sheng committed
642
            ],
Lianmin Zheng's avatar
Lianmin Zheng committed
643
644
            help="The quantization method.",
        )
645
646
647
648
649
650
651
652
653
        parser.add_argument(
            "--quantization-param-path",
            type=nullable_str,
            default=None,
            help="Path to the JSON file containing the KV cache "
            "scaling factors. This should generally be supplied, when "
            "KV cache dtype is FP8. Otherwise, KV cache scaling factors "
            "default to 1.0, which may cause accuracy issues. ",
        )
654
655
656
657
658
659
660
661
662
        parser.add_argument(
            "--context-length",
            type=int,
            default=ServerArgs.context_length,
            help="The model's maximum context length. Defaults to None (will use the value from the model's config.json instead).",
        )
        parser.add_argument(
            "--device",
            type=str,
663
            default=ServerArgs.device,
664
            help="The device to use ('cuda', 'xpu', 'hpu', 'npu', 'cpu'). Defaults to auto-detection if not specified.",
665
        )
666
667
668
669
670
671
        parser.add_argument(
            "--served-model-name",
            type=str,
            default=ServerArgs.served_model_name,
            help="Override the model name returned by the v1/models endpoint in OpenAI API server.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
672
673
674
675
676
677
        parser.add_argument(
            "--chat-template",
            type=str,
            default=ServerArgs.chat_template,
            help="The buliltin chat template name or the path of the chat template file. This is only used for OpenAI-compatible API server.",
        )
678
679
680
681
682
683
        parser.add_argument(
            "--completion-template",
            type=str,
            default=ServerArgs.completion_template,
            help="The buliltin completion template name or the path of the completion template file. This is only used for OpenAI-compatible API server. only for code completion currently.",
        )
684
685
686
687
688
        parser.add_argument(
            "--is-embedding",
            action="store_true",
            help="Whether to use a CausalLM as an embedding model.",
        )
689
690
691
692
693
694
        parser.add_argument(
            "--enable-multimodal",
            default=ServerArgs.enable_multimodal,
            action="store_true",
            help="Enable the multimodal functionality for the served model. If the model being served is not multimodal, nothing will happen",
        )
695
696
697
698
699
700
701
702
        parser.add_argument(
            "--revision",
            type=str,
            default=None,
            help="The specific model version to use. It can be a branch "
            "name, a tag name, or a commit id. If unspecified, will use "
            "the default version.",
        )
703
704
705
706
707
708
709
710
711
712
713
714
        parser.add_argument(
            "--impl",
            type=str,
            default=ServerArgs.impl,
            help="Which implementation of the model to use.\n\n"
            '* "auto" will try to use the SGLang implementation if it exists '
            "and fall back to the Transformers implementation if no SGLang "
            "implementation is available.\n"
            '* "sglang" will use the SGLang model implementation.\n'
            '* "transformers" will use the Transformers model '
            "implementation.\n",
        )
715

716
        # Memory and scheduling
Lianmin Zheng's avatar
Lianmin Zheng committed
717
718
719
720
        parser.add_argument(
            "--mem-fraction-static",
            type=float,
            default=ServerArgs.mem_fraction_static,
721
            help="The fraction of the memory used for static allocation (model weights and KV cache memory pool). Use a smaller value if you see out-of-memory errors.",
Lianmin Zheng's avatar
Lianmin Zheng committed
722
        )
723
724
725
726
727
728
        parser.add_argument(
            "--max-running-requests",
            type=int,
            default=ServerArgs.max_running_requests,
            help="The maximum number of running requests.",
        )
729
730
731
732
        parser.add_argument(
            "--max-total-tokens",
            type=int,
            default=ServerArgs.max_total_tokens,
733
734
            help="The maximum number of tokens in the memory pool. If not specified, it will be automatically calculated based on the memory usage fraction. "
            "This option is typically used for development and debugging purposes.",
735
        )
736
737
738
739
        parser.add_argument(
            "--chunked-prefill-size",
            type=int,
            default=ServerArgs.chunked_prefill_size,
740
            help="The maximum number of tokens in a chunk for the chunked prefill. Setting this to -1 means disabling chunked prefill.",
741
742
743
744
745
746
747
        )
        parser.add_argument(
            "--max-prefill-tokens",
            type=int,
            default=ServerArgs.max_prefill_tokens,
            help="The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model's maximum context length.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
748
        parser.add_argument(
749
            "--schedule-policy",
Lianmin Zheng's avatar
Lianmin Zheng committed
750
            type=str,
751
            default=ServerArgs.schedule_policy,
Liangsheng Yin's avatar
Liangsheng Yin committed
752
            choices=["lpm", "random", "fcfs", "dfs-weight"],
753
            help="The scheduling policy of the requests.",
Lianmin Zheng's avatar
Lianmin Zheng committed
754
        )
755
756
757
758
        parser.add_argument(
            "--schedule-conservativeness",
            type=float,
            default=ServerArgs.schedule_conservativeness,
759
            help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see requests being retracted frequently.",
760
        )
761
762
763
764
        parser.add_argument(
            "--cpu-offload-gb",
            type=int,
            default=ServerArgs.cpu_offload_gb,
765
            help="How many GBs of RAM to reserve for CPU offloading.",
766
        )
767
768
769
770
771
772
        parser.add_argument(
            "--page-size",
            type=int,
            default=ServerArgs.page_size,
            help="The number of tokens in a page.",
        )
773

774
        # Other runtime options
Lianmin Zheng's avatar
Lianmin Zheng committed
775
        parser.add_argument(
776
            "--tensor-parallel-size",
Lianmin Zheng's avatar
Lianmin Zheng committed
777
            "--tp-size",
Lianmin Zheng's avatar
Lianmin Zheng committed
778
            type=int,
Lianmin Zheng's avatar
Lianmin Zheng committed
779
            default=ServerArgs.tp_size,
780
            help="The tensor parallelism size.",
781
        )
782
783
784
785
786
787
788
789
790
791
792
793
794
        parser.add_argument(
            "--pipeline-parallel-size",
            "--pp-size",
            type=int,
            default=ServerArgs.pp_size,
            help="The pipeline parallelism size.",
        )
        parser.add_argument(
            "--max-micro-batch-size",
            type=int,
            default=ServerArgs.max_micro_batch_size,
            help="The maximum micro batch size in pipeline parallelism.",
        )
795
796
797
        parser.add_argument(
            "--stream-interval",
            type=int,
Lianmin Zheng's avatar
Lianmin Zheng committed
798
            default=ServerArgs.stream_interval,
799
            help="The interval (or buffer size) for streaming in terms of the token length. A smaller value makes streaming smoother, while a larger value makes the throughput higher",
800
        )
801
802
803
804
805
        parser.add_argument(
            "--stream-output",
            action="store_true",
            help="Whether to output as a sequence of disjoint segments.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
806
807
808
809
        parser.add_argument(
            "--random-seed",
            type=int,
            default=ServerArgs.random_seed,
810
            help="The random seed.",
Lianmin Zheng's avatar
Lianmin Zheng committed
811
        )
812
813
814
815
816
817
        parser.add_argument(
            "--constrained-json-whitespace-pattern",
            type=str,
            default=ServerArgs.constrained_json_whitespace_pattern,
            help=r"Regex pattern for syntactic whitespaces allowed in JSON constrained output. For example, to allow the model generate consecutive whitespaces, set the pattern to [\n\t ]*",
        )
818
819
820
821
822
823
        parser.add_argument(
            "--watchdog-timeout",
            type=float,
            default=ServerArgs.watchdog_timeout,
            help="Set watchdog timeout in seconds. If a forward batch takes longer than this, the server will crash to prevent hanging.",
        )
824
825
826
827
828
829
        parser.add_argument(
            "--dist-timeout",
            type=int,
            default=ServerArgs.dist_timeout,
            help="Set timeout for torch.distributed initialization.",
        )
830
831
832
833
        parser.add_argument(
            "--download-dir",
            type=str,
            default=ServerArgs.download_dir,
834
            help="Model download directory for huggingface.",
835
        )
836
837
838
839
840
841
        parser.add_argument(
            "--base-gpu-id",
            type=int,
            default=ServerArgs.base_gpu_id,
            help="The base GPU ID to start allocating GPUs from. Useful when running multiple instances on the same machine.",
        )
842
843
844
845
846
847
        parser.add_argument(
            "--gpu-id-step",
            type=int,
            default=ServerArgs.gpu_id_step,
            help="The delta between consecutive GPU IDs that are used. For example, setting it to 2 will use GPU 0,2,4,...",
        )
848
849
850
851
852
        parser.add_argument(
            "--sleep-on-idle",
            action="store_true",
            help="Reduce CPU usage when sglang is idle.",
        )
853
854

        # Logging
Lianmin Zheng's avatar
Lianmin Zheng committed
855
856
857
858
        parser.add_argument(
            "--log-level",
            type=str,
            default=ServerArgs.log_level,
859
            help="The logging level of all loggers.",
Lianmin Zheng's avatar
Lianmin Zheng committed
860
        )
861
        parser.add_argument(
862
863
864
865
            "--log-level-http",
            type=str,
            default=ServerArgs.log_level_http,
            help="The logging level of HTTP server. If not set, reuse --log-level by default.",
866
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
867
        parser.add_argument(
868
            "--log-requests",
Lianmin Zheng's avatar
Lianmin Zheng committed
869
            action="store_true",
870
871
872
873
874
875
876
877
            help="Log metadata, inputs, outputs of all requests. The verbosity is decided by --log-requests-level",
        )
        parser.add_argument(
            "--log-requests-level",
            type=int,
            default=0,
            help="0: Log metadata. 1. Log metadata and partial input/output. 2. Log every input/output.",
            choices=[0, 1, 2],
Lianmin Zheng's avatar
Lianmin Zheng committed
878
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
879
880
881
        parser.add_argument(
            "--show-time-cost",
            action="store_true",
Ying Sheng's avatar
Ying Sheng committed
882
            help="Show time cost of custom marks.",
Lianmin Zheng's avatar
Lianmin Zheng committed
883
        )
884
885
886
887
888
        parser.add_argument(
            "--enable-metrics",
            action="store_true",
            help="Enable log prometheus metrics.",
        )
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
        parser.add_argument(
            "--bucket-time-to-first-token",
            type=float,
            nargs="+",
            default=ServerArgs.bucket_time_to_first_token,
            help="The buckets of time to first token, specified as a list of floats.",
        )
        parser.add_argument(
            "--bucket-inter-token-latency",
            type=float,
            nargs="+",
            default=ServerArgs.bucket_inter_token_latency,
            help="The buckets of inter-token latency, specified as a list of floats.",
        )
        parser.add_argument(
            "--bucket-e2e-request-latency",
            type=float,
            nargs="+",
            default=ServerArgs.bucket_e2e_request_latency,
            help="The buckets of end-to-end request latency, specified as a list of floats.",
        )
        parser.add_argument(
            "--collect-tokens-histogram",
            action="store_true",
            default=ServerArgs.collect_tokens_histogram,
            help="Collect prompt/generation tokens histogram.",
        )
916
917
918
919
920
921
        parser.add_argument(
            "--kv-events-config",
            type=str,
            default=None,
            help="Config in json format for NVIDIA dynamo KV event publishing. Publishing will be enabled if this flag is used.",
        )
922
923
924
925
        parser.add_argument(
            "--decode-log-interval",
            type=int,
            default=ServerArgs.decode_log_interval,
926
            help="The log interval of decode batch.",
927
        )
928
929
930
931
932
933
        parser.add_argument(
            "--enable-request-time-stats-logging",
            action="store_true",
            default=ServerArgs.enable_request_time_stats_logging,
            help="Enable per request time stats logging",
        )
934

935
        # API related
Liangsheng Yin's avatar
Liangsheng Yin committed
936
937
938
939
        parser.add_argument(
            "--api-key",
            type=str,
            default=ServerArgs.api_key,
940
            help="Set API key of the server. It is also used in the OpenAI API compatible server.",
Liangsheng Yin's avatar
Liangsheng Yin committed
941
        )
942
        parser.add_argument(
943
            "--file-storage-path",
944
            type=str,
945
            default=ServerArgs.file_storage_path,
946
947
            help="The path of the file storage in backend.",
        )
948
949
950
951
952
        parser.add_argument(
            "--enable-cache-report",
            action="store_true",
            help="Return number of cached tokens in usage.prompt_tokens_details for each openai request.",
        )
Xihuai Wang's avatar
Xihuai Wang committed
953
954
955
956
957
958
959
        parser.add_argument(
            "--reasoning-parser",
            type=str,
            choices=list(ReasoningParser.DetectorMap.keys()),
            default=ServerArgs.reasoning_parser,
            help=f"Specify the parser for reasoning models, supported parsers are: {list(ReasoningParser.DetectorMap.keys())}.",
        )
960
961
962
963
964
965
966
        parser.add_argument(
            "--tool-call-parser",
            type=str,
            choices=["qwen25", "mistral", "llama3", "deepseekv3", "pythonic"],
            default=ServerArgs.tool_call_parser,
            help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', 'llama3', 'deepseekv3', and 'pythonic'.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
967

968
969
        # Data parallelism
        parser.add_argument(
970
            "--data-parallel-size",
971
972
973
            "--dp-size",
            type=int,
            default=ServerArgs.dp_size,
974
            help="The data parallelism size.",
975
976
977
978
979
        )
        parser.add_argument(
            "--load-balance-method",
            type=str,
            default=ServerArgs.load_balance_method,
980
            help="The load balancing strategy for data parallelism.",
981
982
983
984
985
            choices=[
                "round_robin",
                "shortest_queue",
            ],
        )
986

987
        # Multi-node distributed serving
988
        parser.add_argument(
989
            "--dist-init-addr",
990
            "--nccl-init-addr",  # For backward compatibility. This will be removed in the future.
991
            type=str,
992
            help="The host address for initializing distributed backend (e.g., `192.168.0.2:25000`).",
993
994
        )
        parser.add_argument(
Liangsheng Yin's avatar
Liangsheng Yin committed
995
            "--nnodes", type=int, default=ServerArgs.nnodes, help="The number of nodes."
996
        )
997
998
999
        parser.add_argument(
            "--node-rank", type=int, default=ServerArgs.node_rank, help="The node rank."
        )
1000

Lianmin Zheng's avatar
Lianmin Zheng committed
1001
1002
1003
1004
1005
1006
1007
        # Model override args
        parser.add_argument(
            "--json-model-override-args",
            type=str,
            help="A dictionary in JSON string format used to override default model configurations.",
            default=ServerArgs.json_model_override_args,
        )
1008
1009
1010
1011
1012
        parser.add_argument(
            "--preferred-sampling-params",
            type=str,
            help="json-formatted sampling settings that will be returned in /get_model_info",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1013

1014
1015
1016
1017
1018
1019
1020
        # LoRA
        parser.add_argument(
            "--lora-paths",
            type=str,
            nargs="*",
            default=None,
            action=LoRAPathAction,
1021
            help="The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}.",
1022
1023
1024
1025
1026
        )
        parser.add_argument(
            "--max-loras-per-batch",
            type=int,
            default=8,
1027
1028
1029
1030
1031
1032
1033
            help="Maximum number of adapters for a running batch, include base-only request.",
        )
        parser.add_argument(
            "--lora-backend",
            type=str,
            default="triton",
            help="Choose the kernel backend for multi-LoRA serving.",
1034
1035
1036
        )

        # Kernel backend
1037
1038
1039
        parser.add_argument(
            "--attention-backend",
            type=str,
1040
            choices=[
1041
                "aiter",
1042
                "cutlass_mla",
1043
                "fa3",
1044
                "flashinfer",
1045
                "flashmla",
1046
                "intel_amx",
1047
1048
                "torch_native",
                "triton",
1049
            ],
1050
1051
1052
            default=ServerArgs.attention_backend,
            help="Choose the kernels for attention layers.",
        )
1053
1054
1055
1056
1057
1058
1059
        parser.add_argument(
            "--sampling-backend",
            type=str,
            choices=["flashinfer", "pytorch"],
            default=ServerArgs.sampling_backend,
            help="Choose the kernels for sampling layers.",
        )
1060
1061
1062
        parser.add_argument(
            "--grammar-backend",
            type=str,
1063
            choices=["xgrammar", "outlines", "llguidance", "none"],
1064
            default=ServerArgs.grammar_backend,
Lianmin Zheng's avatar
Lianmin Zheng committed
1065
            help="Choose the backend for grammar-guided decoding.",
1066
        )
1067

1068
1069
1070
1071
        # Speculative decoding
        parser.add_argument(
            "--speculative-algorithm",
            type=str,
James Liu's avatar
James Liu committed
1072
            choices=["EAGLE", "EAGLE3", "NEXTN"],
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
            help="Speculative algorithm.",
        )
        parser.add_argument(
            "--speculative-draft-model-path",
            type=str,
            help="The path of the draft model weights. This can be a local folder or a Hugging Face repo ID.",
        )
        parser.add_argument(
            "--speculative-num-steps",
            type=int,
            help="The number of steps sampled from draft model in Speculative Decoding.",
            default=ServerArgs.speculative_num_steps,
        )
        parser.add_argument(
            "--speculative-eagle-topk",
            type=int,
1089
            help="The number of tokens sampled from the draft model in eagle2 each step.",
1090
1091
            default=ServerArgs.speculative_eagle_topk,
        )
1092
1093
1094
        parser.add_argument(
            "--speculative-num-draft-tokens",
            type=int,
1095
            help="The number of tokens sampled from the draft model in Speculative Decoding.",
1096
1097
            default=ServerArgs.speculative_num_draft_tokens,
        )
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
        parser.add_argument(
            "--speculative-accept-threshold-single",
            type=float,
            help="Accept a draft token if its probability in the target model is greater than this threshold.",
            default=ServerArgs.speculative_accept_threshold_single,
        )
        parser.add_argument(
            "--speculative-accept-threshold-acc",
            type=float,
            help="The accept probability of a draft token is raised from its target probability p to min(1, p / threshold_acc).",
            default=ServerArgs.speculative_accept_threshold_acc,
        )
1110
1111
1112
1113
1114
1115
        parser.add_argument(
            "--speculative-token-map",
            type=str,
            help="The path of the draft model's small vocab table.",
            default=ServerArgs.speculative_token_map,
        )
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
        parser.add_argument(
            "--mm-attention-backend",
            type=str,
            choices=["sdpa", "fa3", "triton_attn"],
            default=ServerArgs.mm_attention_backend,
            help="Set multimodal attention backend.",
        )

        # Expert parallelism
        parser.add_argument(
            "--expert-parallel-size",
            "--ep-size",
            type=int,
            default=ServerArgs.ep_size,
            help="The expert parallelism size.",
        )
        parser.add_argument(
            "--enable-ep-moe",
            action="store_true",
            help="Enabling expert parallelism for moe. The ep size is equal to the tp size.",
        )
        parser.add_argument(
            "--enable-deepep-moe",
            action="store_true",
            help="Enabling DeepEP MoE implementation for EP MoE.",
        )
        parser.add_argument(
            "--deepep-mode",
            type=str,
            choices=["normal", "low_latency", "auto"],
            default="auto",
            help="Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch.",
        )
        parser.add_argument(
            "--ep-num-redundant-experts",
            type=int,
            default=ServerArgs.ep_num_redundant_experts,
            help="Allocate this number of redundant experts in expert parallel.",
        )
        parser.add_argument(
            "--ep-dispatch-algorithm",
            type=str,
            default=ServerArgs.ep_dispatch_algorithm,
            help="The algorithm to choose ranks for redundant experts in expert parallel.",
        )
        parser.add_argument(
            "--init-expert-location",
            type=str,
            default=ServerArgs.init_expert_location,
            help="Initial location of EP experts.",
        )
        parser.add_argument(
            "--enable-eplb",
            action="store_true",
            help="Enable EPLB algorithm",
        )
        parser.add_argument(
            "--eplb-algorithm",
            type=str,
            default=ServerArgs.eplb_algorithm,
            help="Chosen EPLB algorithm",
        )
        parser.add_argument(
            "--eplb-rebalance-num-iterations",
            type=int,
            default=ServerArgs.eplb_rebalance_num_iterations,
            help="Number of iterations to automatically trigger a EPLB re-balance.",
        )
        parser.add_argument(
            "--eplb-rebalance-layers-per-chunk",
            type=int,
            default=ServerArgs.eplb_rebalance_layers_per_chunk,
            help="Number of layers to rebalance per forward pass.",
        )
        parser.add_argument(
            "--expert-distribution-recorder-mode",
            type=str,
            default=ServerArgs.expert_distribution_recorder_mode,
            help="Mode of expert distribution recorder.",
        )
        parser.add_argument(
            "--expert-distribution-recorder-buffer-size",
            type=int,
            default=ServerArgs.expert_distribution_recorder_buffer_size,
            help="Circular buffer size of expert distribution recorder. Set to -1 to denote infinite buffer.",
        )
        parser.add_argument(
            "--enable-expert-distribution-metrics",
            action="store_true",
            help="Enable logging metrics for expert balancedness",
        )
        parser.add_argument(
            "--deepep-config",
            type=str,
            default=ServerArgs.deepep_config,
            help="Tuned DeepEP config suitable for your own cluster. It can be either a string with JSON content or a file path.",
        )
        parser.add_argument(
            "--moe-dense-tp-size",
            type=int,
            default=ServerArgs.moe_dense_tp_size,
            help="TP size for MoE dense MLP layers. This flag is useful when, with large TP size, there are errors caused by weights in MLP layers having dimension smaller than the min dimension GEMM supports.",
        )
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256

        # Double Sparsity
        parser.add_argument(
            "--enable-double-sparsity",
            action="store_true",
            help="Enable double sparsity attention",
        )
        parser.add_argument(
            "--ds-channel-config-path",
            type=str,
            default=ServerArgs.ds_channel_config_path,
            help="The path of the double sparsity channel config",
        )
        parser.add_argument(
            "--ds-heavy-channel-num",
            type=int,
            default=ServerArgs.ds_heavy_channel_num,
            help="The number of heavy channels in double sparsity attention",
        )
        parser.add_argument(
            "--ds-heavy-token-num",
            type=int,
            default=ServerArgs.ds_heavy_token_num,
            help="The number of heavy tokens in double sparsity attention",
        )
        parser.add_argument(
            "--ds-heavy-channel-type",
            type=str,
            default=ServerArgs.ds_heavy_channel_type,
            help="The type of heavy channels in double sparsity attention",
        )
        parser.add_argument(
            "--ds-sparse-decode-threshold",
            type=int,
            default=ServerArgs.ds_sparse_decode_threshold,
            help="The type of heavy channels in double sparsity attention",
        )

1257
        # Optimization/debug options
Liangsheng Yin's avatar
Liangsheng Yin committed
1258
        parser.add_argument(
Lianmin Zheng's avatar
Lianmin Zheng committed
1259
            "--disable-radix-cache",
Liangsheng Yin's avatar
Liangsheng Yin committed
1260
            action="store_true",
Ying Sheng's avatar
Ying Sheng committed
1261
            help="Disable RadixAttention for prefix caching.",
Liangsheng Yin's avatar
Liangsheng Yin committed
1262
        )
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
        parser.add_argument(
            "--cuda-graph-max-bs",
            type=int,
            default=ServerArgs.cuda_graph_max_bs,
            help="Set the maximum batch size for cuda graph. It will extend the cuda graph capture batch size to this value.",
        )
        parser.add_argument(
            "--cuda-graph-bs",
            type=int,
            nargs="+",
            help="Set the list of batch sizes for cuda graph.",
        )
1275
1276
1277
        parser.add_argument(
            "--disable-cuda-graph",
            action="store_true",
1278
            help="Disable cuda graph.",
1279
        )
1280
        parser.add_argument(
1281
1282
            "--disable-cuda-graph-padding",
            action="store_true",
1283
            help="Disable cuda graph when padding is needed. Still uses cuda graph when padding is not needed.",
1284
        )
1285
1286
1287
1288
1289
        parser.add_argument(
            "--enable-profile-cuda-graph",
            action="store_true",
            help="Enable profiling of cuda graph capture.",
        )
1290
1291
1292
1293
1294
        parser.add_argument(
            "--enable-nccl-nvls",
            action="store_true",
            help="Enable NCCL NVLS for prefill heavy requests when available.",
        )
1295
1296
1297
1298
1299
        parser.add_argument(
            "--enable-tokenizer-batch-encode",
            action="store_true",
            help="Enable batch tokenization for improved performance when processing multiple text inputs. Do not use with image inputs, pre-tokenized input_ids, or input_embeds.",
        )
1300
        parser.add_argument(
1301
            "--disable-outlines-disk-cache",
1302
            action="store_true",
1303
            help="Disable disk cache of outlines to avoid possible crashes related to file system or high concurrency.",
1304
        )
1305
1306
1307
1308
1309
        parser.add_argument(
            "--disable-custom-all-reduce",
            action="store_true",
            help="Disable the custom all-reduce kernel and fall back to NCCL.",
        )
1310
1311
1312
1313
1314
        parser.add_argument(
            "--enable-mscclpp",
            action="store_true",
            help="Enable using mscclpp for small messages for all-reduce kernel and fall back to NCCL.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1315
        parser.add_argument(
1316
            "--disable-overlap-schedule",
Lianmin Zheng's avatar
Lianmin Zheng committed
1317
            action="store_true",
1318
            help="Disable the overlap scheduler, which overlaps the CPU scheduler with GPU model worker.",
Lianmin Zheng's avatar
Lianmin Zheng committed
1319
        )
1320
1321
1322
1323
1324
        parser.add_argument(
            "--disable-overlap-cg-plan",
            action="store_true",
            help="Disable the overlap optimization for cudagraph preparation in eagle verify.",
        )
1325
1326
1327
        parser.add_argument(
            "--enable-mixed-chunk",
            action="store_true",
1328
            help="Enabling mixing prefill and decode in a batch when using chunked prefill.",
1329
        )
Ke Bao's avatar
Ke Bao committed
1330
1331
1332
        parser.add_argument(
            "--enable-dp-attention",
            action="store_true",
1333
            help="Enabling data parallelism for attention and tensor parallelism for FFN. The dp size should be equal to the tp size. Currently DeepSeek-V2 and Qwen 2/3 MoE models are supported.",
Ke Bao's avatar
Ke Bao committed
1334
        )
1335
1336
1337
1338
1339
        parser.add_argument(
            "--enable-dp-lm-head",
            action="store_true",
            help="Enable vocabulary parallel across the attention TP group to avoid all-gather across DP groups, optimizing performance under DP attention.",
        )
1340
1341
1342
1343
1344
        parser.add_argument(
            "--enable-two-batch-overlap",
            action="store_true",
            help="Enabling two micro batches to overlap.",
        )
1345
1346
1347
        parser.add_argument(
            "--enable-torch-compile",
            action="store_true",
1348
1349
            help="Optimize the model with torch.compile. Experimental feature.",
        )
1350
        parser.add_argument(
1351
            "--torch-compile-max-bs",
1352
            type=int,
1353
            default=ServerArgs.torch_compile_max_bs,
1354
1355
            help="Set the maximum batch size when using torch compile.",
        )
1356
1357
1358
1359
        parser.add_argument(
            "--torchao-config",
            type=str,
            default=ServerArgs.torchao_config,
1360
            help="Optimize the model with torchao. Experimental feature. Current choices are: int8dq, int8wo, int4wo-<group_size>, fp8wo, fp8dq-per_tensor, fp8dq-per_row",
1361
        )
1362
1363
1364
1365
1366
        parser.add_argument(
            "--enable-nan-detection",
            action="store_true",
            help="Enable the NaN detection for debugging purposes.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1367
        parser.add_argument(
1368
            "--enable-p2p-check",
Lianmin Zheng's avatar
Lianmin Zheng committed
1369
            action="store_true",
1370
            help="Enable P2P check for GPU access, otherwise the p2p access is allowed by default.",
Lianmin Zheng's avatar
Lianmin Zheng committed
1371
        )
1372
        parser.add_argument(
1373
            "--triton-attention-reduce-in-fp32",
1374
            action="store_true",
1375
            help="Cast the intermediate attention results to fp32 to avoid possible crashes related to fp16."
1376
            "This only affects Triton attention kernels.",
1377
        )
1378
1379
1380
1381
1382
1383
        parser.add_argument(
            "--triton-attention-num-kv-splits",
            type=int,
            default=ServerArgs.triton_attention_num_kv_splits,
            help="The number of KV splits in flash decoding Triton kernel. Larger value is better in longer context scenarios. The default value is 8.",
        )
1384
1385
1386
1387
1388
1389
1390
1391
        parser.add_argument(
            "--num-continuous-decode-steps",
            type=int,
            default=ServerArgs.num_continuous_decode_steps,
            help="Run multiple continuous decoding steps to reduce scheduling overhead. "
            "This can potentially increase throughput but may also increase time-to-first-token latency. "
            "The default value is 1, meaning only run one decoding step at a time.",
        )
1392
1393
1394
1395
1396
        parser.add_argument(
            "--delete-ckpt-after-loading",
            action="store_true",
            help="Delete the model checkpoint after loading the model.",
        )
1397
1398
1399
1400
1401
        parser.add_argument(
            "--enable-memory-saver",
            action="store_true",
            help="Allow saving memory using release_memory_occupation and resume_memory_occupation",
        )
1402
1403
1404
1405
1406
        parser.add_argument(
            "--allow-auto-truncate",
            action="store_true",
            help="Allow automatically truncating requests that exceed the maximum input length instead of returning an error.",
        )
1407
1408
1409
1410
1411
        parser.add_argument(
            "--enable-custom-logit-processor",
            action="store_true",
            help="Enable users to pass custom logit processors to the server (disabled by default for security)",
        )
1412
1413
1414
1415
1416
        parser.add_argument(
            "--enable-hierarchical-cache",
            action="store_true",
            help="Enable hierarchical cache",
        )
1417
1418
1419
1420
1421
1422
        parser.add_argument(
            "--hicache-ratio",
            type=float,
            default=ServerArgs.hicache_ratio,
            help="The ratio of the size of host KV cache memory pool to the size of device pool.",
        )
Zhiqiang Xie's avatar
Zhiqiang Xie committed
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
        parser.add_argument(
            "--hicache-size",
            type=int,
            default=ServerArgs.hicache_size,
            help="The size of host KV cache memory pool in gigabytes, which will override the hicache_ratio if set.",
        )
        parser.add_argument(
            "--hicache-write-policy",
            type=str,
            choices=["write_back", "write_through", "write_through_selective"],
            default=ServerArgs.hicache_write_policy,
            help="The write policy of hierarchical cache.",
        )
1436
        parser.add_argument(
1437
            "--flashinfer-mla-disable-ragged",
1438
            action="store_true",
1439
            help="Not using ragged prefill wrapper when running flashinfer mla",
1440
        )
1441
        parser.add_argument(
1442
1443
1444
            "--disable-shared-experts-fusion",
            action="store_true",
            help="Disable shared experts fusion optimization for deepseek v3/r1.",
1445
        )
1446
1447
1448
1449
1450
        parser.add_argument(
            "--disable-chunked-prefix-cache",
            action="store_true",
            help="Disable chunked prefix cache feature for deepseek, which should save overhead for short sequences.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1451
1452
1453
1454
1455
        parser.add_argument(
            "--disable-fast-image-processor",
            action="store_true",
            help="Adopt base image processor instead of fast image processor.",
        )
1456
1457
1458
1459
1460
        parser.add_argument(
            "--enable-return-hidden-states",
            action="store_true",
            help="Enable returning hidden states with responses.",
        )
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
        parser.add_argument(
            "--warmups",
            type=str,
            required=False,
            help="Specify custom warmup functions (csv) to run before server starts eg. --warmups=warmup_name1,warmup_name2 "
            "will run the functions `warmup_name1` and `warmup_name2` specified in warmup.py before the server starts listening for requests",
        )

        # Debug tensor dumps
        parser.add_argument(
            "--debug-tensor-dump-output-folder",
            type=str,
            default=ServerArgs.debug_tensor_dump_output_folder,
            help="The output folder for dumping tensors.",
        )
        parser.add_argument(
            "--debug-tensor-dump-input-file",
            type=str,
            default=ServerArgs.debug_tensor_dump_input_file,
            help="The input filename for dumping tensors",
        )
        parser.add_argument(
            "--debug-tensor-dump-inject",
            type=str,
            default=ServerArgs.debug_tensor_dump_inject,
            help="Inject the outputs from jax as the input of every layer.",
        )
1488
1489
1490
1491
1492
        parser.add_argument(
            "--debug-tensor-dump-prefill-only",
            action="store_true",
            help="Only dump the tensors for prefill requests (i.e. batch size > 1).",
        )
1493

Byron Hsu's avatar
Byron Hsu committed
1494
1495
1496
1497
1498
1499
1500
1501
        # Disaggregation
        parser.add_argument(
            "--disaggregation-mode",
            type=str,
            default="null",
            choices=["null", "prefill", "decode"],
            help='Only used for PD disaggregation. "prefill" for prefill-only server, and "decode" for decode-only server. If not specified, it is not PD disaggregated',
        )
1502
1503
1504
1505
        parser.add_argument(
            "--disaggregation-transfer-backend",
            type=str,
            default=ServerArgs.disaggregation_transfer_backend,
1506
            choices=["mooncake", "nixl"],
1507
1508
            help="The backend for disaggregation transfer. Default is mooncake.",
        )
1509
1510
1511
1512
1513
1514
        parser.add_argument(
            "--disaggregation-bootstrap-port",
            type=int,
            default=ServerArgs.disaggregation_bootstrap_port,
            help="Bootstrap server port on the prefill server. Default is 8998.",
        )
1515
1516
1517
1518
        parser.add_argument(
            "--disaggregation-ib-device",
            type=str,
            default=ServerArgs.disaggregation_ib_device,
1519
1520
1521
            help="The InfiniBand devices for disaggregation transfer, accepts single device (e.g., --disaggregation-ib-device mlx5_0) "
            "or multiple comma-separated devices (e.g., --disaggregation-ib-device mlx5_0,mlx5_1). "
            "Default is None, which triggers automatic device detection when mooncake backend is enabled.",
1522
        )
1523
1524
1525
1526
1527
1528
        parser.add_argument(
            "--num-reserved-decode-tokens",
            type=int,
            default=ServerArgs.num_reserved_decode_tokens,
            help="Number of decode tokens that will have memory reserved when adding new request to the running batch.",
        )
1529
1530
1531
1532
1533
1534
        parser.add_argument(
            "--pdlb-url",
            type=str,
            default=None,
            help="The URL of the PD disaggregation load balancer. If set, the prefill/decode server will register with the load balancer.",
        )
Byron Hsu's avatar
Byron Hsu committed
1535

Lianmin Zheng's avatar
Lianmin Zheng committed
1536
1537
    @classmethod
    def from_cli_args(cls, args: argparse.Namespace):
1538
        args.tp_size = args.tensor_parallel_size
1539
        args.pp_size = args.pipeline_parallel_size
1540
        args.dp_size = args.data_parallel_size
xiaobochen's avatar
xiaobochen committed
1541
        args.ep_size = args.expert_parallel_size
Lianmin Zheng's avatar
Lianmin Zheng committed
1542
1543
1544
1545
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        return cls(**{attr: getattr(args, attr) for attr in attrs})

    def url(self):
1546
        if is_valid_ipv6_address(self.host):
1547
1548
1549
            return f"http://[{self.host}]:{self.port}"
        else:
            return f"http://{self.host}:{self.port}"
Lianmin Zheng's avatar
Lianmin Zheng committed
1550

1551
1552
    def check_server_args(self):
        assert (
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
            self.tp_size * self.pp_size
        ) % self.nnodes == 0, "tp_size must be divisible by number of nodes"

        # FIXME pp constraints
        if self.pp_size > 1:
            assert (
                self.disable_overlap_schedule
                and self.speculative_algorithm is None
                and not self.enable_mixed_chunk
            ), "Pipeline parallelism is not compatible with overlap schedule, speculative decoding, mixed chunked prefill."

1564
        assert not (
1565
1566
            self.dp_size > 1 and self.nnodes != 1 and not self.enable_dp_attention
        ), "multi-node data parallel is not supported unless dp attention!"
1567
1568
1569
1570
        assert (
            self.max_loras_per_batch > 0
            # FIXME
            and (self.lora_paths is None or self.disable_radix_cache)
1571
        ), "compatibility of lora and radix attention is in progress"
1572
        assert self.base_gpu_id >= 0, "base_gpu_id must be non-negative"
1573
        assert self.gpu_id_step >= 1, "gpu_id_step must be positive"
1574

1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
        if isinstance(self.lora_paths, list):
            lora_paths = self.lora_paths
            self.lora_paths = {}
            for lora_path in lora_paths:
                if "=" in lora_path:
                    name, path = lora_path.split("=", 1)
                    self.lora_paths[name] = path
                else:
                    self.lora_paths[lora_path] = lora_path

Lianmin Zheng's avatar
Lianmin Zheng committed
1585

Lianmin Zheng's avatar
Lianmin Zheng committed
1586
def prepare_server_args(argv: List[str]) -> ServerArgs:
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
    """
    Prepare the server arguments from the command line arguments.

    Args:
        args: The command line arguments. Typically, it should be `sys.argv[1:]`
            to ensure compatibility with `parse_args` when no arguments are passed.

    Returns:
        The server arguments.
    """
    parser = argparse.ArgumentParser()
    ServerArgs.add_cli_args(parser)
Lianmin Zheng's avatar
Lianmin Zheng committed
1599
    raw_args = parser.parse_args(argv)
1600
1601
1602
1603
    server_args = ServerArgs.from_cli_args(raw_args)
    return server_args


1604
1605
1606
ZMQ_TCP_PORT_DELTA = 233


Lianmin Zheng's avatar
Lianmin Zheng committed
1607
1608
@dataclasses.dataclass
class PortArgs:
1609
1610
1611
1612
1613
1614
    # The ipc filename for tokenizer to receive inputs from detokenizer (zmq)
    tokenizer_ipc_name: str
    # The ipc filename for scheduler (rank 0) to receive inputs from tokenizer (zmq)
    scheduler_input_ipc_name: str
    # The ipc filename for detokenizer to receive inputs from scheduler (zmq)
    detokenizer_ipc_name: str
1615

1616
1617
    # The port for nccl initialization (torch.dist)
    nccl_port: int
1618

1619
1620
1621
    # The ipc filename for rpc call between Engine and Scheduler
    rpc_ipc_name: str

1622
    @staticmethod
1623
    def init_new(server_args, dp_rank: Optional[int] = None) -> "PortArgs":
1624
        port = server_args.port + random.randint(100, 1000)
1625
1626
1627
        while True:
            if is_port_available(port):
                break
TianYu GUO's avatar
TianYu GUO committed
1628
1629
1630
1631
            if port < 60000:
                port += 42
            else:
                port -= 43
1632

1633
1634
1635
1636
1637
1638
1639
        if not server_args.enable_dp_attention:
            # Normal case, use IPC within a single node
            return PortArgs(
                tokenizer_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
                scheduler_input_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
                detokenizer_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
                nccl_port=port,
1640
                rpc_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
1641
1642
1643
1644
1645
            )
        else:
            # DP attention. Use TCP + port to handle both single-node and multi-node.
            if server_args.nnodes == 1 and server_args.dist_init_addr is None:
                dist_init_addr = ("127.0.0.1", server_args.port + ZMQ_TCP_PORT_DELTA)
Vincent's avatar
Vincent committed
1646
1647
1648
            elif server_args.dist_init_addr.startswith("["):  # ipv6 address
                port_num, host = configure_ipv6(server_args.dist_init_addr)
                dist_init_addr = (host, str(port_num))
1649
1650
            else:
                dist_init_addr = server_args.dist_init_addr.split(":")
Vincent's avatar
Vincent committed
1651

1652
1653
1654
1655
1656
1657
1658
1659
            assert (
                len(dist_init_addr) == 2
            ), "please provide --dist-init-addr as host:port of head node"

            dist_init_host, dist_init_port = dist_init_addr
            port_base = int(dist_init_port) + 1
            if dp_rank is None:
                scheduler_input_port = (
1660
                    port_base + 3
1661
                )  # TokenizerManager to DataParallelController
1662
            else:
1663
                scheduler_input_port = port_base + 3 + 1 + dp_rank
1664
1665
1666
1667
1668
1669

            return PortArgs(
                tokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base}",
                scheduler_input_ipc_name=f"tcp://{dist_init_host}:{scheduler_input_port}",
                detokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base + 1}",
                nccl_port=port,
1670
                rpc_ipc_name=f"tcp://{dist_init_host}:{port_base + 2}",
1671
            )
1672

1673
1674
1675
1676
1677
1678
1679
1680
1681
1682

class LoRAPathAction(argparse.Action):
    def __call__(self, parser, namespace, values, option_string=None):
        setattr(namespace, self.dest, {})
        for lora_path in values:
            if "=" in lora_path:
                name, path = lora_path.split("=", 1)
                getattr(namespace, self.dest)[name] = path
            else:
                getattr(namespace, self.dest)[lora_path] = lora_path
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692


class DeprecatedAction(argparse.Action):
    def __init__(self, option_strings, dest, nargs=0, **kwargs):
        super(DeprecatedAction, self).__init__(
            option_strings, dest, nargs=nargs, **kwargs
        )

    def __call__(self, parser, namespace, values, option_string=None):
        raise ValueError(self.help)
1693
1694


1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
def get_model_arch(args: ServerArgs):
    hf_config = get_config(
        args.model_path,
        trust_remote_code=args.trust_remote_code,
        revision=args.revision,
        model_override_args=json.loads(args.json_model_override_args),
    )
    return hf_config.architectures[0]


1705
def auto_choose_speculative_params(self: ServerArgs):
1706
1707
1708
1709
1710
    """
    Automatically choose the parameters for speculative decoding.

    You can tune them on your own models and prompts with scripts/playground/bench_speculative.py
    """
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
    kwargs = {}

    hf_config = get_config(
        self.model_path,
        trust_remote_code=self.trust_remote_code,
        revision=self.revision,
        model_override_args=json.loads(self.json_model_override_args),
        **kwargs,
    )
    arch = hf_config.architectures[0]

1722
1723
1724
1725
1726
    if arch in ["LlamaForCausalLM"]:
        # The default value for llama
        return (5, 4, 8)
    elif arch in ["DeepseekV3ForCausalLM", "DeepseekV2ForCausalLM"]:
        # The default value for deepseek
1727
        return (3, 1, 4)
1728
1729
1730
1731
1732
    elif arch in ["Grok1ForCausalLM", "Grok1VForCausalLM"]:
        return (5, 4, 8)
    else:
        # The default value for all other models
        return (5, 4, 8)