server_args.py 93.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
Lianmin Zheng's avatar
Lianmin Zheng committed
14
15
"""The arguments of the server."""

Lianmin Zheng's avatar
Lianmin Zheng committed
16
17
import argparse
import dataclasses
18
import json
19
import logging
20
import os
21
import random
22
import sys
23
import tempfile
24
from typing import List, Literal, Optional, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
25

26
from sglang.srt.hf_transformers_utils import check_gguf_file, get_config
27
from sglang.srt.layers.utils import is_sm100_supported
28
from sglang.srt.lora.lora_registry import LoRARef
Xihuai Wang's avatar
Xihuai Wang committed
29
from sglang.srt.reasoning_parser import ReasoningParser
30
from sglang.srt.utils import (
31
32
    LORA_TARGET_ALL_MODULES,
    SUPPORTED_LORA_TARGET_MODULES,
Vincent's avatar
Vincent committed
33
    configure_ipv6,
34
    get_device,
Lianmin Zheng's avatar
Lianmin Zheng committed
35
    get_device_memory_capacity,
36
    is_flashinfer_available,
HAI's avatar
HAI committed
37
    is_hip,
38
    is_port_available,
39
    is_remote_url,
40
    is_triton_kernels_available,
41
    is_valid_ipv6_address,
bjmsong's avatar
bjmsong committed
42
    nullable_str,
43
)
44

45
46
logger = logging.getLogger(__name__)

Lianmin Zheng's avatar
Lianmin Zheng committed
47
48
49

@dataclasses.dataclass
class ServerArgs:
Lianmin Zheng's avatar
Lianmin Zheng committed
50
    # Model and tokenizer
Lianmin Zheng's avatar
Lianmin Zheng committed
51
52
53
    model_path: str
    tokenizer_path: Optional[str] = None
    tokenizer_mode: str = "auto"
54
    skip_tokenizer_init: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
55
    load_format: str = "auto"
56
    model_loader_extra_config: str = "{}"
57
    trust_remote_code: bool = False
58
    context_length: Optional[int] = None
59
    is_embedding: bool = False
60
    enable_multimodal: Optional[bool] = None
61
    revision: Optional[str] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
62
    model_impl: str = "auto"
Lianmin Zheng's avatar
Lianmin Zheng committed
63

Lianmin Zheng's avatar
Lianmin Zheng committed
64
    # HTTP server
Lianmin Zheng's avatar
Lianmin Zheng committed
65
66
    host: str = "127.0.0.1"
    port: int = 30000
Lianmin Zheng's avatar
Lianmin Zheng committed
67
68
    skip_server_warmup: bool = False
    warmups: Optional[str] = None
69
    nccl_port: Optional[int] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
70

Lianmin Zheng's avatar
Lianmin Zheng committed
71
72
73
74
75
76
    # Quantization and data type
    dtype: str = "auto"
    quantization: Optional[str] = None
    quantization_param_path: Optional[str] = None
    kv_cache_dtype: str = "auto"

Lianmin Zheng's avatar
Lianmin Zheng committed
77
    # Memory and scheduling
Lianmin Zheng's avatar
Lianmin Zheng committed
78
    mem_fraction_static: Optional[float] = None
79
    max_running_requests: Optional[int] = None
80
    max_queued_requests: Optional[int] = sys.maxsize
81
    max_total_tokens: Optional[int] = None
82
    chunked_prefill_size: Optional[int] = None
83
    max_prefill_tokens: int = 16384
84
    schedule_policy: str = "fcfs"
85
    schedule_conservativeness: float = 1.0
86
    cpu_offload_gb: int = 0
87
    page_size: Optional[int] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
88
89
90
    hybrid_kvcache_ratio: Optional[float] = None
    swa_full_tokens_ratio: float = 0.8
    disable_hybrid_swa_memory: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
91

Lianmin Zheng's avatar
Lianmin Zheng committed
92
93
    # Runtime options
    device: Optional[str] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
94
    tp_size: int = 1
95
96
    pp_size: int = 1
    max_micro_batch_size: Optional[int] = None
97
    stream_interval: int = 1
98
    stream_output: bool = False
99
    random_seed: Optional[int] = None
100
    constrained_json_whitespace_pattern: Optional[str] = None
101
    watchdog_timeout: float = 300
102
    dist_timeout: Optional[int] = None  # timeout for torch.distributed
103
    download_dir: Optional[str] = None
104
    base_gpu_id: int = 0
105
    gpu_id_step: int = 1
106
    sleep_on_idle: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
107
108
109

    # Logging
    log_level: str = "info"
110
    log_level_http: Optional[str] = None
111
    log_requests: bool = False
112
    log_requests_level: int = 0
113
    crash_dump_folder: Optional[str] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
114
    show_time_cost: bool = False
115
    enable_metrics: bool = False
116
    enable_metrics_for_all_schedulers: bool = False
117
118
    bucket_time_to_first_token: Optional[List[float]] = None
    bucket_inter_token_latency: Optional[List[float]] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
119
    bucket_e2e_request_latency: Optional[List[float]] = None
120
    collect_tokens_histogram: bool = False
121
    decode_log_interval: int = 40
122
    enable_request_time_stats_logging: bool = False
123
    kv_events_config: Optional[str] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
124

125
    # API related
126
    api_key: Optional[str] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
127
128
129
    served_model_name: Optional[str] = None
    chat_template: Optional[str] = None
    completion_template: Optional[str] = None
130
    file_storage_path: str = "sglang_storage"
131
    enable_cache_report: bool = False
Xihuai Wang's avatar
Xihuai Wang committed
132
    reasoning_parser: Optional[str] = None
133
    tool_call_parser: Optional[str] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
134

135
136
137
    # Data parallelism
    dp_size: int = 1
    load_balance_method: str = "round_robin"
138

139
    # Multi-node distributed serving
140
    dist_init_addr: Optional[str] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
141
    nnodes: int = 1
142
    node_rank: int = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
143
144
145

    # Model override args in JSON
    json_model_override_args: str = "{}"
146
    preferred_sampling_params: Optional[str] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
147

148
    # LoRA
149
    enable_lora: Optional[bool] = None
150
    max_lora_rank: Optional[int] = None
151
    lora_target_modules: Optional[Union[set[str], List[str]]] = None
152
    lora_paths: Optional[Union[dict[str, str], dict[str, LoRARef], List[str]]] = None
153
    max_loaded_loras: Optional[int] = None
154
    max_loras_per_batch: int = 8
155
    lora_backend: str = "triton"
156
157

    # Kernel backend
158
    attention_backend: Optional[str] = None
159
160
    decode_attention_backend: Optional[str] = None
    prefill_attention_backend: Optional[str] = None
161
    sampling_backend: Optional[str] = None
162
    grammar_backend: Optional[str] = None
163
    mm_attention_backend: Optional[str] = None
164

165
166
    # Speculative decoding
    speculative_algorithm: Optional[str] = None
167
    speculative_draft_model_path: Optional[str] = None
168
169
170
    speculative_num_steps: Optional[int] = None
    speculative_eagle_topk: Optional[int] = None
    speculative_num_draft_tokens: Optional[int] = None
171
172
    speculative_accept_threshold_single: float = 1.0
    speculative_accept_threshold_acc: float = 1.0
173
    speculative_token_map: Optional[str] = None
174

175
176
    # Expert parallelism
    ep_size: int = 1
177
    moe_a2a_backend: Optional[Literal["deepep"]] = None
178
179
    enable_flashinfer_cutlass_moe: bool = False
    enable_flashinfer_trtllm_moe: bool = False
180
    enable_flashinfer_allreduce_fusion: bool = False
181
    deepep_mode: Literal["auto", "normal", "low_latency"] = "auto"
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
    ep_num_redundant_experts: int = 0
    ep_dispatch_algorithm: Optional[Literal["static", "dynamic", "fake"]] = None
    init_expert_location: str = "trivial"
    enable_eplb: bool = False
    eplb_algorithm: str = "auto"
    eplb_rebalance_num_iterations: int = 1000
    eplb_rebalance_layers_per_chunk: Optional[int] = None
    expert_distribution_recorder_mode: Optional[
        Literal["stat", "stat_approx", "per_pass", "per_token"]
    ] = None
    expert_distribution_recorder_buffer_size: Optional[int] = None
    enable_expert_distribution_metrics: bool = False
    deepep_config: Optional[str] = None
    moe_dense_tp_size: Optional[int] = None

Lianmin Zheng's avatar
Lianmin Zheng committed
197
198
199
200
201
    # Hierarchical cache
    enable_hierarchical_cache: bool = False
    hicache_ratio: float = 2.0
    hicache_size: int = 0
    hicache_write_policy: str = "write_through_selective"
202
203
    hicache_io_backend: str = "kernel"
    hicache_mem_layout: str = "layer_first"
Lianmin Zheng's avatar
Lianmin Zheng committed
204
    hicache_storage_backend: Optional[str] = None
pansicheng's avatar
pansicheng committed
205
    hicache_storage_prefetch_policy: str = "best_effort"
Lianmin Zheng's avatar
Lianmin Zheng committed
206

207
208
    # Double Sparsity
    enable_double_sparsity: bool = False
Vincent's avatar
Vincent committed
209
    ds_channel_config_path: Optional[str] = None
210
211
212
213
214
    ds_heavy_channel_num: int = 32
    ds_heavy_token_num: int = 256
    ds_heavy_channel_type: str = "qk"
    ds_sparse_decode_threshold: int = 4096

215
    # Optimization/debug options
Lianmin Zheng's avatar
Lianmin Zheng committed
216
    disable_radix_cache: bool = False
217
218
    cuda_graph_max_bs: Optional[int] = None
    cuda_graph_bs: Optional[List[int]] = None
219
    disable_cuda_graph: bool = False
220
    disable_cuda_graph_padding: bool = False
221
    enable_profile_cuda_graph: bool = False
222
    enable_cudagraph_gc: bool = False
223
    enable_nccl_nvls: bool = False
224
    enable_symm_mem: bool = False
225
    enable_tokenizer_batch_encode: bool = False
226
    disable_outlines_disk_cache: bool = False
227
    disable_custom_all_reduce: bool = False
228
    enable_mscclpp: bool = False
229
    disable_overlap_schedule: bool = False
230
    enable_mixed_chunk: bool = False
Ke Bao's avatar
Ke Bao committed
231
    enable_dp_attention: bool = False
232
    enable_dp_lm_head: bool = False
233
    enable_two_batch_overlap: bool = False
234
    tbo_token_distribution_threshold: float = 0.48
235
    enable_torch_compile: bool = False
236
    torch_compile_max_bs: int = 32
237
    torchao_config: str = ""
238
    enable_nan_detection: bool = False
239
    enable_p2p_check: bool = False
240
    triton_attention_reduce_in_fp32: bool = False
241
    triton_attention_num_kv_splits: int = 8
242
    num_continuous_decode_steps: int = 1
243
    delete_ckpt_after_loading: bool = False
244
    enable_memory_saver: bool = False
245
    allow_auto_truncate: bool = False
246
    enable_custom_logit_processor: bool = False
247
    flashinfer_mla_disable_ragged: bool = False
248
    disable_shared_experts_fusion: bool = False
249
    disable_chunked_prefix_cache: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
250
    disable_fast_image_processor: bool = False
251
    enable_return_hidden_states: bool = False
Yuan Luo's avatar
Yuan Luo committed
252
    enable_triton_kernel_moe: bool = False
253
    enable_flashinfer_mxfp4_moe: bool = False
254
    scheduler_recv_interval: int = 1
255
256
257
258
259

    # Debug tensor dumps
    debug_tensor_dump_output_folder: Optional[str] = None
    debug_tensor_dump_input_file: Optional[str] = None
    debug_tensor_dump_inject: bool = False
260
    debug_tensor_dump_prefill_only: bool = False
261

Lianmin Zheng's avatar
Lianmin Zheng committed
262
    # PD disaggregation: can be "null" (not disaggregated), "prefill" (prefill-only), or "decode" (decode-only)
Byron Hsu's avatar
Byron Hsu committed
263
    disaggregation_mode: str = "null"
264
    disaggregation_transfer_backend: str = "mooncake"
265
    disaggregation_bootstrap_port: int = 8998
Byron Hsu's avatar
Byron Hsu committed
266
267
268
    disaggregation_decode_tp: Optional[int] = None
    disaggregation_decode_dp: Optional[int] = None
    disaggregation_prefill_pp: Optional[int] = 1
269
    disaggregation_ib_device: Optional[str] = None
270
    num_reserved_decode_tokens: int = 512  # used for decode kv cache offload in PD
271
    pdlb_url: Optional[str] = None
Byron Hsu's avatar
Byron Hsu committed
272

273
274
    # For model weight update
    custom_weight_loader: Optional[List[str]] = None
275
    weight_loader_disable_mmap: bool = False
276

277
278
279
280
    # For PD-Multiplexing
    enable_pdmux: bool = False
    sm_group_num: int = 3

281
282
283
    # For tool server
    tool_server: Optional[str] = None

284
285
286
287
    # Deprecated arguments
    enable_ep_moe: bool = False
    enable_deepep_moe: bool = False

Lianmin Zheng's avatar
Lianmin Zheng committed
288
    def __post_init__(self):
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304

        # Check deprecated arguments
        def print_deprecated_warning(message: str):
            logger.warning(f"\033[33m{message}\033[0m")

        if self.enable_ep_moe:
            self.ep_size = self.tp_size
            print_deprecated_warning(
                "NOTE: --enable-ep-moe is deprecated. Please set `--ep-size` to the same value as `--tp-size` instead."
            )
        if self.enable_deepep_moe:
            self.moe_a2a_backend = "deepep"
            print_deprecated_warning(
                "NOTE: --enable-deepep-moe is deprecated. Please set `--moe-a2a-backend` to 'deepep' instead."
            )

305
        # Set missing default values
Lianmin Zheng's avatar
Lianmin Zheng committed
306
307
        if self.tokenizer_path is None:
            self.tokenizer_path = self.model_path
308
309
        if self.served_model_name is None:
            self.served_model_name = self.model_path
310
311
        if self.device is None:
            self.device = get_device()
312
313
314
        if self.random_seed is None:
            self.random_seed = random.randint(0, 1 << 30)

Lianmin Zheng's avatar
Lianmin Zheng committed
315
        gpu_mem = get_device_memory_capacity(self.device)
316

317
        # Set mem fraction static
Lianmin Zheng's avatar
Lianmin Zheng committed
318
        if self.mem_fraction_static is None:
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
            if gpu_mem is not None:
                # GPU memory capacity = model weights + KV cache pool + activations + cuda graph buffers
                # mem_fraction_static = (model weights + KV cache pool) / GPU memory capacity.

                # We want mem_fraction_static to be as large as possible but still has enough room
                # for activations and cuda graph buffers. We use the following heuristic to
                # compute the needed size for activations and cuda graph buffers:
                # - The size of the activation depends on the chunked_prefill_size and model size.
                # - The size of cuda graph buffers depends on the cuda graph capture range and model size.
                # For GPUs with more memory, we use a larger chunked_prefill_size and
                # capture more cuda graphs, so they need to reserve more memory.
                parallel_size = self.tp_size * self.pp_size

                if gpu_mem < 20 * 1024:
                    # T4, 4080. (chunked_prefill_size 2k, cuda_graph_max_bs 8)
                    reserved_mem = (2.8 + parallel_size / 10) * 1024
                elif gpu_mem < 35 * 1024:
                    # A10, L40, 4090, 5090. (chunked_prefill_size 2k, cuda_graph_max_bs 8)
                    reserved_mem = (2.8 + parallel_size / 10) * 1024
                elif gpu_mem < 90 * 1024:
                    # H100, A100. (chunked_prefill_size 8k, cuda_graph_max_bs 160)
                    reserved_mem = (9.5 + parallel_size / 2) * 1024
                elif gpu_mem < 100 * 1024:
                    # H20. (chunked_prefill_size 8k, cuda_graph_max_bs 256)
                    reserved_mem = (12 + parallel_size / 2) * 1024
                elif gpu_mem < 160 * 1024:
                    # H200. (chunked_prefill_size 8k, cuda_graph_max_bs 256)
                    reserved_mem = (12 + parallel_size / 2) * 1024
347
                else:
348
349
350
                    # B200, MI300. (chunked_prefill_size 16k, cuda_graph_max_bs 512)
                    reserved_mem = 32 * 1024

351
                if self.speculative_algorithm is not None:
352
353
354
355
356
357
                    # draft model and larger cuda graph buffers
                    reserved_mem += 2 * 1024
                if self.enable_dp_attention:
                    reserved_mem += 4 * 1024

                self.mem_fraction_static = round((gpu_mem - reserved_mem) / gpu_mem, 3)
358
            else:
359
                self.mem_fraction_static = 0.88
360

361
            # Lazy init to avoid circular import
Lianmin Zheng's avatar
Lianmin Zheng committed
362
            # Multimodal models need more memory for the image processor
363
364
365
            from sglang.srt.configs.model_config import ModelConfig

            model_config = ModelConfig.from_server_args(self)
Lianmin Zheng's avatar
Lianmin Zheng committed
366
367
            if model_config.is_multimodal:
                self.adjust_mem_fraction_for_vlm(model_config)
368

369
370
        # Set chunked prefill size, which depends on the gpu memory capacity
        if self.chunked_prefill_size is None:
371
372
373
374
375
376
377
            if gpu_mem is not None:
                if gpu_mem < 35 * 1024:  # A10, L40, 4090
                    self.chunked_prefill_size = 2048
                elif gpu_mem < 160 * 1024:  # H100, H200, A100, H20
                    self.chunked_prefill_size = 8192
                else:  # B200, MI300
                    self.chunked_prefill_size = 16384
378
            else:
379
                self.chunked_prefill_size = 4096
Lianmin Zheng's avatar
Lianmin Zheng committed
380

381
382
383
384
385
386
387
388
389
        # Set cuda graph max batch size
        if self.cuda_graph_max_bs is None:
            # Based on detailed statistics, when serving TP1/TP2 models on lower-end GPUs with HBM<25G, you can either disable cuda graph or set `cuda_graph_max_bs` to a very small value to reduce the memory overhead of creating cuda graphs, with almost no impact on performance. However, when serving models with TP4 or TP8, we need to enable cuda graph to maintain high performance. In this case, we can set `cuda_graph_max_bs` to 80 (half of the default value 160) to reduce the memory overhead of creating cuda graphs. Looking at the logs from TP4 serving of qwen2-72b, a value of 80 is sufficient and can reduce the memory overhead of creating cuda graphs on lower-end GPUs compared to the original 160, avoiding OOM issues.
            if gpu_mem is not None and gpu_mem < 35 * 1024:
                if self.tp_size < 4:
                    self.cuda_graph_max_bs = 8
                else:
                    self.cuda_graph_max_bs = 80

390
        # Set kernel backends for hpu device
391
392
393
394
        if self.device == "hpu":
            self.attention_backend = "torch_native"
            self.sampling_backend = "pytorch"

Lianmin Zheng's avatar
Lianmin Zheng committed
395
        # Set kernel backends
396
397
398
399
400
        if self.device == "cpu":
            if self.attention_backend is None:
                self.attention_backend = "intel_amx"
            self.sampling_backend = "pytorch"

401
        if self.sampling_backend is None:
402
403
404
405
406
            self.sampling_backend = (
                "flashinfer" if is_flashinfer_available() else "pytorch"
            )

        if self.attention_backend == "torch_native":
407
            logger.warning(
408
409
410
                "Cuda graph is disabled because of using torch native attention backend"
            )
            self.disable_cuda_graph = True
411

412
413
414
415
416
417
        if self.attention_backend == "ascend":
            logger.warning(
                "At this moment Ascend attention backend only supports a page_size of 128, change page_size to 128."
            )
            self.page_size = 128

418
419
420
421
        if (
            self.attention_backend == "flashmla"
            or self.decode_attention_backend == "flashmla"
        ):
Lianmin Zheng's avatar
Lianmin Zheng committed
422
423
424
425
426
            logger.warning(
                "FlashMLA only supports a page_size of 64, change page_size to 64."
            )
            self.page_size = 64

427
428
429
430
        if (
            self.attention_backend == "cutlass_mla"
            or self.decode_attention_backend == "cutlass_mla"
        ):
Lianmin Zheng's avatar
Lianmin Zheng committed
431
432
433
434
435
            logger.warning(
                "Cutlass MLA only supports a page_size of 128, change page_size to 128."
            )
            self.page_size = 128

436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
        if self.attention_backend == "trtllm_mla":
            if not is_sm100_supported():
                raise ValueError(
                    "TRTLLM MLA backend is only supported on Blackwell GPUs (SM100). Please use a different backend."
                )

            if self.page_size not in [32, 64]:
                logger.warning(
                    f"TensorRT-LLM MLA only supports page_size of 32 or 64, changing page_size from {self.page_size} to 64."
                )
                self.page_size = 64
            if self.speculative_algorithm is not None:
                raise ValueError(
                    "trtllm_mla backend does not support speculative decoding yet."
                )

452
453
454
455
456
        if (
            self.attention_backend == "trtllm_mha"
            or self.decode_attention_backend == "trtllm_mha"
            or self.prefill_attention_backend == "trtllm_mha"
        ):
457
458
459
460
461
462
463
464
465
466
467
468
469
            if not is_sm100_supported():
                raise ValueError(
                    "TRTLLM MHA backend is only supported on Blackwell GPUs (SM100). Please use a different backend."
                )

            if self.page_size not in [16, 32, 64]:
                logger.warning(
                    f"TensorRT-LLM MHA only supports page_size of 16, 32 or 64, changing page_size from {self.page_size} to 64."
                )
                self.page_size = 64

            if self.speculative_algorithm is not None:
                raise ValueError(
470
                    "trtllm_mha backend does not support speculative decoding yet."
471
                )
472

473
474
        model_arch = self.get_hf_config().architectures[0]
        if model_arch in ["GptOssForCausalLM"]:
475
476
477
478
479
480
481
            if self.attention_backend is None:
                # default is triton, but we could have trtllm_mha as an option
                self.attention_backend = "triton"
            assert (
                self.attention_backend == "trtllm_mha"
                or self.attention_backend == "triton"
            )
482
483
484
485
486
487
488
            quantization_config = getattr(
                self.get_hf_config(), "quantization_config", None
            )
            is_mxfp4_quant_format = (
                quantization_config is not None
                and quantization_config.get("quant_method") == "mxfp4"
            )
Xiaoyu Zhang's avatar
Xiaoyu Zhang committed
489

490
            if is_sm100_supported() and is_mxfp4_quant_format:
491
492
                self.enable_flashinfer_mxfp4_moe = True
                self.enable_triton_kernel_moe = False
493
494
495
                logger.info(
                    "Detected SM100 and MXFP4 quantization format for GPT-OSS model, enabling FlashInfer MXFP4 MOE kernel."
                )
496
            else:
497
498
499
500
501
502
503
504
505
                if self.enable_triton_kernel_moe:
                    assert (
                        self.ep_size == 1
                    ), "Triton kernel MoE is only supported when ep_size == 1"
                if not self.enable_triton_kernel_moe and self.ep_size == 1:
                    self.enable_triton_kernel_moe = True
                    logger.info(
                        "Detected GPT-OSS model, enabling triton_kernels MOE kernel."
                    )
Xiaoyu Zhang's avatar
Xiaoyu Zhang committed
506

507
            self.disable_hybrid_swa_memory = True
508

509
            if is_mxfp4_quant_format:
Ying Sheng's avatar
Ying Sheng committed
510
511
512
                # use bf16 for mxfp4 triton kernels
                self.dtype = "bfloat16"

513
514
515
516
517
518
519
520
521
522
523
524
525
526
        if self.attention_backend == "dual_chunk_flash_attn":
            logger.warning(
                "Mixed chunk is disabled because of using dual chunk flash attention backend"
            )
            logger.warning(
                "Radix cache is disabled because of using dual chunk flash attention backend"
            )
            logger.warning(
                "Cuda graph is disabled because of using dual chunk flash attention backend"
            )
            self.enable_mixed_chunk = False
            self.disable_cuda_graph = True
            self.disable_radix_cache = True

527
528
529
530
531
532
533
534
        # Set page size
        if self.page_size is None:
            self.page_size = 1

        # AMD-specific Triton attention KV splits default number
        if is_hip():
            self.triton_attention_num_kv_splits = 16

535
536
537
        # Choose grammar backend
        if self.grammar_backend is None:
            self.grammar_backend = "xgrammar"
538

539
        # Data parallelism attention
Ke Bao's avatar
Ke Bao committed
540
        if self.enable_dp_attention:
541
            self.schedule_conservativeness = self.schedule_conservativeness * 0.3
Lianmin Zheng's avatar
Lianmin Zheng committed
542
543
544
545
546
            assert (
                self.dp_size > 1
            ), "Please set a dp-size > 1. You can use 1 < dp-size <= tp-size "
            assert self.tp_size % self.dp_size == 0
            self.chunked_prefill_size = self.chunked_prefill_size // self.dp_size
547
            logger.warning(
548
                f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE kernel issues. "
549
            )
550

551
552
553
        if self.enable_dp_lm_head:
            assert (
                self.enable_dp_attention
554
            ), "Please enable dp attention when setting enable_dp_lm_head. "
555

556
        # MoE kernel
557
        if self.enable_flashinfer_cutlass_moe:
558
559
560
561
            assert (
                self.quantization == "modelopt_fp4"
            ), "modelopt_fp4 quantization is required for Flashinfer MOE"
            os.environ["TRTLLM_ENABLE_PDL"] = "1"
562
563
564
565
            assert self.ep_size in [
                1,
                self.tp_size,
            ], "The expert parallel size must be 1 or the same as the tensor parallel size"
566

567
568
569
570
571
572
573
        if self.enable_flashinfer_trtllm_moe:
            if not self.disable_shared_experts_fusion:
                self.disable_shared_experts_fusion = True
                logger.warning(
                    "FlashInfer TRTLLM MoE is enabled. --disable-shared-experts-fusion is automatically set."
                )

574
        # DeepEP MoE
575
        if self.moe_a2a_backend == "deepep":
576
577
578
            if self.deepep_mode == "normal":
                logger.warning("Cuda graph is disabled because deepep_mode=`normal`")
                self.disable_cuda_graph = True
579
            self.ep_size = self.tp_size
Lianmin Zheng's avatar
Lianmin Zheng committed
580
            logger.warning(
581
582
                f"DeepEP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
            )
583

584
585
586
        if self.enable_eplb and (self.expert_distribution_recorder_mode is None):
            self.expert_distribution_recorder_mode = "stat"
            logger.info(
587
                "EPLB is enabled. The expert_distribution_recorder_mode is automatically set."
588
589
590
591
592
593
594
            )

        if (self.enable_eplb or (self.init_expert_location is not None)) and (
            self.ep_dispatch_algorithm is None
        ):
            self.ep_dispatch_algorithm = "static"
            logger.info(
595
                "EPLB is enabled or init_expert_location is provided. ep_dispatch_algorithm is configured."
596
597
            )

598
        if self.enable_eplb:
599
            assert self.ep_size > 1 or self.moe_a2a_backend is not None
600

601
602
603
604
605
        if self.enable_expert_distribution_metrics and (
            self.expert_distribution_recorder_mode is None
        ):
            self.expert_distribution_recorder_mode = "stat"

606
        if self.expert_distribution_recorder_buffer_size is None:
607
608
            if (x := self.eplb_rebalance_num_iterations) is not None:
                self.expert_distribution_recorder_buffer_size = x
609
610
611
            elif self.expert_distribution_recorder_mode is not None:
                self.expert_distribution_recorder_buffer_size = 1000

Lianmin Zheng's avatar
Lianmin Zheng committed
612
613
614
615
616
617
618
        # Pipeline parallelism
        if self.pp_size > 1:
            self.disable_overlap_schedule = True
            logger.warning(
                "Pipeline parallelism is incompatible with overlap schedule."
            )

619
        # Speculative Decoding
620
621
622
623
        if self.speculative_algorithm == "NEXTN":
            # NEXTN shares the same implementation of EAGLE
            self.speculative_algorithm = "EAGLE"

Lianmin Zheng's avatar
Lianmin Zheng committed
624
        if self.speculative_algorithm in ("EAGLE", "EAGLE3"):
625
            if self.max_running_requests is None:
626
                self.max_running_requests = 48
627
            self.disable_overlap_schedule = True
Lianmin Zheng's avatar
Lianmin Zheng committed
628
            logger.warning(
629
                "Overlap scheduler is disabled because of using "
630
                "eagle speculative decoding."
631
            )
632
633
634
635
636
637
            if self.enable_mixed_chunk:
                self.enable_mixed_chunk = False
                logger.warning(
                    "Mixed chunked prefill is disabled because of using "
                    "eagle speculative decoding."
                )
638

Lianmin Zheng's avatar
Lianmin Zheng committed
639
            model_arch = self.get_hf_config().architectures[0]
Yuxuan Zhang's avatar
Yuxuan Zhang committed
640
            if model_arch in ["DeepseekV3ForCausalLM", "Glm4MoeForCausalLM"]:
Hanming Lu's avatar
Hanming Lu committed
641
                # Auto set draft_model_path DeepSeek-V3/R1
642
643
644
645
646
647
                if self.speculative_draft_model_path is None:
                    self.speculative_draft_model_path = self.model_path
                else:
                    logger.warning(
                        "DeepSeek MTP does not require setting speculative_draft_model_path."
                    )
648

649
650
651
652
653
654
655
656
657
658
            # Auto choose parameters
            if self.speculative_num_steps is None:
                assert (
                    self.speculative_eagle_topk is None
                    and self.speculative_num_draft_tokens is None
                )
                (
                    self.speculative_num_steps,
                    self.speculative_eagle_topk,
                    self.speculative_num_draft_tokens,
659
                ) = auto_choose_speculative_params(self)
660

661
662
663
664
            if (
                self.speculative_eagle_topk == 1
                and self.speculative_num_draft_tokens != self.speculative_num_steps + 1
            ):
Lianmin Zheng's avatar
Lianmin Zheng committed
665
                logger.warning(
666
667
668
                    "speculative_num_draft_tokens is adjusted to speculative_num_steps + 1 when speculative_eagle_topk == 1"
                )
                self.speculative_num_draft_tokens = self.speculative_num_steps + 1
669

670
            # The token generated from the verify step is counted.
671
            # If sepculative_num_steps >= speculative_num_draft_tokens, the additional tokens will definitely be discarded.
672
            # assert self.speculative_num_steps < self.speculative_num_draft_tokens
673

674
675
676
677
678
679
        # GGUF
        if (
            self.load_format == "auto" or self.load_format == "gguf"
        ) and check_gguf_file(self.model_path):
            self.quantization = self.load_format = "gguf"

680
        # Model loading
681
682
        if is_remote_url(self.model_path):
            self.load_format = "remote"
683
684
        if self.custom_weight_loader is None:
            self.custom_weight_loader = []
685

Byron Hsu's avatar
Byron Hsu committed
686
        # PD disaggregation
Byron Hsu's avatar
Byron Hsu committed
687
688
689
690
691
692
693
694
        if self.disaggregation_mode == "decode":
            assert (
                self.disaggregation_decode_tp is None
            ), "Cannot set --disaggregation-decode-tp for the decode engine."
            assert (
                self.disaggregation_decode_dp is None
            ), "Cannot set --disaggregation-decode-dp for the decode engine."

Byron Hsu's avatar
Byron Hsu committed
695
            self.disable_radix_cache = True
696
            logger.warning("KV cache is forced as chunk cache for decode server")
Byron Hsu's avatar
Byron Hsu committed
697
698
699
700
701
702
703
704
705
706
707
        elif self.disaggregation_mode == "prefill":
            if self.disaggregation_decode_tp is None:
                self.disaggregation_decode_tp = self.tp_size
            if self.disaggregation_decode_dp is None:
                self.disaggregation_decode_dp = self.dp_size

            self.disaggregation_prefill_pp = self.pp_size
            self.validate_disagg_tp_size(self.tp_size, self.disaggregation_decode_tp)

            self.disable_cuda_graph = True
            logger.warning("Cuda graph is disabled for prefill server")
Byron Hsu's avatar
Byron Hsu committed
708

709
        # Propagate env vars
710
711
712
        os.environ["SGLANG_ENABLE_TORCH_COMPILE"] = (
            "1" if self.enable_torch_compile else "0"
        )
713
714
715
716
        # Set env var before grammar backends init
        os.environ["SGLANG_DISABLE_OUTLINES_DISK_CACHE"] = (
            "1" if self.disable_outlines_disk_cache else "0"
        )
717

Lianmin Zheng's avatar
Lianmin Zheng committed
718
719
    @staticmethod
    def add_cli_args(parser: argparse.ArgumentParser):
Lianmin Zheng's avatar
Lianmin Zheng committed
720
        # Model and tokenizer
Lianmin Zheng's avatar
Lianmin Zheng committed
721
722
        parser.add_argument(
            "--model-path",
723
            "--model",
Lianmin Zheng's avatar
Lianmin Zheng committed
724
725
726
727
728
729
730
731
732
733
            type=str,
            help="The path of the model weights. This can be a local folder or a Hugging Face repo ID.",
            required=True,
        )
        parser.add_argument(
            "--tokenizer-path",
            type=str,
            default=ServerArgs.tokenizer_path,
            help="The path of the tokenizer.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
734
735
736
737
738
739
740
741
742
        parser.add_argument(
            "--tokenizer-mode",
            type=str,
            default=ServerArgs.tokenizer_mode,
            choices=["auto", "slow"],
            help="Tokenizer mode. 'auto' will use the fast "
            "tokenizer if available, and 'slow' will "
            "always use the slow tokenizer.",
        )
743
744
745
        parser.add_argument(
            "--skip-tokenizer-init",
            action="store_true",
746
            help="If set, skip init tokenizer and pass input_ids in generate request.",
747
        )
748
        parser.add_argument(
Lianmin Zheng's avatar
Lianmin Zheng committed
749
750
751
            "--load-format",
            type=str,
            default=ServerArgs.load_format,
752
753
754
755
756
757
            choices=[
                "auto",
                "pt",
                "safetensors",
                "npcache",
                "dummy",
758
                "sharded_state",
759
760
                "gguf",
                "bitsandbytes",
761
                "layered",
762
                "remote",
763
            ],
Lianmin Zheng's avatar
Lianmin Zheng committed
764
765
766
767
768
769
770
771
772
            help="The format of the model weights to load. "
            '"auto" will try to load the weights in the safetensors format '
            "and fall back to the pytorch bin format if safetensors format "
            "is not available. "
            '"pt" will load the weights in the pytorch bin format. '
            '"safetensors" will load the weights in the safetensors format. '
            '"npcache" will load the weights in pytorch format and store '
            "a numpy cache to speed up the loading. "
            '"dummy" will initialize the weights with random values, '
773
            "which is mainly for profiling."
774
775
            '"gguf" will load the weights in the gguf format. '
            '"bitsandbytes" will load the weights using bitsandbytes '
776
777
778
779
            "quantization."
            '"layered" loads weights layer by layer so that one can quantize a '
            "layer before loading another to make the peak memory envelope "
            "smaller.",
Lianmin Zheng's avatar
Lianmin Zheng committed
780
        )
781
782
783
784
785
786
787
        parser.add_argument(
            "--model-loader-extra-config",
            type=str,
            help="Extra config for model loader. "
            "This will be passed to the model loader corresponding to the chosen load_format.",
            default=ServerArgs.model_loader_extra_config,
        )
788
789
790
791
792
        parser.add_argument(
            "--trust-remote-code",
            action="store_true",
            help="Whether or not to allow for custom models defined on the Hub in their own modeling files.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
        parser.add_argument(
            "--context-length",
            type=int,
            default=ServerArgs.context_length,
            help="The model's maximum context length. Defaults to None (will use the value from the model's config.json instead).",
        )
        parser.add_argument(
            "--is-embedding",
            action="store_true",
            help="Whether to use a CausalLM as an embedding model.",
        )
        parser.add_argument(
            "--enable-multimodal",
            default=ServerArgs.enable_multimodal,
            action="store_true",
            help="Enable the multimodal functionality for the served model. If the model being served is not multimodal, nothing will happen",
        )
        parser.add_argument(
            "--revision",
            type=str,
            default=None,
            help="The specific model version to use. It can be a branch "
            "name, a tag name, or a commit id. If unspecified, will use "
            "the default version.",
        )
        parser.add_argument(
            "--model-impl",
            type=str,
            default=ServerArgs.model_impl,
            help="Which implementation of the model to use.\n\n"
            '* "auto" will try to use the SGLang implementation if it exists '
            "and fall back to the Transformers implementation if no SGLang "
            "implementation is available.\n"
            '* "sglang" will use the SGLang model implementation.\n'
            '* "transformers" will use the Transformers model '
            "implementation.\n",
        )

        # HTTP server
        parser.add_argument(
            "--host",
            type=str,
            default=ServerArgs.host,
            help="The host of the HTTP server.",
        )
        parser.add_argument(
            "--port",
            type=int,
            default=ServerArgs.port,
            help="The port of the HTTP server.",
        )
        parser.add_argument(
            "--skip-server-warmup",
            action="store_true",
            help="If set, skip warmup.",
        )
        parser.add_argument(
            "--warmups",
            type=str,
            required=False,
            help="Specify custom warmup functions (csv) to run before server starts eg. --warmups=warmup_name1,warmup_name2 "
            "will run the functions `warmup_name1` and `warmup_name2` specified in warmup.py before the server starts listening for requests",
        )
        parser.add_argument(
            "--nccl-port",
            type=int,
            default=ServerArgs.nccl_port,
            help="The port for NCCL distributed environment setup. Defaults to a random port.",
        )

        # Quantization and data type
Lianmin Zheng's avatar
Lianmin Zheng committed
864
        parser.add_argument(
Lianmin Zheng's avatar
Lianmin Zheng committed
865
            "--dtype",
Cody Yu's avatar
Cody Yu committed
866
            type=str,
Lianmin Zheng's avatar
Lianmin Zheng committed
867
            default=ServerArgs.dtype,
Ying Sheng's avatar
Ying Sheng committed
868
869
            choices=["auto", "half", "float16", "bfloat16", "float", "float32"],
            help="Data type for model weights and activations.\n\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
870
            '* "auto" will use FP16 precision for FP32 and FP16 models, and '
Ying Sheng's avatar
Ying Sheng committed
871
            "BF16 precision for BF16 models.\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
872
873
874
875
            '* "half" for FP16. Recommended for AWQ quantization.\n'
            '* "float16" is the same as "half".\n'
            '* "bfloat16" for a balance between precision and range.\n'
            '* "float" is shorthand for FP32 precision.\n'
Ying Sheng's avatar
Ying Sheng committed
876
877
            '* "float32" for FP32 precision.',
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
878
879
880
881
        parser.add_argument(
            "--quantization",
            type=str,
            default=ServerArgs.quantization,
Ying Sheng's avatar
Ying Sheng committed
882
883
884
885
886
887
            choices=[
                "awq",
                "fp8",
                "gptq",
                "marlin",
                "gptq_marlin",
Ying Sheng's avatar
Ying Sheng committed
888
                "awq_marlin",
Ying Sheng's avatar
Ying Sheng committed
889
                "bitsandbytes",
890
                "gguf",
891
                "modelopt",
892
                "modelopt_fp4",
893
                "petit_nvfp4",
894
                "w8a8_int8",
HandH1998's avatar
HandH1998 committed
895
                "w8a8_fp8",
AniZpZ's avatar
AniZpZ committed
896
                "moe_wna16",
HandH1998's avatar
HandH1998 committed
897
                "qoq",
898
                "w4afp8",
899
                "mxfp4",
Ying Sheng's avatar
Ying Sheng committed
900
            ],
Lianmin Zheng's avatar
Lianmin Zheng committed
901
902
            help="The quantization method.",
        )
903
904
905
906
907
908
909
910
911
        parser.add_argument(
            "--quantization-param-path",
            type=nullable_str,
            default=None,
            help="Path to the JSON file containing the KV cache "
            "scaling factors. This should generally be supplied, when "
            "KV cache dtype is FP8. Otherwise, KV cache scaling factors "
            "default to 1.0, which may cause accuracy issues. ",
        )
912
        parser.add_argument(
Lianmin Zheng's avatar
Lianmin Zheng committed
913
            "--kv-cache-dtype",
914
            type=str,
Lianmin Zheng's avatar
Lianmin Zheng committed
915
916
917
            default=ServerArgs.kv_cache_dtype,
            choices=["auto", "fp8_e5m2", "fp8_e4m3"],
            help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" and "fp8_e4m3" is supported for CUDA 11.8+.',
918
        )
919

920
        # Memory and scheduling
Lianmin Zheng's avatar
Lianmin Zheng committed
921
922
923
924
        parser.add_argument(
            "--mem-fraction-static",
            type=float,
            default=ServerArgs.mem_fraction_static,
925
            help="The fraction of the memory used for static allocation (model weights and KV cache memory pool). Use a smaller value if you see out-of-memory errors.",
Lianmin Zheng's avatar
Lianmin Zheng committed
926
        )
927
928
929
930
931
932
        parser.add_argument(
            "--max-running-requests",
            type=int,
            default=ServerArgs.max_running_requests,
            help="The maximum number of running requests.",
        )
933
934
935
936
937
938
        parser.add_argument(
            "--max-queued-requests",
            type=int,
            default=ServerArgs.max_queued_requests,
            help="The maximum number of queued requests. This option is ignored when using disaggregation-mode.",
        )
939
940
941
942
        parser.add_argument(
            "--max-total-tokens",
            type=int,
            default=ServerArgs.max_total_tokens,
943
944
            help="The maximum number of tokens in the memory pool. If not specified, it will be automatically calculated based on the memory usage fraction. "
            "This option is typically used for development and debugging purposes.",
945
        )
946
947
948
949
        parser.add_argument(
            "--chunked-prefill-size",
            type=int,
            default=ServerArgs.chunked_prefill_size,
950
            help="The maximum number of tokens in a chunk for the chunked prefill. Setting this to -1 means disabling chunked prefill.",
951
952
953
954
955
956
957
        )
        parser.add_argument(
            "--max-prefill-tokens",
            type=int,
            default=ServerArgs.max_prefill_tokens,
            help="The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model's maximum context length.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
958
        parser.add_argument(
959
            "--schedule-policy",
Lianmin Zheng's avatar
Lianmin Zheng committed
960
            type=str,
961
            default=ServerArgs.schedule_policy,
962
            choices=["lpm", "random", "fcfs", "dfs-weight", "lof"],
963
            help="The scheduling policy of the requests.",
Lianmin Zheng's avatar
Lianmin Zheng committed
964
        )
965
966
967
968
        parser.add_argument(
            "--schedule-conservativeness",
            type=float,
            default=ServerArgs.schedule_conservativeness,
969
            help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see requests being retracted frequently.",
970
        )
971
972
973
974
        parser.add_argument(
            "--cpu-offload-gb",
            type=int,
            default=ServerArgs.cpu_offload_gb,
975
            help="How many GBs of RAM to reserve for CPU offloading.",
976
        )
977
978
979
980
981
982
        parser.add_argument(
            "--page-size",
            type=int,
            default=ServerArgs.page_size,
            help="The number of tokens in a page.",
        )
tarinkk's avatar
tarinkk committed
983
984
985
986
987
988
989
990
991
992
993
994
        parser.add_argument(
            "--hybrid-kvcache-ratio",
            nargs="?",
            const=0.5,
            type=float,
            default=ServerArgs.hybrid_kvcache_ratio,
            help=(
                "Mix ratio in [0,1] between uniform and hybrid kv buffers "
                "(0.0 = pure uniform: swa_size / full_size = 1)"
                "(1.0 = pure hybrid: swa_size / full_size = local_attention_size / context_length)"
            ),
        )
Hanming Lu's avatar
Hanming Lu committed
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
        parser.add_argument(
            "--swa-full-tokens-ratio",
            type=float,
            default=ServerArgs.swa_full_tokens_ratio,
            help="The ratio of SWA layer KV tokens / full layer KV tokens, regardless of the number of swa:full layers. It should be between 0 and 1. "
            "E.g. 0.5 means if each swa layer has 50 tokens, then each full layer has 100 tokens.",
        )
        parser.add_argument(
            "--disable-hybrid-swa-memory",
            action="store_true",
            help="Disable the hybrid SWA memory.",
        )
1007

Lianmin Zheng's avatar
Lianmin Zheng committed
1008
1009
1010
1011
1012
1013
1014
        # Runtime options
        parser.add_argument(
            "--device",
            type=str,
            default=ServerArgs.device,
            help="The device to use ('cuda', 'xpu', 'hpu', 'npu', 'cpu'). Defaults to auto-detection if not specified.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1015
        parser.add_argument(
1016
            "--tensor-parallel-size",
Lianmin Zheng's avatar
Lianmin Zheng committed
1017
            "--tp-size",
Lianmin Zheng's avatar
Lianmin Zheng committed
1018
            type=int,
Lianmin Zheng's avatar
Lianmin Zheng committed
1019
            default=ServerArgs.tp_size,
1020
            help="The tensor parallelism size.",
1021
        )
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
        parser.add_argument(
            "--pipeline-parallel-size",
            "--pp-size",
            type=int,
            default=ServerArgs.pp_size,
            help="The pipeline parallelism size.",
        )
        parser.add_argument(
            "--max-micro-batch-size",
            type=int,
            default=ServerArgs.max_micro_batch_size,
            help="The maximum micro batch size in pipeline parallelism.",
        )
1035
1036
1037
        parser.add_argument(
            "--stream-interval",
            type=int,
Lianmin Zheng's avatar
Lianmin Zheng committed
1038
            default=ServerArgs.stream_interval,
1039
            help="The interval (or buffer size) for streaming in terms of the token length. A smaller value makes streaming smoother, while a larger value makes the throughput higher",
1040
        )
1041
1042
1043
1044
1045
        parser.add_argument(
            "--stream-output",
            action="store_true",
            help="Whether to output as a sequence of disjoint segments.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1046
1047
1048
1049
        parser.add_argument(
            "--random-seed",
            type=int,
            default=ServerArgs.random_seed,
1050
            help="The random seed.",
Lianmin Zheng's avatar
Lianmin Zheng committed
1051
        )
1052
1053
1054
1055
        parser.add_argument(
            "--constrained-json-whitespace-pattern",
            type=str,
            default=ServerArgs.constrained_json_whitespace_pattern,
Lianmin Zheng's avatar
Lianmin Zheng committed
1056
            help="(outlines backend only) Regex pattern for syntactic whitespaces allowed in JSON constrained output. For example, to allow the model generate consecutive whitespaces, set the pattern to [\n\t ]*",
1057
        )
1058
1059
1060
1061
1062
1063
        parser.add_argument(
            "--watchdog-timeout",
            type=float,
            default=ServerArgs.watchdog_timeout,
            help="Set watchdog timeout in seconds. If a forward batch takes longer than this, the server will crash to prevent hanging.",
        )
1064
1065
1066
1067
1068
1069
        parser.add_argument(
            "--dist-timeout",
            type=int,
            default=ServerArgs.dist_timeout,
            help="Set timeout for torch.distributed initialization.",
        )
1070
1071
1072
1073
        parser.add_argument(
            "--download-dir",
            type=str,
            default=ServerArgs.download_dir,
1074
            help="Model download directory for huggingface.",
1075
        )
1076
1077
1078
1079
1080
1081
        parser.add_argument(
            "--base-gpu-id",
            type=int,
            default=ServerArgs.base_gpu_id,
            help="The base GPU ID to start allocating GPUs from. Useful when running multiple instances on the same machine.",
        )
1082
1083
1084
1085
1086
1087
        parser.add_argument(
            "--gpu-id-step",
            type=int,
            default=ServerArgs.gpu_id_step,
            help="The delta between consecutive GPU IDs that are used. For example, setting it to 2 will use GPU 0,2,4,...",
        )
1088
1089
1090
1091
1092
        parser.add_argument(
            "--sleep-on-idle",
            action="store_true",
            help="Reduce CPU usage when sglang is idle.",
        )
1093
1094

        # Logging
Lianmin Zheng's avatar
Lianmin Zheng committed
1095
1096
1097
1098
        parser.add_argument(
            "--log-level",
            type=str,
            default=ServerArgs.log_level,
1099
            help="The logging level of all loggers.",
Lianmin Zheng's avatar
Lianmin Zheng committed
1100
        )
1101
        parser.add_argument(
1102
1103
1104
1105
            "--log-level-http",
            type=str,
            default=ServerArgs.log_level_http,
            help="The logging level of HTTP server. If not set, reuse --log-level by default.",
1106
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1107
        parser.add_argument(
1108
            "--log-requests",
Lianmin Zheng's avatar
Lianmin Zheng committed
1109
            action="store_true",
1110
1111
1112
1113
1114
1115
            help="Log metadata, inputs, outputs of all requests. The verbosity is decided by --log-requests-level",
        )
        parser.add_argument(
            "--log-requests-level",
            type=int,
            default=0,
1116
1117
1118
1119
1120
1121
1122
1123
            help="0: Log metadata (no sampling parameters). 1: Log metadata and sampling parameters. 2: Log metadata, sampling parameters and partial input/output. 3: Log every input/output.",
            choices=[0, 1, 2, 3],
        )
        parser.add_argument(
            "--crash-dump-folder",
            type=str,
            default=ServerArgs.crash_dump_folder,
            help="Folder path to dump requests from the last 5 min before a crash (if any). If not specified, crash dumping is disabled.",
Lianmin Zheng's avatar
Lianmin Zheng committed
1124
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1125
1126
1127
        parser.add_argument(
            "--show-time-cost",
            action="store_true",
Ying Sheng's avatar
Ying Sheng committed
1128
            help="Show time cost of custom marks.",
Lianmin Zheng's avatar
Lianmin Zheng committed
1129
        )
1130
1131
1132
1133
1134
        parser.add_argument(
            "--enable-metrics",
            action="store_true",
            help="Enable log prometheus metrics.",
        )
1135
1136
1137
1138
1139
1140
1141
        parser.add_argument(
            "--enable-metrics-for-all-schedulers",
            action="store_true",
            help="Enable --enable-metrics-for-all-schedulers when you want schedulers on all TP ranks (not just TP 0) "
            "to record request metrics separately. This is especially useful when dp_attention is enabled, as "
            "otherwise all metrics appear to come from TP 0.",
        )
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
        parser.add_argument(
            "--bucket-time-to-first-token",
            type=float,
            nargs="+",
            default=ServerArgs.bucket_time_to_first_token,
            help="The buckets of time to first token, specified as a list of floats.",
        )
        parser.add_argument(
            "--bucket-inter-token-latency",
            type=float,
            nargs="+",
            default=ServerArgs.bucket_inter_token_latency,
            help="The buckets of inter-token latency, specified as a list of floats.",
        )
        parser.add_argument(
            "--bucket-e2e-request-latency",
            type=float,
            nargs="+",
            default=ServerArgs.bucket_e2e_request_latency,
            help="The buckets of end-to-end request latency, specified as a list of floats.",
        )
        parser.add_argument(
            "--collect-tokens-histogram",
            action="store_true",
            default=ServerArgs.collect_tokens_histogram,
            help="Collect prompt/generation tokens histogram.",
        )
1169
1170
1171
1172
        parser.add_argument(
            "--decode-log-interval",
            type=int,
            default=ServerArgs.decode_log_interval,
1173
            help="The log interval of decode batch.",
1174
        )
1175
1176
1177
1178
1179
1180
        parser.add_argument(
            "--enable-request-time-stats-logging",
            action="store_true",
            default=ServerArgs.enable_request_time_stats_logging,
            help="Enable per request time stats logging",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1181
1182
1183
1184
1185
1186
        parser.add_argument(
            "--kv-events-config",
            type=str,
            default=None,
            help="Config in json format for NVIDIA dynamo KV event publishing. Publishing will be enabled if this flag is used.",
        )
1187

1188
        # API related
Liangsheng Yin's avatar
Liangsheng Yin committed
1189
1190
1191
1192
        parser.add_argument(
            "--api-key",
            type=str,
            default=ServerArgs.api_key,
1193
            help="Set API key of the server. It is also used in the OpenAI API compatible server.",
Liangsheng Yin's avatar
Liangsheng Yin committed
1194
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
        parser.add_argument(
            "--served-model-name",
            type=str,
            default=ServerArgs.served_model_name,
            help="Override the model name returned by the v1/models endpoint in OpenAI API server.",
        )
        parser.add_argument(
            "--chat-template",
            type=str,
            default=ServerArgs.chat_template,
            help="The buliltin chat template name or the path of the chat template file. This is only used for OpenAI-compatible API server.",
        )
        parser.add_argument(
            "--completion-template",
            type=str,
            default=ServerArgs.completion_template,
            help="The buliltin completion template name or the path of the completion template file. This is only used for OpenAI-compatible API server. only for code completion currently.",
        )
1213
        parser.add_argument(
1214
            "--file-storage-path",
1215
            type=str,
1216
            default=ServerArgs.file_storage_path,
1217
1218
            help="The path of the file storage in backend.",
        )
1219
1220
1221
1222
1223
        parser.add_argument(
            "--enable-cache-report",
            action="store_true",
            help="Return number of cached tokens in usage.prompt_tokens_details for each openai request.",
        )
Xihuai Wang's avatar
Xihuai Wang committed
1224
1225
1226
1227
1228
1229
1230
        parser.add_argument(
            "--reasoning-parser",
            type=str,
            choices=list(ReasoningParser.DetectorMap.keys()),
            default=ServerArgs.reasoning_parser,
            help=f"Specify the parser for reasoning models, supported parsers are: {list(ReasoningParser.DetectorMap.keys())}.",
        )
1231
1232
1233
        parser.add_argument(
            "--tool-call-parser",
            type=str,
Atream's avatar
Atream committed
1234
1235
1236
1237
1238
1239
1240
            choices=[
                "qwen25",
                "mistral",
                "llama3",
                "deepseekv3",
                "pythonic",
                "kimi_k2",
1241
                "qwen3_coder",
Yuxuan Zhang's avatar
Yuxuan Zhang committed
1242
                "glm45",
Chang Su's avatar
Chang Su committed
1243
                "step3",
Atream's avatar
Atream committed
1244
            ],
1245
            default=ServerArgs.tool_call_parser,
Chang Su's avatar
Chang Su committed
1246
            help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', 'llama3', 'deepseekv3', 'pythonic', 'kimi_k2', 'qwen3_coder', 'glm45', and 'step3'.",
1247
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1248

1249
1250
        # Data parallelism
        parser.add_argument(
1251
            "--data-parallel-size",
1252
1253
1254
            "--dp-size",
            type=int,
            default=ServerArgs.dp_size,
1255
            help="The data parallelism size.",
1256
1257
1258
1259
1260
        )
        parser.add_argument(
            "--load-balance-method",
            type=str,
            default=ServerArgs.load_balance_method,
1261
            help="The load balancing strategy for data parallelism.",
1262
1263
1264
            choices=[
                "round_robin",
                "shortest_queue",
1265
                "minimum_tokens",
1266
1267
            ],
        )
1268

1269
        # Multi-node distributed serving
1270
        parser.add_argument(
1271
            "--dist-init-addr",
1272
            "--nccl-init-addr",  # For backward compatibility. This will be removed in the future.
1273
            type=str,
1274
            help="The host address for initializing distributed backend (e.g., `192.168.0.2:25000`).",
1275
1276
        )
        parser.add_argument(
Liangsheng Yin's avatar
Liangsheng Yin committed
1277
            "--nnodes", type=int, default=ServerArgs.nnodes, help="The number of nodes."
1278
        )
1279
1280
1281
        parser.add_argument(
            "--node-rank", type=int, default=ServerArgs.node_rank, help="The node rank."
        )
1282

Lianmin Zheng's avatar
Lianmin Zheng committed
1283
1284
1285
1286
1287
1288
1289
        # Model override args
        parser.add_argument(
            "--json-model-override-args",
            type=str,
            help="A dictionary in JSON string format used to override default model configurations.",
            default=ServerArgs.json_model_override_args,
        )
1290
1291
1292
1293
1294
        parser.add_argument(
            "--preferred-sampling-params",
            type=str,
            help="json-formatted sampling settings that will be returned in /get_model_info",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1295

1296
        # LoRA
1297
1298
1299
1300
1301
1302
        parser.add_argument(
            "--enable-lora",
            default=ServerArgs.enable_lora,
            action="store_true",
            help="Enable LoRA support for the model. This argument is automatically set to True if `--lora-paths` is provided for backward compatibility.",
        )
1303
1304
1305
1306
1307
1308
1309
1310
1311
        parser.add_argument(
            "--max-lora-rank",
            default=ServerArgs.max_lora_rank,
            type=int,
            help="The maximum rank of LoRA adapters. If not specified, it will be automatically inferred from the adapters provided in --lora-paths.",
        )
        parser.add_argument(
            "--lora-target-modules",
            type=str,
1312
            choices=SUPPORTED_LORA_TARGET_MODULES + [LORA_TARGET_ALL_MODULES],
1313
1314
            nargs="*",
            default=None,
1315
1316
1317
            help="The union set of all target modules where LoRA should be applied. If not specified, "
            "it will be automatically inferred from the adapters provided in --lora-paths. If 'all' is specified, "
            "all supported modules will be targeted.",
1318
        )
1319
1320
1321
1322
1323
1324
        parser.add_argument(
            "--lora-paths",
            type=str,
            nargs="*",
            default=None,
            action=LoRAPathAction,
1325
            help="The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}.",
1326
1327
1328
1329
1330
        )
        parser.add_argument(
            "--max-loras-per-batch",
            type=int,
            default=8,
1331
1332
            help="Maximum number of adapters for a running batch, include base-only request.",
        )
1333
1334
1335
1336
1337
1338
        parser.add_argument(
            "--max-loaded-loras",
            type=int,
            default=ServerArgs.max_loaded_loras,
            help="If specified, it limits the maximum number of LoRA adapters loaded in CPU memory at a time. The value must be greater than or equal to `--max-loras-per-batch`.",
        )
1339
1340
1341
1342
1343
        parser.add_argument(
            "--lora-backend",
            type=str,
            default="triton",
            help="Choose the kernel backend for multi-LoRA serving.",
1344
1345
1346
        )

        # Kernel backend
1347
1348
1349
        parser.add_argument(
            "--attention-backend",
            type=str,
1350
            choices=[
1351
                "aiter",
1352
                "cutlass_mla",
1353
                "fa3",
1354
                "flashinfer",
1355
                "flashmla",
1356
                "intel_amx",
1357
                "torch_native",
1358
                "ascend",
1359
                "triton",
1360
                "trtllm_mla",
1361
                "trtllm_mha",
1362
                "dual_chunk_flash_attn",
1363
            ],
1364
1365
1366
            default=ServerArgs.attention_backend,
            help="Choose the kernels for attention layers.",
        )
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
        parser.add_argument(
            "--decode-attention-backend",
            type=str,
            choices=[
                "flashinfer",
                "triton",
                "torch_native",
                "fa3",
                "flashmla",
                "cutlass_mla",
            ],
            default=ServerArgs.decode_attention_backend,
            help="Choose the kernels for decode attention layers (have priority over --attention-backend).",
        )

        parser.add_argument(
            "--prefill-attention-backend",
            type=str,
            choices=[
                "flashinfer",
                "triton",
                "torch_native",
                "fa3",
                "flashmla",
                "cutlass_mla",
            ],
            default=ServerArgs.prefill_attention_backend,
            help="Choose the kernels for prefill attention layers (have priority over --attention-backend).",
        )
1396
1397
1398
1399
1400
1401
1402
        parser.add_argument(
            "--sampling-backend",
            type=str,
            choices=["flashinfer", "pytorch"],
            default=ServerArgs.sampling_backend,
            help="Choose the kernels for sampling layers.",
        )
1403
1404
1405
        parser.add_argument(
            "--grammar-backend",
            type=str,
1406
            choices=["xgrammar", "outlines", "llguidance", "none"],
1407
            default=ServerArgs.grammar_backend,
Lianmin Zheng's avatar
Lianmin Zheng committed
1408
            help="Choose the backend for grammar-guided decoding.",
1409
        )
1410
1411
1412
1413
1414
1415
1416
        parser.add_argument(
            "--mm-attention-backend",
            type=str,
            choices=["sdpa", "fa3", "triton_attn"],
            default=ServerArgs.mm_attention_backend,
            help="Set multimodal attention backend.",
        )
1417

1418
1419
1420
1421
        # Speculative decoding
        parser.add_argument(
            "--speculative-algorithm",
            type=str,
James Liu's avatar
James Liu committed
1422
            choices=["EAGLE", "EAGLE3", "NEXTN"],
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
            help="Speculative algorithm.",
        )
        parser.add_argument(
            "--speculative-draft-model-path",
            type=str,
            help="The path of the draft model weights. This can be a local folder or a Hugging Face repo ID.",
        )
        parser.add_argument(
            "--speculative-num-steps",
            type=int,
            help="The number of steps sampled from draft model in Speculative Decoding.",
            default=ServerArgs.speculative_num_steps,
        )
        parser.add_argument(
            "--speculative-eagle-topk",
            type=int,
1439
            help="The number of tokens sampled from the draft model in eagle2 each step.",
1440
1441
            default=ServerArgs.speculative_eagle_topk,
        )
1442
1443
1444
        parser.add_argument(
            "--speculative-num-draft-tokens",
            type=int,
1445
            help="The number of tokens sampled from the draft model in Speculative Decoding.",
1446
1447
            default=ServerArgs.speculative_num_draft_tokens,
        )
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
        parser.add_argument(
            "--speculative-accept-threshold-single",
            type=float,
            help="Accept a draft token if its probability in the target model is greater than this threshold.",
            default=ServerArgs.speculative_accept_threshold_single,
        )
        parser.add_argument(
            "--speculative-accept-threshold-acc",
            type=float,
            help="The accept probability of a draft token is raised from its target probability p to min(1, p / threshold_acc).",
            default=ServerArgs.speculative_accept_threshold_acc,
        )
1460
1461
1462
1463
1464
1465
        parser.add_argument(
            "--speculative-token-map",
            type=str,
            help="The path of the draft model's small vocab table.",
            default=ServerArgs.speculative_token_map,
        )
1466
1467
1468
1469
1470

        # Expert parallelism
        parser.add_argument(
            "--expert-parallel-size",
            "--ep-size",
Cheng Wan's avatar
Cheng Wan committed
1471
            "--ep",
1472
1473
1474
1475
1476
            type=int,
            default=ServerArgs.ep_size,
            help="The expert parallelism size.",
        )
        parser.add_argument(
1477
1478
1479
1480
1481
            "--moe-a2a-backend",
            type=str,
            choices=["deepep"],
            default=ServerArgs.moe_a2a_backend,
            help="Choose the backend for MoE A2A.",
1482
        )
1483
        parser.add_argument(
1484
            "--enable-flashinfer-cutlass-moe",
1485
            action="store_true",
1486
            help="Enable FlashInfer CUTLASS MoE backend for modelopt_fp4 quant on Blackwell. Supports MoE-EP",
1487
        )
1488
        parser.add_argument(
1489
1490
            "--enable-flashinfer-trtllm-moe",
            action="store_true",
1491
            help="Enable FlashInfer TRTLLM MoE backend on Blackwell. Supports BlockScale FP8 MoE-EP",
1492
1493
        )
        parser.add_argument(
1494
1495
1496
1497
            "--enable-flashinfer-allreduce-fusion",
            action="store_true",
            help="Enable FlashInfer allreduce fusion for Add_RMSNorm.",
        )
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
        parser.add_argument(
            "--deepep-mode",
            type=str,
            choices=["normal", "low_latency", "auto"],
            default="auto",
            help="Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch.",
        )
        parser.add_argument(
            "--ep-num-redundant-experts",
            type=int,
            default=ServerArgs.ep_num_redundant_experts,
            help="Allocate this number of redundant experts in expert parallel.",
        )
        parser.add_argument(
            "--ep-dispatch-algorithm",
            type=str,
            default=ServerArgs.ep_dispatch_algorithm,
            help="The algorithm to choose ranks for redundant experts in expert parallel.",
        )
        parser.add_argument(
            "--init-expert-location",
            type=str,
            default=ServerArgs.init_expert_location,
            help="Initial location of EP experts.",
        )
        parser.add_argument(
            "--enable-eplb",
            action="store_true",
            help="Enable EPLB algorithm",
        )
        parser.add_argument(
            "--eplb-algorithm",
            type=str,
            default=ServerArgs.eplb_algorithm,
            help="Chosen EPLB algorithm",
        )
        parser.add_argument(
            "--eplb-rebalance-num-iterations",
            type=int,
            default=ServerArgs.eplb_rebalance_num_iterations,
            help="Number of iterations to automatically trigger a EPLB re-balance.",
        )
        parser.add_argument(
            "--eplb-rebalance-layers-per-chunk",
            type=int,
            default=ServerArgs.eplb_rebalance_layers_per_chunk,
            help="Number of layers to rebalance per forward pass.",
        )
        parser.add_argument(
            "--expert-distribution-recorder-mode",
            type=str,
            default=ServerArgs.expert_distribution_recorder_mode,
            help="Mode of expert distribution recorder.",
        )
        parser.add_argument(
            "--expert-distribution-recorder-buffer-size",
            type=int,
            default=ServerArgs.expert_distribution_recorder_buffer_size,
            help="Circular buffer size of expert distribution recorder. Set to -1 to denote infinite buffer.",
        )
        parser.add_argument(
            "--enable-expert-distribution-metrics",
            action="store_true",
            help="Enable logging metrics for expert balancedness",
        )
        parser.add_argument(
            "--deepep-config",
            type=str,
            default=ServerArgs.deepep_config,
            help="Tuned DeepEP config suitable for your own cluster. It can be either a string with JSON content or a file path.",
        )
        parser.add_argument(
            "--moe-dense-tp-size",
            type=int,
            default=ServerArgs.moe_dense_tp_size,
            help="TP size for MoE dense MLP layers. This flag is useful when, with large TP size, there are errors caused by weights in MLP layers having dimension smaller than the min dimension GEMM supports.",
        )
1575

Lianmin Zheng's avatar
Lianmin Zheng committed
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
        # Hierarchical cache
        parser.add_argument(
            "--enable-hierarchical-cache",
            action="store_true",
            help="Enable hierarchical cache",
        )
        parser.add_argument(
            "--hicache-ratio",
            type=float,
            default=ServerArgs.hicache_ratio,
            help="The ratio of the size of host KV cache memory pool to the size of device pool.",
        )
        parser.add_argument(
            "--hicache-size",
            type=int,
            default=ServerArgs.hicache_size,
            help="The size of host KV cache memory pool in gigabytes, which will override the hicache_ratio if set.",
        )
        parser.add_argument(
            "--hicache-write-policy",
            type=str,
            choices=["write_back", "write_through", "write_through_selective"],
            default=ServerArgs.hicache_write_policy,
            help="The write policy of hierarchical cache.",
        )
        parser.add_argument(
            "--hicache-io-backend",
            type=str,
            choices=["direct", "kernel"],
            default=ServerArgs.hicache_io_backend,
            help="The IO backend for KV cache transfer between CPU and GPU",
        )
1608
1609
1610
1611
1612
1613
1614
1615
        parser.add_argument(
            "--hicache-mem-layout",
            type=str,
            choices=["layer_first", "page_first"],
            default=ServerArgs.hicache_mem_layout,
            help="The layout of host memory pool for hierarchical cache.",
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
1616
1617
1618
        parser.add_argument(
            "--hicache-storage-backend",
            type=str,
1619
            choices=["file", "mooncake", "hf3fs", "nixl"],
Lianmin Zheng's avatar
Lianmin Zheng committed
1620
1621
1622
            default=ServerArgs.hicache_storage_backend,
            help="The storage backend for hierarchical KV cache.",
        )
pansicheng's avatar
pansicheng committed
1623
1624
1625
1626
1627
1628
1629
        parser.add_argument(
            "--hicache-storage-prefetch-policy",
            type=str,
            choices=["best_effort", "wait_complete", "timeout"],
            default=ServerArgs.hicache_storage_prefetch_policy,
            help="Control when prefetching from the storage backend should stop.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1630

1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
        # Double Sparsity
        parser.add_argument(
            "--enable-double-sparsity",
            action="store_true",
            help="Enable double sparsity attention",
        )
        parser.add_argument(
            "--ds-channel-config-path",
            type=str,
            default=ServerArgs.ds_channel_config_path,
            help="The path of the double sparsity channel config",
        )
        parser.add_argument(
            "--ds-heavy-channel-num",
            type=int,
            default=ServerArgs.ds_heavy_channel_num,
            help="The number of heavy channels in double sparsity attention",
        )
        parser.add_argument(
            "--ds-heavy-token-num",
            type=int,
            default=ServerArgs.ds_heavy_token_num,
            help="The number of heavy tokens in double sparsity attention",
        )
        parser.add_argument(
            "--ds-heavy-channel-type",
            type=str,
            default=ServerArgs.ds_heavy_channel_type,
            help="The type of heavy channels in double sparsity attention",
        )
        parser.add_argument(
            "--ds-sparse-decode-threshold",
            type=int,
            default=ServerArgs.ds_sparse_decode_threshold,
            help="The type of heavy channels in double sparsity attention",
        )

1668
        # Optimization/debug options
Liangsheng Yin's avatar
Liangsheng Yin committed
1669
        parser.add_argument(
Lianmin Zheng's avatar
Lianmin Zheng committed
1670
            "--disable-radix-cache",
Liangsheng Yin's avatar
Liangsheng Yin committed
1671
            action="store_true",
Ying Sheng's avatar
Ying Sheng committed
1672
            help="Disable RadixAttention for prefix caching.",
Liangsheng Yin's avatar
Liangsheng Yin committed
1673
        )
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
        parser.add_argument(
            "--cuda-graph-max-bs",
            type=int,
            default=ServerArgs.cuda_graph_max_bs,
            help="Set the maximum batch size for cuda graph. It will extend the cuda graph capture batch size to this value.",
        )
        parser.add_argument(
            "--cuda-graph-bs",
            type=int,
            nargs="+",
            help="Set the list of batch sizes for cuda graph.",
        )
1686
1687
1688
        parser.add_argument(
            "--disable-cuda-graph",
            action="store_true",
1689
            help="Disable cuda graph.",
1690
        )
1691
        parser.add_argument(
1692
1693
            "--disable-cuda-graph-padding",
            action="store_true",
1694
            help="Disable cuda graph when padding is needed. Still uses cuda graph when padding is not needed.",
1695
        )
1696
1697
1698
1699
1700
        parser.add_argument(
            "--enable-profile-cuda-graph",
            action="store_true",
            help="Enable profiling of cuda graph capture.",
        )
1701
1702
1703
1704
1705
        parser.add_argument(
            "--enable-cudagraph-gc",
            action="store_true",
            help="Enable garbage collection during CUDA graph capture. If disabled (default), GC is frozen during capture to speed up the process.",
        )
1706
1707
1708
1709
1710
        parser.add_argument(
            "--enable-nccl-nvls",
            action="store_true",
            help="Enable NCCL NVLS for prefill heavy requests when available.",
        )
1711
1712
1713
1714
1715
        parser.add_argument(
            "--enable-symm-mem",
            action="store_true",
            help="Enable NCCL symmetric memory for fast collectives.",
        )
1716
1717
1718
1719
1720
        parser.add_argument(
            "--enable-tokenizer-batch-encode",
            action="store_true",
            help="Enable batch tokenization for improved performance when processing multiple text inputs. Do not use with image inputs, pre-tokenized input_ids, or input_embeds.",
        )
1721
        parser.add_argument(
1722
            "--disable-outlines-disk-cache",
1723
            action="store_true",
1724
            help="Disable disk cache of outlines to avoid possible crashes related to file system or high concurrency.",
1725
        )
1726
1727
1728
1729
1730
        parser.add_argument(
            "--disable-custom-all-reduce",
            action="store_true",
            help="Disable the custom all-reduce kernel and fall back to NCCL.",
        )
1731
1732
1733
1734
1735
        parser.add_argument(
            "--enable-mscclpp",
            action="store_true",
            help="Enable using mscclpp for small messages for all-reduce kernel and fall back to NCCL.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1736
        parser.add_argument(
1737
            "--disable-overlap-schedule",
Lianmin Zheng's avatar
Lianmin Zheng committed
1738
            action="store_true",
1739
            help="Disable the overlap scheduler, which overlaps the CPU scheduler with GPU model worker.",
Lianmin Zheng's avatar
Lianmin Zheng committed
1740
        )
1741
1742
1743
        parser.add_argument(
            "--enable-mixed-chunk",
            action="store_true",
1744
            help="Enabling mixing prefill and decode in a batch when using chunked prefill.",
1745
        )
Ke Bao's avatar
Ke Bao committed
1746
1747
1748
        parser.add_argument(
            "--enable-dp-attention",
            action="store_true",
1749
            help="Enabling data parallelism for attention and tensor parallelism for FFN. The dp size should be equal to the tp size. Currently DeepSeek-V2 and Qwen 2/3 MoE models are supported.",
Ke Bao's avatar
Ke Bao committed
1750
        )
1751
1752
1753
1754
1755
        parser.add_argument(
            "--enable-dp-lm-head",
            action="store_true",
            help="Enable vocabulary parallel across the attention TP group to avoid all-gather across DP groups, optimizing performance under DP attention.",
        )
1756
1757
1758
1759
1760
        parser.add_argument(
            "--enable-two-batch-overlap",
            action="store_true",
            help="Enabling two micro batches to overlap.",
        )
1761
1762
1763
1764
1765
1766
        parser.add_argument(
            "--tbo-token-distribution-threshold",
            type=float,
            default=ServerArgs.tbo_token_distribution_threshold,
            help="The threshold of token distribution between two batches in micro-batch-overlap, determines whether to two-batch-overlap or two-chunk-overlap. Set to 0 denote disable two-chunk-overlap.",
        )
1767
1768
1769
        parser.add_argument(
            "--enable-torch-compile",
            action="store_true",
1770
1771
            help="Optimize the model with torch.compile. Experimental feature.",
        )
1772
        parser.add_argument(
1773
            "--torch-compile-max-bs",
1774
            type=int,
1775
            default=ServerArgs.torch_compile_max_bs,
1776
1777
            help="Set the maximum batch size when using torch compile.",
        )
1778
1779
1780
1781
        parser.add_argument(
            "--torchao-config",
            type=str,
            default=ServerArgs.torchao_config,
1782
            help="Optimize the model with torchao. Experimental feature. Current choices are: int8dq, int8wo, int4wo-<group_size>, fp8wo, fp8dq-per_tensor, fp8dq-per_row",
1783
        )
1784
1785
1786
1787
1788
        parser.add_argument(
            "--enable-nan-detection",
            action="store_true",
            help="Enable the NaN detection for debugging purposes.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1789
        parser.add_argument(
1790
            "--enable-p2p-check",
Lianmin Zheng's avatar
Lianmin Zheng committed
1791
            action="store_true",
1792
            help="Enable P2P check for GPU access, otherwise the p2p access is allowed by default.",
Lianmin Zheng's avatar
Lianmin Zheng committed
1793
        )
1794
        parser.add_argument(
1795
            "--triton-attention-reduce-in-fp32",
1796
            action="store_true",
1797
            help="Cast the intermediate attention results to fp32 to avoid possible crashes related to fp16."
1798
            "This only affects Triton attention kernels.",
1799
        )
1800
1801
1802
1803
1804
1805
        parser.add_argument(
            "--triton-attention-num-kv-splits",
            type=int,
            default=ServerArgs.triton_attention_num_kv_splits,
            help="The number of KV splits in flash decoding Triton kernel. Larger value is better in longer context scenarios. The default value is 8.",
        )
1806
1807
1808
1809
1810
1811
1812
1813
        parser.add_argument(
            "--num-continuous-decode-steps",
            type=int,
            default=ServerArgs.num_continuous_decode_steps,
            help="Run multiple continuous decoding steps to reduce scheduling overhead. "
            "This can potentially increase throughput but may also increase time-to-first-token latency. "
            "The default value is 1, meaning only run one decoding step at a time.",
        )
1814
1815
1816
1817
1818
        parser.add_argument(
            "--delete-ckpt-after-loading",
            action="store_true",
            help="Delete the model checkpoint after loading the model.",
        )
1819
1820
1821
1822
1823
        parser.add_argument(
            "--enable-memory-saver",
            action="store_true",
            help="Allow saving memory using release_memory_occupation and resume_memory_occupation",
        )
1824
1825
1826
1827
1828
        parser.add_argument(
            "--allow-auto-truncate",
            action="store_true",
            help="Allow automatically truncating requests that exceed the maximum input length instead of returning an error.",
        )
1829
1830
1831
1832
1833
        parser.add_argument(
            "--enable-custom-logit-processor",
            action="store_true",
            help="Enable users to pass custom logit processors to the server (disabled by default for security)",
        )
1834
        parser.add_argument(
1835
            "--flashinfer-mla-disable-ragged",
1836
            action="store_true",
1837
            help="Not using ragged prefill wrapper when running flashinfer mla",
1838
        )
1839
        parser.add_argument(
1840
1841
1842
            "--disable-shared-experts-fusion",
            action="store_true",
            help="Disable shared experts fusion optimization for deepseek v3/r1.",
1843
        )
1844
1845
1846
1847
1848
        parser.add_argument(
            "--disable-chunked-prefix-cache",
            action="store_true",
            help="Disable chunked prefix cache feature for deepseek, which should save overhead for short sequences.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1849
1850
1851
1852
1853
        parser.add_argument(
            "--disable-fast-image-processor",
            action="store_true",
            help="Adopt base image processor instead of fast image processor.",
        )
1854
1855
1856
1857
1858
        parser.add_argument(
            "--enable-return-hidden-states",
            action="store_true",
            help="Enable returning hidden states with responses.",
        )
Yuan Luo's avatar
Yuan Luo committed
1859
1860
1861
1862
1863
        parser.add_argument(
            "--enable-triton-kernel-moe",
            action="store_true",
            help="Use triton moe grouped gemm kernel.",
        )
1864
1865
1866
1867
1868
        parser.add_argument(
            "--enable-flashinfer-mxfp4-moe",
            action="store_true",
            help="Enable FlashInfer MXFP4 MoE backend for modelopt_fp4 quant on Blackwell.",
        )
1869
1870
1871
1872
1873
1874
        parser.add_argument(
            "--scheduler-recv-interval",
            type=int,
            default=ServerArgs.scheduler_recv_interval,
            help="The interval to poll requests in scheduler. Can be set to >1 to reduce the overhead of this.",
        )
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894

        # Debug tensor dumps
        parser.add_argument(
            "--debug-tensor-dump-output-folder",
            type=str,
            default=ServerArgs.debug_tensor_dump_output_folder,
            help="The output folder for dumping tensors.",
        )
        parser.add_argument(
            "--debug-tensor-dump-input-file",
            type=str,
            default=ServerArgs.debug_tensor_dump_input_file,
            help="The input filename for dumping tensors",
        )
        parser.add_argument(
            "--debug-tensor-dump-inject",
            type=str,
            default=ServerArgs.debug_tensor_dump_inject,
            help="Inject the outputs from jax as the input of every layer.",
        )
1895
1896
1897
1898
1899
        parser.add_argument(
            "--debug-tensor-dump-prefill-only",
            action="store_true",
            help="Only dump the tensors for prefill requests (i.e. batch size > 1).",
        )
1900

Lianmin Zheng's avatar
Lianmin Zheng committed
1901
        # PD disaggregation
Byron Hsu's avatar
Byron Hsu committed
1902
1903
1904
1905
1906
1907
1908
        parser.add_argument(
            "--disaggregation-mode",
            type=str,
            default="null",
            choices=["null", "prefill", "decode"],
            help='Only used for PD disaggregation. "prefill" for prefill-only server, and "decode" for decode-only server. If not specified, it is not PD disaggregated',
        )
1909
1910
1911
1912
        parser.add_argument(
            "--disaggregation-transfer-backend",
            type=str,
            default=ServerArgs.disaggregation_transfer_backend,
1913
            choices=["mooncake", "nixl", "ascend"],
1914
1915
            help="The backend for disaggregation transfer. Default is mooncake.",
        )
1916
1917
1918
1919
1920
1921
        parser.add_argument(
            "--disaggregation-bootstrap-port",
            type=int,
            default=ServerArgs.disaggregation_bootstrap_port,
            help="Bootstrap server port on the prefill server. Default is 8998.",
        )
Byron Hsu's avatar
Byron Hsu committed
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
        parser.add_argument(
            "--disaggregation-decode-tp",
            type=int,
            default=ServerArgs.disaggregation_decode_tp,
            help="Decode tp size. If not set, it matches the tp size of the current engine. This is only set on the prefill server.",
        )
        parser.add_argument(
            "--disaggregation-decode-dp",
            type=int,
            default=ServerArgs.disaggregation_decode_dp,
            help="Decode dp size. If not set, it matches the dp size of the current engine. This is only set on the prefill server.",
        )
        parser.add_argument(
            "--disaggregation-prefill-pp",
            type=int,
            default=ServerArgs.disaggregation_prefill_pp,
            help="Prefill pp size. If not set, it is default to 1. This is only set on the decode server.",
        )
1940
1941
1942
1943
        parser.add_argument(
            "--disaggregation-ib-device",
            type=str,
            default=ServerArgs.disaggregation_ib_device,
1944
1945
1946
            help="The InfiniBand devices for disaggregation transfer, accepts single device (e.g., --disaggregation-ib-device mlx5_0) "
            "or multiple comma-separated devices (e.g., --disaggregation-ib-device mlx5_0,mlx5_1). "
            "Default is None, which triggers automatic device detection when mooncake backend is enabled.",
1947
        )
1948
1949
1950
1951
1952
1953
        parser.add_argument(
            "--num-reserved-decode-tokens",
            type=int,
            default=ServerArgs.num_reserved_decode_tokens,
            help="Number of decode tokens that will have memory reserved when adding new request to the running batch.",
        )
1954
1955
1956
1957
1958
1959
        parser.add_argument(
            "--pdlb-url",
            type=str,
            default=None,
            help="The URL of the PD disaggregation load balancer. If set, the prefill/decode server will register with the load balancer.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1960
1961

        # Custom weight loader
1962
1963
1964
1965
1966
1967
1968
        parser.add_argument(
            "--custom-weight-loader",
            type=str,
            nargs="*",
            default=None,
            help="The custom dataloader which used to update the model. Should be set with a valid import path, such as my_package.weight_load_func",
        )
1969
1970
1971
1972
1973
        parser.add_argument(
            "--enable-pdmux",
            action="store_true",
            help="Enable PD-Multiplexing, PD running on greenctx stream.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1974
1975

        # For PD-Multiplexing
1976
1977
1978
1979
1980
1981
        parser.add_argument(
            "--sm-group-num",
            type=int,
            default=ServerArgs.sm_group_num,
            help="Number of sm partition groups.",
        )
1982
1983
1984
1985
1986
        parser.add_argument(
            "--weight-loader-disable-mmap",
            action="store_true",
            help="Disable mmap while loading weight using safetensors.",
        )
Byron Hsu's avatar
Byron Hsu committed
1987

1988
1989
1990
1991
1992
1993
1994
1995
        # For tool server
        parser.add_argument(
            "--tool-server",
            type=str,
            default=None,
            help="Either 'demo' or a comma-separated list of tool server urls to use for the model. If not specified, no tool server will be used.",
        )

1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
        # Deprecated arguments
        parser.add_argument(
            "--enable-ep-moe",
            action="store_true",
            help="(Deprecated) Enabling expert parallelism for moe. The ep size is equal to the tp size.",
        )
        parser.add_argument(
            "--enable-deepep-moe",
            action="store_true",
            help="(Deprecated) Enabling DeepEP MoE implementation for EP MoE.",
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
2008
2009
    @classmethod
    def from_cli_args(cls, args: argparse.Namespace):
2010
        args.tp_size = args.tensor_parallel_size
2011
        args.pp_size = args.pipeline_parallel_size
2012
        args.dp_size = args.data_parallel_size
xiaobochen's avatar
xiaobochen committed
2013
        args.ep_size = args.expert_parallel_size
Lianmin Zheng's avatar
Lianmin Zheng committed
2014
2015
2016
2017
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        return cls(**{attr: getattr(args, attr) for attr in attrs})

    def url(self):
2018
        if is_valid_ipv6_address(self.host):
2019
2020
2021
            return f"http://[{self.host}]:{self.port}"
        else:
            return f"http://{self.host}:{self.port}"
Lianmin Zheng's avatar
Lianmin Zheng committed
2022

Lianmin Zheng's avatar
Lianmin Zheng committed
2023
2024
2025
2026
2027
2028
2029
2030
2031
2032
2033
    def get_hf_config(self):
        kwargs = {}
        hf_config = get_config(
            self.model_path,
            trust_remote_code=self.trust_remote_code,
            revision=self.revision,
            model_override_args=json.loads(self.json_model_override_args),
            **kwargs,
        )
        return hf_config

2034
    def check_server_args(self):
2035
        # Check parallel size constraints
2036
        assert (
2037
2038
2039
2040
2041
2042
2043
2044
2045
2046
            self.tp_size * self.pp_size
        ) % self.nnodes == 0, "tp_size must be divisible by number of nodes"

        if self.pp_size > 1:
            assert (
                self.disable_overlap_schedule
                and self.speculative_algorithm is None
                and not self.enable_mixed_chunk
            ), "Pipeline parallelism is not compatible with overlap schedule, speculative decoding, mixed chunked prefill."

2047
        assert not (
2048
2049
            self.dp_size > 1 and self.nnodes != 1 and not self.enable_dp_attention
        ), "multi-node data parallel is not supported unless dp attention!"
2050

2051
        assert self.base_gpu_id >= 0, "base_gpu_id must be non-negative"
2052
        assert self.gpu_id_step >= 1, "gpu_id_step must be positive"
2053

Lianmin Zheng's avatar
Lianmin Zheng committed
2054
2055
2056
2057
2058
        assert self.moe_dense_tp_size in {
            1,
            None,
        }, "moe_dense_tp_size only support 1 and None currently"

2059
2060
2061
2062
2063
        # Check model architecture
        model_arch = self.get_hf_config().architectures[0]
        if "Llama4" in model_arch:
            assert self.attention_backend == "fa3", "fa3 is required for Llama4 model"

2064
2065
        if model_arch in [
            "Gemma2ForCausalLM",
2066
2067
            "Gemma3ForCausalLM",
            "Gemma3ForConditionalGeneration",
2068
2069
2070
            "Gemma3nForCausalLM",
            "Gemma3nForConditionalGeneration",
        ]:
2071
2072
            # FIXME: https://github.com/sgl-project/sglang/pull/7367 is not compatible with gemma2 model.
            # It failed at this test: https://github.com/sgl-project/sglang/actions/runs/16255155597/job/45890331952#step:4:736
2073
2074
2075
            logger.warning(
                f"Disable hybrid SWA memory for {model_arch} as it is not yet supported."
            )
2076
2077
            self.disable_hybrid_swa_memory = True

2078
        # Check LoRA
2079
2080
        self.check_lora_server_args()

2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091
        # Check speculative decoding
        if self.speculative_algorithm is not None:
            assert (
                not self.enable_mixed_chunk
            ), "enable_mixed_chunk is required for speculative decoding"

        # Check chunked prefill
        assert (
            self.chunked_prefill_size % self.page_size == 0
        ), "chunked_prefill_size must be divisible by page_size"

2092
    def check_lora_server_args(self):
2093
2094
2095
2096
2097
2098
        assert (
            self.max_loras_per_batch > 0
            # FIXME
            and (self.lora_paths is None or self.disable_radix_cache)
        ), "compatibility of lora and radix attention is in progress"

2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
        # Enable LoRA if any LoRA paths are provided for backward compatibility.
        if self.lora_paths:
            if self.enable_lora is None:
                self.enable_lora = True
                logger.info(
                    "--enable-lora is set to True because --lora-paths is provided."
                )
            elif self.enable_lora is False:
                logger.warning(
                    "--enable-lora is set to False, any provided lora_paths will be ignored."
                )

        if self.enable_lora:
            # Normalize lora_paths to a dictionary if it is a list.
2113
            # TODO (lifuhuang): support specifying pinned adapters in server_args.
2114
2115
2116
2117
2118
2119
            if isinstance(self.lora_paths, list):
                lora_paths = self.lora_paths
                self.lora_paths = {}
                for lora_path in lora_paths:
                    if "=" in lora_path:
                        name, path = lora_path.split("=", 1)
2120
2121
2122
                        self.lora_paths[name] = LoRARef(
                            lora_name=name, lora_path=path, pinned=False
                        )
2123
                    else:
2124
                        self.lora_paths[lora_path] = LoRARef(
2125
                            lora_name=lora_path, lora_path=lora_path, pinned=False
2126
2127
2128
                        )
            elif isinstance(self.lora_paths, dict):
                self.lora_paths = {
2129
                    k: LoRARef(lora_name=k, lora_path=v, pinned=False)
2130
2131
2132
2133
2134
2135
2136
2137
2138
                    for k, v in self.lora_paths.items()
                }
            elif self.lora_paths is None:
                self.lora_paths = {}
            else:
                raise ValueError(
                    f"Invalid type for --lora-paths: {type(self.lora_paths)}. "
                    "Expected a list or a dictionary."
                )
2139
2140
2141
2142
2143
2144
2145
2146
2147
2148
2149
2150
2151
2152

            # Expand target modules
            if self.lora_target_modules:
                self.lora_target_modules = set(self.lora_target_modules)
                if "all" in self.lora_target_modules:
                    assert (
                        len(self.lora_target_modules) == 1
                    ), "If 'all' is specified in --lora-target-modules, it should be the only module specified."
                    self.lora_target_modules = set(SUPPORTED_LORA_TARGET_MODULES)

            # Ensure sufficient information is provided for LoRA initialization.
            assert self.lora_paths or (
                self.max_lora_rank and self.lora_target_modules
            ), "When no initial --lora-paths is provided, you need to specify both --max-lora-rank and --lora-target-modules for LoRA initialization."
2153

2154
2155
2156
2157
2158
2159
2160
2161
2162
2163
2164
2165
2166
            # Validate max_loaded_loras
            if self.max_loaded_loras is not None:
                assert self.max_loaded_loras >= self.max_loras_per_batch, (
                    "max_loaded_loras should be greater than or equal to max_loras_per_batch. "
                    f"max_loaded_loras={self.max_loaded_loras}, max_loras_per_batch={self.max_loras_per_batch}"
                )
                assert (
                    not self.lora_paths or len(self.lora_paths) <= self.max_loaded_loras
                ), (
                    "The number of LoRA paths should not exceed max_loaded_loras. "
                    f"max_loaded_loras={self.max_loaded_loras}, lora_paths={len(self.lora_paths)}"
                )

Lianmin Zheng's avatar
Lianmin Zheng committed
2167
2168
2169
2170
2171
2172
2173
2174
2175
2176
2177
2178
2179
2180
2181
2182
2183
2184
2185
2186
2187
2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
2198
2199
2200
2201
2202
2203
2204
2205
2206
2207
2208
2209
2210
2211
2212
2213
2214
2215
2216
    def validate_disagg_tp_size(self, prefill_tp: int, decode_tp: int):
        larger_tp = max(decode_tp, prefill_tp)
        smaller_tp = min(decode_tp, prefill_tp)
        assert larger_tp % smaller_tp == 0, (
            "Different tp size is supported only when one tp is multiple of the other. "
            f"decode_tp={decode_tp}, prefill_tp={prefill_tp}"
        )

    def adjust_mem_fraction_for_vlm(self, model_config):
        vision_config = getattr(model_config.hf_config, "vision_config", None)
        if vision_config is None:
            return

        # roughly reduce the mem_fraction_static base on params of Vit
        original_server_arg_mem_fraction = self.mem_fraction_static
        # a base mem_fraction_static factor for regular Vit
        base_mem_fraction_reduction_ratio = 0.95

        vit_num_layers = getattr(vision_config, "num_hidden_layers", 24)
        vit_hidden_size = getattr(vision_config, "hidden_size", 1024)

        # baseline ViT params (ViT-L/14)
        baseline_vit_layers = 24
        baseline_vit_hidden_size = 1024

        # weight params count
        current_complexity_score = vit_num_layers * (vit_hidden_size**2)
        baseline_complexity_score = baseline_vit_layers * (baseline_vit_hidden_size**2)
        complexity_ratio = (
            current_complexity_score / baseline_complexity_score
            if baseline_complexity_score > 0
            else 1.0
        )

        # every time the complexity grows 100%, adjust final factor for 10%
        sensitivity_scale = 0.1
        dynamic_adjustment_factor = 1.0 - sensitivity_scale * (complexity_ratio - 1.0)
        dynamic_adjustment_factor = max(0.8, min(1.05, dynamic_adjustment_factor))

        final_overall_factor = (
            base_mem_fraction_reduction_ratio * dynamic_adjustment_factor
        )
        self.mem_fraction_static = (
            original_server_arg_mem_fraction * final_overall_factor
        )
        logger.warning(
            f"Multimodal model: Dynamically adjusted --mem-fraction-static "
            f"from: {original_server_arg_mem_fraction:.3f} to: {self.mem_fraction_static:.3f}."
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
2217

Lianmin Zheng's avatar
Lianmin Zheng committed
2218
def prepare_server_args(argv: List[str]) -> ServerArgs:
2219
2220
2221
2222
2223
2224
2225
2226
2227
2228
2229
2230
    """
    Prepare the server arguments from the command line arguments.

    Args:
        args: The command line arguments. Typically, it should be `sys.argv[1:]`
            to ensure compatibility with `parse_args` when no arguments are passed.

    Returns:
        The server arguments.
    """
    parser = argparse.ArgumentParser()
    ServerArgs.add_cli_args(parser)
Lianmin Zheng's avatar
Lianmin Zheng committed
2231
    raw_args = parser.parse_args(argv)
2232
2233
2234
2235
    server_args = ServerArgs.from_cli_args(raw_args)
    return server_args


2236
2237
2238
ZMQ_TCP_PORT_DELTA = 233


Lianmin Zheng's avatar
Lianmin Zheng committed
2239
2240
@dataclasses.dataclass
class PortArgs:
2241
2242
2243
2244
2245
2246
    # The ipc filename for tokenizer to receive inputs from detokenizer (zmq)
    tokenizer_ipc_name: str
    # The ipc filename for scheduler (rank 0) to receive inputs from tokenizer (zmq)
    scheduler_input_ipc_name: str
    # The ipc filename for detokenizer to receive inputs from scheduler (zmq)
    detokenizer_ipc_name: str
2247

2248
2249
    # The port for nccl initialization (torch.dist)
    nccl_port: int
2250

2251
2252
2253
    # The ipc filename for rpc call between Engine and Scheduler
    rpc_ipc_name: str

2254
2255
2256
    # The ipc filename for Scheduler to send metrics
    metrics_ipc_name: str

2257
    @staticmethod
2258
    def init_new(server_args, dp_rank: Optional[int] = None) -> "PortArgs":
2259
        if server_args.nccl_port is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
2260
            nccl_port = server_args.port + random.randint(100, 1000)
2261
            while True:
Lianmin Zheng's avatar
Lianmin Zheng committed
2262
                if is_port_available(nccl_port):
2263
                    break
Lianmin Zheng's avatar
Lianmin Zheng committed
2264
2265
                if nccl_port < 60000:
                    nccl_port += 42
2266
                else:
Lianmin Zheng's avatar
Lianmin Zheng committed
2267
                    nccl_port -= 43
2268
        else:
Lianmin Zheng's avatar
Lianmin Zheng committed
2269
            nccl_port = server_args.nccl_port
2270

2271
2272
2273
2274
2275
2276
        if not server_args.enable_dp_attention:
            # Normal case, use IPC within a single node
            return PortArgs(
                tokenizer_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
                scheduler_input_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
                detokenizer_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
Lianmin Zheng's avatar
Lianmin Zheng committed
2277
                nccl_port=nccl_port,
2278
                rpc_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
2279
                metrics_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
2280
2281
2282
2283
2284
            )
        else:
            # DP attention. Use TCP + port to handle both single-node and multi-node.
            if server_args.nnodes == 1 and server_args.dist_init_addr is None:
                dist_init_addr = ("127.0.0.1", server_args.port + ZMQ_TCP_PORT_DELTA)
Vincent's avatar
Vincent committed
2285
2286
2287
            elif server_args.dist_init_addr.startswith("["):  # ipv6 address
                port_num, host = configure_ipv6(server_args.dist_init_addr)
                dist_init_addr = (host, str(port_num))
2288
2289
            else:
                dist_init_addr = server_args.dist_init_addr.split(":")
Vincent's avatar
Vincent committed
2290

2291
2292
2293
2294
2295
2296
            assert (
                len(dist_init_addr) == 2
            ), "please provide --dist-init-addr as host:port of head node"

            dist_init_host, dist_init_port = dist_init_addr
            port_base = int(dist_init_port) + 1
2297
2298
2299
            detokenizer_port = port_base + 1
            rpc_port = port_base + 2
            metrics_ipc_name = port_base + 3
2300
            if dp_rank is None:
2301
                # TokenizerManager to DataParallelController
2302
                scheduler_input_port = port_base + 4
2303
            else:
2304
                scheduler_input_port = port_base + 4 + 1 + dp_rank
2305
2306
2307
2308

            return PortArgs(
                tokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base}",
                scheduler_input_ipc_name=f"tcp://{dist_init_host}:{scheduler_input_port}",
2309
                detokenizer_ipc_name=f"tcp://{dist_init_host}:{detokenizer_port}",
Lianmin Zheng's avatar
Lianmin Zheng committed
2310
                nccl_port=nccl_port,
2311
2312
                rpc_ipc_name=f"tcp://{dist_init_host}:{rpc_port}",
                metrics_ipc_name=f"tcp://{dist_init_host}:{metrics_ipc_name}",
2313
            )
2314

2315
2316
2317
2318
2319
2320
2321
2322
2323
2324

class LoRAPathAction(argparse.Action):
    def __call__(self, parser, namespace, values, option_string=None):
        setattr(namespace, self.dest, {})
        for lora_path in values:
            if "=" in lora_path:
                name, path = lora_path.split("=", 1)
                getattr(namespace, self.dest)[name] = path
            else:
                getattr(namespace, self.dest)[lora_path] = lora_path
2325
2326
2327
2328
2329
2330
2331
2332
2333
2334


class DeprecatedAction(argparse.Action):
    def __init__(self, option_strings, dest, nargs=0, **kwargs):
        super(DeprecatedAction, self).__init__(
            option_strings, dest, nargs=nargs, **kwargs
        )

    def __call__(self, parser, namespace, values, option_string=None):
        raise ValueError(self.help)
2335
2336


2337
def auto_choose_speculative_params(self: ServerArgs):
2338
2339
2340
2341
2342
    """
    Automatically choose the parameters for speculative decoding.

    You can tune them on your own models and prompts with scripts/playground/bench_speculative.py
    """
Lianmin Zheng's avatar
Lianmin Zheng committed
2343
    hf_config = self.get_hf_config()
2344
2345
    arch = hf_config.architectures[0]

2346
2347
2348
2349
2350
    if arch in ["LlamaForCausalLM"]:
        # The default value for llama
        return (5, 4, 8)
    elif arch in ["DeepseekV3ForCausalLM", "DeepseekV2ForCausalLM"]:
        # The default value for deepseek
2351
        return (3, 1, 4)
2352
2353
2354
2355
2356
    elif arch in ["Grok1ForCausalLM", "Grok1VForCausalLM"]:
        return (5, 4, 8)
    else:
        # The default value for all other models
        return (5, 4, 8)