server_args.py 99.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
Lianmin Zheng's avatar
Lianmin Zheng committed
14
15
"""The arguments of the server."""

Lianmin Zheng's avatar
Lianmin Zheng committed
16
17
import argparse
import dataclasses
18
import json
19
import logging
20
import os
21
import random
22
import sys
23
import tempfile
24
from typing import List, Literal, Optional, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
25

26
from sglang.srt.function_call.function_call_parser import FunctionCallParser
27
from sglang.srt.hf_transformers_utils import check_gguf_file, get_config
28
from sglang.srt.layers.utils import is_sm90_supported, is_sm100_supported
29
from sglang.srt.lora.lora_registry import LoRARef
Xihuai Wang's avatar
Xihuai Wang committed
30
from sglang.srt.reasoning_parser import ReasoningParser
31
from sglang.srt.utils import (
32
33
    LORA_TARGET_ALL_MODULES,
    SUPPORTED_LORA_TARGET_MODULES,
Vincent's avatar
Vincent committed
34
    configure_ipv6,
35
    get_device,
Lianmin Zheng's avatar
Lianmin Zheng committed
36
    get_device_memory_capacity,
37
    is_cuda,
38
    is_flashinfer_available,
HAI's avatar
HAI committed
39
    is_hip,
40
    is_port_available,
41
    is_remote_url,
42
    is_triton_kernels_available,
43
    is_valid_ipv6_address,
bjmsong's avatar
bjmsong committed
44
    nullable_str,
45
)
46

47
48
logger = logging.getLogger(__name__)

Lianmin Zheng's avatar
Lianmin Zheng committed
49
50
51

@dataclasses.dataclass
class ServerArgs:
Lianmin Zheng's avatar
Lianmin Zheng committed
52
    # Model and tokenizer
Lianmin Zheng's avatar
Lianmin Zheng committed
53
54
55
    model_path: str
    tokenizer_path: Optional[str] = None
    tokenizer_mode: str = "auto"
56
    skip_tokenizer_init: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
57
    load_format: str = "auto"
58
    model_loader_extra_config: str = "{}"
59
    trust_remote_code: bool = False
60
    context_length: Optional[int] = None
61
    is_embedding: bool = False
62
    enable_multimodal: Optional[bool] = None
63
    revision: Optional[str] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
64
    model_impl: str = "auto"
Lianmin Zheng's avatar
Lianmin Zheng committed
65

Lianmin Zheng's avatar
Lianmin Zheng committed
66
    # HTTP server
Lianmin Zheng's avatar
Lianmin Zheng committed
67
68
    host: str = "127.0.0.1"
    port: int = 30000
Lianmin Zheng's avatar
Lianmin Zheng committed
69
70
    skip_server_warmup: bool = False
    warmups: Optional[str] = None
71
    nccl_port: Optional[int] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
72

Lianmin Zheng's avatar
Lianmin Zheng committed
73
74
75
76
77
78
    # Quantization and data type
    dtype: str = "auto"
    quantization: Optional[str] = None
    quantization_param_path: Optional[str] = None
    kv_cache_dtype: str = "auto"

Lianmin Zheng's avatar
Lianmin Zheng committed
79
    # Memory and scheduling
Lianmin Zheng's avatar
Lianmin Zheng committed
80
    mem_fraction_static: Optional[float] = None
81
    max_running_requests: Optional[int] = None
82
    max_queued_requests: Optional[int] = sys.maxsize
83
    max_total_tokens: Optional[int] = None
84
    chunked_prefill_size: Optional[int] = None
85
    max_prefill_tokens: int = 16384
86
    schedule_policy: str = "fcfs"
87
    schedule_conservativeness: float = 1.0
88
    page_size: Optional[int] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
89
90
91
    hybrid_kvcache_ratio: Optional[float] = None
    swa_full_tokens_ratio: float = 0.8
    disable_hybrid_swa_memory: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
92

Lianmin Zheng's avatar
Lianmin Zheng committed
93
94
    # Runtime options
    device: Optional[str] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
95
    tp_size: int = 1
96
97
    pp_size: int = 1
    max_micro_batch_size: Optional[int] = None
98
    stream_interval: int = 1
99
    stream_output: bool = False
100
    random_seed: Optional[int] = None
101
    constrained_json_whitespace_pattern: Optional[str] = None
102
    watchdog_timeout: float = 300
103
    dist_timeout: Optional[int] = None  # timeout for torch.distributed
104
    download_dir: Optional[str] = None
105
    base_gpu_id: int = 0
106
    gpu_id_step: int = 1
107
    sleep_on_idle: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
108
109
110

    # Logging
    log_level: str = "info"
111
    log_level_http: Optional[str] = None
112
    log_requests: bool = False
113
    log_requests_level: int = 2
114
    crash_dump_folder: Optional[str] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
115
    show_time_cost: bool = False
116
    enable_metrics: bool = False
117
    enable_metrics_for_all_schedulers: bool = False
118
119
    bucket_time_to_first_token: Optional[List[float]] = None
    bucket_inter_token_latency: Optional[List[float]] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
120
    bucket_e2e_request_latency: Optional[List[float]] = None
121
    collect_tokens_histogram: bool = False
122
    decode_log_interval: int = 40
123
    enable_request_time_stats_logging: bool = False
124
    kv_events_config: Optional[str] = None
125
    gc_warning_threshold_secs: float = 0.0
Liangsheng Yin's avatar
Liangsheng Yin committed
126

127
    # API related
128
    api_key: Optional[str] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
129
    served_model_name: Optional[str] = None
130
    weight_version: str = "default"
Lianmin Zheng's avatar
Lianmin Zheng committed
131
132
    chat_template: Optional[str] = None
    completion_template: Optional[str] = None
133
    file_storage_path: str = "sglang_storage"
134
    enable_cache_report: bool = False
Xihuai Wang's avatar
Xihuai Wang committed
135
    reasoning_parser: Optional[str] = None
136
    tool_call_parser: Optional[str] = None
137
    tool_server: Optional[str] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
138

139
140
141
    # Data parallelism
    dp_size: int = 1
    load_balance_method: str = "round_robin"
142

143
    # Multi-node distributed serving
144
    dist_init_addr: Optional[str] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
145
    nnodes: int = 1
146
    node_rank: int = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
147
148
149

    # Model override args in JSON
    json_model_override_args: str = "{}"
150
    preferred_sampling_params: Optional[str] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
151

152
    # LoRA
153
    enable_lora: Optional[bool] = None
154
    max_lora_rank: Optional[int] = None
155
    lora_target_modules: Optional[Union[set[str], List[str]]] = None
156
157
158
    lora_paths: Optional[
        Union[dict[str, str], List[dict[str, str]], List[str], List[LoRARef]]
    ] = None
159
    max_loaded_loras: Optional[int] = None
160
    max_loras_per_batch: int = 8
161
    lora_backend: str = "triton"
162
163

    # Kernel backend
164
    attention_backend: Optional[str] = None
165
166
    decode_attention_backend: Optional[str] = None
    prefill_attention_backend: Optional[str] = None
167
    sampling_backend: Optional[str] = None
168
    grammar_backend: Optional[str] = None
169
    mm_attention_backend: Optional[str] = None
170

171
172
    # Speculative decoding
    speculative_algorithm: Optional[str] = None
173
    speculative_draft_model_path: Optional[str] = None
174
175
176
    speculative_num_steps: Optional[int] = None
    speculative_eagle_topk: Optional[int] = None
    speculative_num_draft_tokens: Optional[int] = None
177
178
    speculative_accept_threshold_single: float = 1.0
    speculative_accept_threshold_acc: float = 1.0
179
    speculative_token_map: Optional[str] = None
180

181
182
    # Expert parallelism
    ep_size: int = 1
183
184
185
186
187
188
189
190
191
    moe_a2a_backend: Literal["none", "deepep"] = "none"
    moe_runner_backend: Literal[
        "auto",
        "triton",
        "triton_kernel",
        "flashinfer_trtllm",
        "flashinfer_cutlass",
        "flashinfer_mxfp4",
    ] = "auto"
192
    flashinfer_mxfp4_moe_precision: Literal["default", "bf16"] = "default"
193
    enable_flashinfer_allreduce_fusion: bool = False
194
    deepep_mode: Literal["auto", "normal", "low_latency"] = "auto"
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
    ep_num_redundant_experts: int = 0
    ep_dispatch_algorithm: Optional[Literal["static", "dynamic", "fake"]] = None
    init_expert_location: str = "trivial"
    enable_eplb: bool = False
    eplb_algorithm: str = "auto"
    eplb_rebalance_num_iterations: int = 1000
    eplb_rebalance_layers_per_chunk: Optional[int] = None
    expert_distribution_recorder_mode: Optional[
        Literal["stat", "stat_approx", "per_pass", "per_token"]
    ] = None
    expert_distribution_recorder_buffer_size: Optional[int] = None
    enable_expert_distribution_metrics: bool = False
    deepep_config: Optional[str] = None
    moe_dense_tp_size: Optional[int] = None

Lianmin Zheng's avatar
Lianmin Zheng committed
210
211
212
213
214
    # Hierarchical cache
    enable_hierarchical_cache: bool = False
    hicache_ratio: float = 2.0
    hicache_size: int = 0
    hicache_write_policy: str = "write_through_selective"
215
216
    hicache_io_backend: str = "kernel"
    hicache_mem_layout: str = "layer_first"
Lianmin Zheng's avatar
Lianmin Zheng committed
217
    hicache_storage_backend: Optional[str] = None
pansicheng's avatar
pansicheng committed
218
    hicache_storage_prefetch_policy: str = "best_effort"
Lianmin Zheng's avatar
Lianmin Zheng committed
219

220
221
    # Double Sparsity
    enable_double_sparsity: bool = False
Vincent's avatar
Vincent committed
222
    ds_channel_config_path: Optional[str] = None
223
224
225
226
227
    ds_heavy_channel_num: int = 32
    ds_heavy_token_num: int = 256
    ds_heavy_channel_type: str = "qk"
    ds_sparse_decode_threshold: int = 4096

fzyzcjy's avatar
fzyzcjy committed
228
229
230
231
232
233
234
    # Offloading
    cpu_offload_gb: int = 0
    offload_group_size: int = -1
    offload_num_in_group: int = 1
    offload_prefetch_step: int = 1
    offload_mode: str = "cpu"

235
    # Optimization/debug options
Lianmin Zheng's avatar
Lianmin Zheng committed
236
    disable_radix_cache: bool = False
237
238
    cuda_graph_max_bs: Optional[int] = None
    cuda_graph_bs: Optional[List[int]] = None
239
    disable_cuda_graph: bool = False
240
    disable_cuda_graph_padding: bool = False
241
    enable_profile_cuda_graph: bool = False
242
    enable_cudagraph_gc: bool = False
243
    enable_nccl_nvls: bool = False
244
    enable_symm_mem: bool = False
245
    disable_flashinfer_cutlass_moe_fp4_allgather: bool = False
246
    enable_tokenizer_batch_encode: bool = False
247
    disable_outlines_disk_cache: bool = False
248
    disable_custom_all_reduce: bool = False
249
    enable_mscclpp: bool = False
250
    disable_overlap_schedule: bool = False
251
    enable_mixed_chunk: bool = False
Ke Bao's avatar
Ke Bao committed
252
    enable_dp_attention: bool = False
253
    enable_dp_lm_head: bool = False
254
    enable_two_batch_overlap: bool = False
255
    tbo_token_distribution_threshold: float = 0.48
256
    enable_torch_compile: bool = False
257
    torch_compile_max_bs: int = 32
258
    torchao_config: str = ""
259
    enable_nan_detection: bool = False
260
    enable_p2p_check: bool = False
261
    triton_attention_reduce_in_fp32: bool = False
262
    triton_attention_num_kv_splits: int = 8
263
    num_continuous_decode_steps: int = 1
264
    delete_ckpt_after_loading: bool = False
265
    enable_memory_saver: bool = False
266
    allow_auto_truncate: bool = False
267
    enable_custom_logit_processor: bool = False
268
    flashinfer_mla_disable_ragged: bool = False
269
    disable_shared_experts_fusion: bool = False
270
    disable_chunked_prefix_cache: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
271
    disable_fast_image_processor: bool = False
272
    enable_return_hidden_states: bool = False
273
    scheduler_recv_interval: int = 1
274
275
276
277
278

    # Debug tensor dumps
    debug_tensor_dump_output_folder: Optional[str] = None
    debug_tensor_dump_input_file: Optional[str] = None
    debug_tensor_dump_inject: bool = False
279
    debug_tensor_dump_prefill_only: bool = False
280

Lianmin Zheng's avatar
Lianmin Zheng committed
281
    # PD disaggregation: can be "null" (not disaggregated), "prefill" (prefill-only), or "decode" (decode-only)
Byron Hsu's avatar
Byron Hsu committed
282
    disaggregation_mode: str = "null"
283
    disaggregation_transfer_backend: str = "mooncake"
284
    disaggregation_bootstrap_port: int = 8998
Byron Hsu's avatar
Byron Hsu committed
285
286
287
    disaggregation_decode_tp: Optional[int] = None
    disaggregation_decode_dp: Optional[int] = None
    disaggregation_prefill_pp: Optional[int] = 1
288
    disaggregation_ib_device: Optional[str] = None
289
    num_reserved_decode_tokens: int = 512  # used for decode kv cache offload in PD
290
    pdlb_url: Optional[str] = None
Byron Hsu's avatar
Byron Hsu committed
291

292
293
    # For model weight update
    custom_weight_loader: Optional[List[str]] = None
294
    weight_loader_disable_mmap: bool = False
295

296
297
298
299
    # For PD-Multiplexing
    enable_pdmux: bool = False
    sm_group_num: int = 3

300
301
302
    # Deprecated arguments
    enable_ep_moe: bool = False
    enable_deepep_moe: bool = False
303
304
305
    enable_flashinfer_cutlass_moe: bool = False
    enable_flashinfer_trtllm_moe: bool = False
    enable_triton_kernel_moe: bool = False
306
    enable_flashinfer_mxfp4_moe: bool = False
307

Lianmin Zheng's avatar
Lianmin Zheng committed
308
    def __post_init__(self):
309
310
311
312
313
314
315
316
317
318
319
        # Check deprecated arguments
        if self.enable_ep_moe:
            self.ep_size = self.tp_size
            print_deprecated_warning(
                "NOTE: --enable-ep-moe is deprecated. Please set `--ep-size` to the same value as `--tp-size` instead."
            )
        if self.enable_deepep_moe:
            self.moe_a2a_backend = "deepep"
            print_deprecated_warning(
                "NOTE: --enable-deepep-moe is deprecated. Please set `--moe-a2a-backend` to 'deepep' instead."
            )
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
        if self.enable_triton_kernel_moe:
            self.moe_runner_backend = "triton_kernel"
            print_deprecated_warning(
                "NOTE: --enable-triton-kernel-moe is deprecated. Please set `--moe-runner-backend` to 'triton_kernel' instead."
            )
        if self.enable_flashinfer_cutlass_moe:
            self.moe_runner_backend = "flashinfer_cutlass"
            print_deprecated_warning(
                "NOTE: --enable-flashinfer-cutlass-moe is deprecated. Please set `--moe-runner-backend` to 'flashinfer_cutlass' instead."
            )
        if self.enable_flashinfer_trtllm_moe:
            self.moe_runner_backend = "flashinfer_trtllm"
            print_deprecated_warning(
                "NOTE: --enable-flashinfer-trtllm-moe is deprecated. Please set `--moe-runner-backend` to 'flashinfer_trtllm' instead."
            )
335
336
337
338
339
        if self.enable_flashinfer_mxfp4_moe:
            self.moe_runner_backend = "flashinfer_mxfp4"
            print_deprecated_warning(
                "NOTE: --enable-flashinfer-mxfp4-moe is deprecated. Please set `--moe-runner-backend` to 'flashinfer_mxfp4' instead."
            )
340

341
        # Set missing default values
Lianmin Zheng's avatar
Lianmin Zheng committed
342
343
        if self.tokenizer_path is None:
            self.tokenizer_path = self.model_path
344
345
        if self.served_model_name is None:
            self.served_model_name = self.model_path
346
347
        if self.device is None:
            self.device = get_device()
348
349
350
        if self.random_seed is None:
            self.random_seed = random.randint(0, 1 << 30)

Lianmin Zheng's avatar
Lianmin Zheng committed
351
        gpu_mem = get_device_memory_capacity(self.device)
352

353
        # Set mem fraction static
Lianmin Zheng's avatar
Lianmin Zheng committed
354
        if self.mem_fraction_static is None:
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
            if gpu_mem is not None:
                # GPU memory capacity = model weights + KV cache pool + activations + cuda graph buffers
                # mem_fraction_static = (model weights + KV cache pool) / GPU memory capacity.

                # We want mem_fraction_static to be as large as possible but still has enough room
                # for activations and cuda graph buffers. We use the following heuristic to
                # compute the needed size for activations and cuda graph buffers:
                # - The size of the activation depends on the chunked_prefill_size and model size.
                # - The size of cuda graph buffers depends on the cuda graph capture range and model size.
                # For GPUs with more memory, we use a larger chunked_prefill_size and
                # capture more cuda graphs, so they need to reserve more memory.
                parallel_size = self.tp_size * self.pp_size

                if gpu_mem < 20 * 1024:
                    # T4, 4080. (chunked_prefill_size 2k, cuda_graph_max_bs 8)
                    reserved_mem = (2.8 + parallel_size / 10) * 1024
                elif gpu_mem < 35 * 1024:
                    # A10, L40, 4090, 5090. (chunked_prefill_size 2k, cuda_graph_max_bs 8)
                    reserved_mem = (2.8 + parallel_size / 10) * 1024
                elif gpu_mem < 90 * 1024:
                    # H100, A100. (chunked_prefill_size 8k, cuda_graph_max_bs 160)
                    reserved_mem = (9.5 + parallel_size / 2) * 1024
                elif gpu_mem < 100 * 1024:
                    # H20. (chunked_prefill_size 8k, cuda_graph_max_bs 256)
                    reserved_mem = (12 + parallel_size / 2) * 1024
                elif gpu_mem < 160 * 1024:
                    # H200. (chunked_prefill_size 8k, cuda_graph_max_bs 256)
                    reserved_mem = (12 + parallel_size / 2) * 1024
383
                else:
384
385
386
                    # B200, MI300. (chunked_prefill_size 16k, cuda_graph_max_bs 512)
                    reserved_mem = 32 * 1024

387
                if self.speculative_algorithm is not None:
388
389
390
391
392
393
                    # draft model and larger cuda graph buffers
                    reserved_mem += 2 * 1024
                if self.enable_dp_attention:
                    reserved_mem += 4 * 1024

                self.mem_fraction_static = round((gpu_mem - reserved_mem) / gpu_mem, 3)
394
            else:
395
                self.mem_fraction_static = 0.88
396

397
            # Lazy init to avoid circular import
Lianmin Zheng's avatar
Lianmin Zheng committed
398
            # Multimodal models need more memory for the image processor
399
400
401
            from sglang.srt.configs.model_config import ModelConfig

            model_config = ModelConfig.from_server_args(self)
Lianmin Zheng's avatar
Lianmin Zheng committed
402
403
            if model_config.is_multimodal:
                self.adjust_mem_fraction_for_vlm(model_config)
404

405
406
        # Set chunked prefill size, which depends on the gpu memory capacity
        if self.chunked_prefill_size is None:
407
408
409
410
411
412
413
            if gpu_mem is not None:
                if gpu_mem < 35 * 1024:  # A10, L40, 4090
                    self.chunked_prefill_size = 2048
                elif gpu_mem < 160 * 1024:  # H100, H200, A100, H20
                    self.chunked_prefill_size = 8192
                else:  # B200, MI300
                    self.chunked_prefill_size = 16384
414
            else:
415
                self.chunked_prefill_size = 4096
Lianmin Zheng's avatar
Lianmin Zheng committed
416

417
418
419
420
421
422
423
424
425
        # Set cuda graph max batch size
        if self.cuda_graph_max_bs is None:
            # Based on detailed statistics, when serving TP1/TP2 models on lower-end GPUs with HBM<25G, you can either disable cuda graph or set `cuda_graph_max_bs` to a very small value to reduce the memory overhead of creating cuda graphs, with almost no impact on performance. However, when serving models with TP4 or TP8, we need to enable cuda graph to maintain high performance. In this case, we can set `cuda_graph_max_bs` to 80 (half of the default value 160) to reduce the memory overhead of creating cuda graphs. Looking at the logs from TP4 serving of qwen2-72b, a value of 80 is sufficient and can reduce the memory overhead of creating cuda graphs on lower-end GPUs compared to the original 160, avoiding OOM issues.
            if gpu_mem is not None and gpu_mem < 35 * 1024:
                if self.tp_size < 4:
                    self.cuda_graph_max_bs = 8
                else:
                    self.cuda_graph_max_bs = 80

426
        # Set kernel backends for hpu device
427
428
429
430
        if self.device == "hpu":
            self.attention_backend = "torch_native"
            self.sampling_backend = "pytorch"

431
432
433
        # Model-specific adjustments
        self.model_specific_adjustments()

Lianmin Zheng's avatar
Lianmin Zheng committed
434
        # Set kernel backends
435
436
437
438
439
        if self.device == "cpu":
            if self.attention_backend is None:
                self.attention_backend = "intel_amx"
            self.sampling_backend = "pytorch"

440
        if self.sampling_backend is None:
441
442
443
444
445
            self.sampling_backend = (
                "flashinfer" if is_flashinfer_available() else "pytorch"
            )

        if self.attention_backend == "torch_native":
446
            logger.warning(
447
448
449
                "Cuda graph is disabled because of using torch native attention backend"
            )
            self.disable_cuda_graph = True
450

451
452
453
454
455
456
        if self.attention_backend == "ascend":
            logger.warning(
                "At this moment Ascend attention backend only supports a page_size of 128, change page_size to 128."
            )
            self.page_size = 128

457
458
459
460
        if (
            self.attention_backend == "flashmla"
            or self.decode_attention_backend == "flashmla"
        ):
Lianmin Zheng's avatar
Lianmin Zheng committed
461
462
463
464
465
            logger.warning(
                "FlashMLA only supports a page_size of 64, change page_size to 64."
            )
            self.page_size = 64

466
467
468
469
        if (
            self.attention_backend == "cutlass_mla"
            or self.decode_attention_backend == "cutlass_mla"
        ):
Lianmin Zheng's avatar
Lianmin Zheng committed
470
471
472
473
474
            logger.warning(
                "Cutlass MLA only supports a page_size of 128, change page_size to 128."
            )
            self.page_size = 128

Faraz's avatar
Faraz committed
475
476
477
478
        if (
            self.attention_backend == "trtllm_mla"
            or self.decode_attention_backend == "trtllm_mla"
        ):
479
480
481
482
483
484
485
486
487
488
            if not is_sm100_supported():
                raise ValueError(
                    "TRTLLM MLA backend is only supported on Blackwell GPUs (SM100). Please use a different backend."
                )

            if self.page_size not in [32, 64]:
                logger.warning(
                    f"TensorRT-LLM MLA only supports page_size of 32 or 64, changing page_size from {self.page_size} to 64."
                )
                self.page_size = 64
Faraz's avatar
Faraz committed
489
490
491
492
493

            if self.kv_cache_dtype not in ["fp8_e4m3", "auto"]:
                raise ValueError(
                    "TensorRT-LLM MLA backend only supports kv-cache-dtype of fp8_e4m3 or auto."
                )
494

495
496
497
498
499
        if (
            self.attention_backend == "trtllm_mha"
            or self.decode_attention_backend == "trtllm_mha"
            or self.prefill_attention_backend == "trtllm_mha"
        ):
500
501
502
503
504
505
506
507
508
509
510
            if not is_sm100_supported():
                raise ValueError(
                    "TRTLLM MHA backend is only supported on Blackwell GPUs (SM100). Please use a different backend."
                )

            if self.page_size not in [16, 32, 64]:
                logger.warning(
                    f"TensorRT-LLM MHA only supports page_size of 16, 32 or 64, changing page_size from {self.page_size} to 64."
                )
                self.page_size = 64

511
512
        if self.attention_backend == "dual_chunk_flash_attn":
            logger.warning(
513
                "Mixed chunk, radix cache, and cuda graphs are disabled because of using dual chunk flash attention backend"
514
515
516
517
518
            )
            self.enable_mixed_chunk = False
            self.disable_cuda_graph = True
            self.disable_radix_cache = True

519
520
521
522
523
524
525
526
        # Set page size
        if self.page_size is None:
            self.page_size = 1

        # AMD-specific Triton attention KV splits default number
        if is_hip():
            self.triton_attention_num_kv_splits = 16

527
528
529
        # Choose grammar backend
        if self.grammar_backend is None:
            self.grammar_backend = "xgrammar"
530

531
        # Data parallelism attention
Ke Bao's avatar
Ke Bao committed
532
        if self.enable_dp_attention:
533
            self.schedule_conservativeness = self.schedule_conservativeness * 0.3
Lianmin Zheng's avatar
Lianmin Zheng committed
534
535
536
537
538
            assert (
                self.dp_size > 1
            ), "Please set a dp-size > 1. You can use 1 < dp-size <= tp-size "
            assert self.tp_size % self.dp_size == 0
            self.chunked_prefill_size = self.chunked_prefill_size // self.dp_size
539
            logger.warning(
540
                f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE kernel issues. "
541
            )
542

543
544
545
        if self.enable_dp_lm_head:
            assert (
                self.enable_dp_attention
546
            ), "Please enable dp attention when setting enable_dp_lm_head. "
547

548
        # MoE kernel
549
        if self.moe_runner_backend == "flashinfer_cutlass":
550
551
552
            assert (
                self.quantization == "modelopt_fp4"
            ), "modelopt_fp4 quantization is required for Flashinfer MOE"
553
554
555
556
            assert self.ep_size in [
                1,
                self.tp_size,
            ], "The expert parallel size must be 1 or the same as the tensor parallel size"
557

558
        if self.moe_runner_backend == "flashinfer_trtllm":
559
560
561
562
563
564
            if not self.disable_shared_experts_fusion:
                self.disable_shared_experts_fusion = True
                logger.warning(
                    "FlashInfer TRTLLM MoE is enabled. --disable-shared-experts-fusion is automatically set."
                )

565
        # DeepEP MoE
566
        if self.moe_a2a_backend == "deepep":
567
568
569
            if self.deepep_mode == "normal":
                logger.warning("Cuda graph is disabled because deepep_mode=`normal`")
                self.disable_cuda_graph = True
570
            self.ep_size = self.tp_size
Lianmin Zheng's avatar
Lianmin Zheng committed
571
            logger.warning(
572
573
                f"DeepEP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
            )
574

575
576
        if self.enable_eplb and (self.expert_distribution_recorder_mode is None):
            self.expert_distribution_recorder_mode = "stat"
577
            logger.warning(
578
                "EPLB is enabled. The expert_distribution_recorder_mode is automatically set."
579
580
581
582
583
584
585
            )

        if (self.enable_eplb or (self.init_expert_location is not None)) and (
            self.ep_dispatch_algorithm is None
        ):
            self.ep_dispatch_algorithm = "static"

586
        if self.enable_eplb:
587
            assert self.ep_size > 1
588

589
590
591
592
593
        if self.enable_expert_distribution_metrics and (
            self.expert_distribution_recorder_mode is None
        ):
            self.expert_distribution_recorder_mode = "stat"

594
        if self.expert_distribution_recorder_buffer_size is None:
595
596
            if (x := self.eplb_rebalance_num_iterations) is not None:
                self.expert_distribution_recorder_buffer_size = x
597
598
599
            elif self.expert_distribution_recorder_mode is not None:
                self.expert_distribution_recorder_buffer_size = 1000

Lianmin Zheng's avatar
Lianmin Zheng committed
600
601
602
603
604
605
606
        # Pipeline parallelism
        if self.pp_size > 1:
            self.disable_overlap_schedule = True
            logger.warning(
                "Pipeline parallelism is incompatible with overlap schedule."
            )

Lianmin Zheng's avatar
Lianmin Zheng committed
607
        # Hicache
608
609
610
611
612
        if self.hicache_storage_backend == "mooncake":
            # to use mooncake storage backend, the following conditions must be met:
            self.hicache_io_backend = "kernel"
            self.hicache_mem_layout = "page_first"

613
        # Speculative Decoding
614
615
616
617
        if self.speculative_algorithm == "NEXTN":
            # NEXTN shares the same implementation of EAGLE
            self.speculative_algorithm = "EAGLE"

Lianmin Zheng's avatar
Lianmin Zheng committed
618
        if self.speculative_algorithm in ("EAGLE", "EAGLE3"):
619
            if self.max_running_requests is None:
620
                self.max_running_requests = 48
621
            self.disable_overlap_schedule = True
Lianmin Zheng's avatar
Lianmin Zheng committed
622
            logger.warning(
623
                "Overlap scheduler is disabled because of using "
624
                "eagle speculative decoding."
625
            )
626
627
628
629
630
631
            if self.enable_mixed_chunk:
                self.enable_mixed_chunk = False
                logger.warning(
                    "Mixed chunked prefill is disabled because of using "
                    "eagle speculative decoding."
                )
632

Lianmin Zheng's avatar
Lianmin Zheng committed
633
            model_arch = self.get_hf_config().architectures[0]
Yuxuan Zhang's avatar
Yuxuan Zhang committed
634
            if model_arch in ["DeepseekV3ForCausalLM", "Glm4MoeForCausalLM"]:
Hanming Lu's avatar
Hanming Lu committed
635
                # Auto set draft_model_path DeepSeek-V3/R1
636
637
638
639
640
641
                if self.speculative_draft_model_path is None:
                    self.speculative_draft_model_path = self.model_path
                else:
                    logger.warning(
                        "DeepSeek MTP does not require setting speculative_draft_model_path."
                    )
642

643
644
645
646
647
648
649
650
651
652
            # Auto choose parameters
            if self.speculative_num_steps is None:
                assert (
                    self.speculative_eagle_topk is None
                    and self.speculative_num_draft_tokens is None
                )
                (
                    self.speculative_num_steps,
                    self.speculative_eagle_topk,
                    self.speculative_num_draft_tokens,
653
                ) = auto_choose_speculative_params(self)
654

655
656
657
658
659
660
661
662
663
664
            if (
                self.attention_backend == "trtllm_mha"
                or self.decode_attention_backend == "trtllm_mha"
                or self.prefill_attention_backend == "trtllm_mha"
            ):
                if self.speculative_eagle_topk > 1:
                    raise ValueError(
                        "trtllm_mha backend only supports topk = 1 for speculative decoding."
                    )

665
666
667
668
            if (
                self.speculative_eagle_topk == 1
                and self.speculative_num_draft_tokens != self.speculative_num_steps + 1
            ):
Lianmin Zheng's avatar
Lianmin Zheng committed
669
                logger.warning(
670
671
672
                    "speculative_num_draft_tokens is adjusted to speculative_num_steps + 1 when speculative_eagle_topk == 1"
                )
                self.speculative_num_draft_tokens = self.speculative_num_steps + 1
673

674
            # The token generated from the verify step is counted.
675
            # If sepculative_num_steps >= speculative_num_draft_tokens, the additional tokens will definitely be discarded.
676
            # assert self.speculative_num_steps < self.speculative_num_draft_tokens
677

678
679
680
681
682
683
        # GGUF
        if (
            self.load_format == "auto" or self.load_format == "gguf"
        ) and check_gguf_file(self.model_path):
            self.quantization = self.load_format = "gguf"

684
        # Model loading
685
686
        if is_remote_url(self.model_path):
            self.load_format = "remote"
687
688
        if self.custom_weight_loader is None:
            self.custom_weight_loader = []
689

Byron Hsu's avatar
Byron Hsu committed
690
        # PD disaggregation
Byron Hsu's avatar
Byron Hsu committed
691
692
693
694
695
696
697
698
        if self.disaggregation_mode == "decode":
            assert (
                self.disaggregation_decode_tp is None
            ), "Cannot set --disaggregation-decode-tp for the decode engine."
            assert (
                self.disaggregation_decode_dp is None
            ), "Cannot set --disaggregation-decode-dp for the decode engine."

Byron Hsu's avatar
Byron Hsu committed
699
            self.disable_radix_cache = True
700
            logger.warning("KV cache is forced as chunk cache for decode server")
Byron Hsu's avatar
Byron Hsu committed
701
702
703
704
705
706
707
708
709
710
711
        elif self.disaggregation_mode == "prefill":
            if self.disaggregation_decode_tp is None:
                self.disaggregation_decode_tp = self.tp_size
            if self.disaggregation_decode_dp is None:
                self.disaggregation_decode_dp = self.dp_size

            self.disaggregation_prefill_pp = self.pp_size
            self.validate_disagg_tp_size(self.tp_size, self.disaggregation_decode_tp)

            self.disable_cuda_graph = True
            logger.warning("Cuda graph is disabled for prefill server")
Byron Hsu's avatar
Byron Hsu committed
712

713
        # Propagate env vars
714
715
716
        os.environ["SGLANG_ENABLE_TORCH_COMPILE"] = (
            "1" if self.enable_torch_compile else "0"
        )
717
718
719
720
        # Set env var before grammar backends init
        os.environ["SGLANG_DISABLE_OUTLINES_DISK_CACHE"] = (
            "1" if self.disable_outlines_disk_cache else "0"
        )
721

722
723
724
725
726
727
        if self.enable_hierarchical_cache and self.disable_radix_cache:
            raise ValueError(
                "The arguments enable-hierarchical-cache and disable-radix-cache are mutually exclusive "
                "and cannot be used at the same time. Please use only one of them."
            )

Lianmin Zheng's avatar
Lianmin Zheng committed
728
729
    @staticmethod
    def add_cli_args(parser: argparse.ArgumentParser):
Lianmin Zheng's avatar
Lianmin Zheng committed
730
        # Model and tokenizer
Lianmin Zheng's avatar
Lianmin Zheng committed
731
732
        parser.add_argument(
            "--model-path",
733
            "--model",
Lianmin Zheng's avatar
Lianmin Zheng committed
734
735
736
737
738
739
740
741
742
743
            type=str,
            help="The path of the model weights. This can be a local folder or a Hugging Face repo ID.",
            required=True,
        )
        parser.add_argument(
            "--tokenizer-path",
            type=str,
            default=ServerArgs.tokenizer_path,
            help="The path of the tokenizer.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
744
745
746
747
748
749
750
751
752
        parser.add_argument(
            "--tokenizer-mode",
            type=str,
            default=ServerArgs.tokenizer_mode,
            choices=["auto", "slow"],
            help="Tokenizer mode. 'auto' will use the fast "
            "tokenizer if available, and 'slow' will "
            "always use the slow tokenizer.",
        )
753
754
755
        parser.add_argument(
            "--skip-tokenizer-init",
            action="store_true",
756
            help="If set, skip init tokenizer and pass input_ids in generate request.",
757
        )
758
        parser.add_argument(
Lianmin Zheng's avatar
Lianmin Zheng committed
759
760
761
            "--load-format",
            type=str,
            default=ServerArgs.load_format,
762
763
764
765
766
767
            choices=[
                "auto",
                "pt",
                "safetensors",
                "npcache",
                "dummy",
768
                "sharded_state",
769
770
                "gguf",
                "bitsandbytes",
771
                "layered",
772
                "remote",
773
            ],
Lianmin Zheng's avatar
Lianmin Zheng committed
774
775
776
777
778
779
780
781
782
            help="The format of the model weights to load. "
            '"auto" will try to load the weights in the safetensors format '
            "and fall back to the pytorch bin format if safetensors format "
            "is not available. "
            '"pt" will load the weights in the pytorch bin format. '
            '"safetensors" will load the weights in the safetensors format. '
            '"npcache" will load the weights in pytorch format and store '
            "a numpy cache to speed up the loading. "
            '"dummy" will initialize the weights with random values, '
783
            "which is mainly for profiling."
784
785
            '"gguf" will load the weights in the gguf format. '
            '"bitsandbytes" will load the weights using bitsandbytes '
786
787
788
789
            "quantization."
            '"layered" loads weights layer by layer so that one can quantize a '
            "layer before loading another to make the peak memory envelope "
            "smaller.",
Lianmin Zheng's avatar
Lianmin Zheng committed
790
        )
791
792
793
794
795
796
797
        parser.add_argument(
            "--model-loader-extra-config",
            type=str,
            help="Extra config for model loader. "
            "This will be passed to the model loader corresponding to the chosen load_format.",
            default=ServerArgs.model_loader_extra_config,
        )
798
799
800
801
802
        parser.add_argument(
            "--trust-remote-code",
            action="store_true",
            help="Whether or not to allow for custom models defined on the Hub in their own modeling files.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
        parser.add_argument(
            "--context-length",
            type=int,
            default=ServerArgs.context_length,
            help="The model's maximum context length. Defaults to None (will use the value from the model's config.json instead).",
        )
        parser.add_argument(
            "--is-embedding",
            action="store_true",
            help="Whether to use a CausalLM as an embedding model.",
        )
        parser.add_argument(
            "--enable-multimodal",
            default=ServerArgs.enable_multimodal,
            action="store_true",
            help="Enable the multimodal functionality for the served model. If the model being served is not multimodal, nothing will happen",
        )
        parser.add_argument(
            "--revision",
            type=str,
            default=None,
            help="The specific model version to use. It can be a branch "
            "name, a tag name, or a commit id. If unspecified, will use "
            "the default version.",
        )
        parser.add_argument(
            "--model-impl",
            type=str,
            default=ServerArgs.model_impl,
            help="Which implementation of the model to use.\n\n"
            '* "auto" will try to use the SGLang implementation if it exists '
            "and fall back to the Transformers implementation if no SGLang "
            "implementation is available.\n"
            '* "sglang" will use the SGLang model implementation.\n'
            '* "transformers" will use the Transformers model '
            "implementation.\n",
        )

        # HTTP server
        parser.add_argument(
            "--host",
            type=str,
            default=ServerArgs.host,
            help="The host of the HTTP server.",
        )
        parser.add_argument(
            "--port",
            type=int,
            default=ServerArgs.port,
            help="The port of the HTTP server.",
        )
        parser.add_argument(
            "--skip-server-warmup",
            action="store_true",
            help="If set, skip warmup.",
        )
        parser.add_argument(
            "--warmups",
            type=str,
            required=False,
            help="Specify custom warmup functions (csv) to run before server starts eg. --warmups=warmup_name1,warmup_name2 "
            "will run the functions `warmup_name1` and `warmup_name2` specified in warmup.py before the server starts listening for requests",
        )
        parser.add_argument(
            "--nccl-port",
            type=int,
            default=ServerArgs.nccl_port,
            help="The port for NCCL distributed environment setup. Defaults to a random port.",
        )

        # Quantization and data type
Lianmin Zheng's avatar
Lianmin Zheng committed
874
        parser.add_argument(
Lianmin Zheng's avatar
Lianmin Zheng committed
875
            "--dtype",
Cody Yu's avatar
Cody Yu committed
876
            type=str,
Lianmin Zheng's avatar
Lianmin Zheng committed
877
            default=ServerArgs.dtype,
Ying Sheng's avatar
Ying Sheng committed
878
879
            choices=["auto", "half", "float16", "bfloat16", "float", "float32"],
            help="Data type for model weights and activations.\n\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
880
            '* "auto" will use FP16 precision for FP32 and FP16 models, and '
Ying Sheng's avatar
Ying Sheng committed
881
            "BF16 precision for BF16 models.\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
882
883
884
885
            '* "half" for FP16. Recommended for AWQ quantization.\n'
            '* "float16" is the same as "half".\n'
            '* "bfloat16" for a balance between precision and range.\n'
            '* "float" is shorthand for FP32 precision.\n'
Ying Sheng's avatar
Ying Sheng committed
886
887
            '* "float32" for FP32 precision.',
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
888
889
890
891
        parser.add_argument(
            "--quantization",
            type=str,
            default=ServerArgs.quantization,
Ying Sheng's avatar
Ying Sheng committed
892
893
894
895
896
897
            choices=[
                "awq",
                "fp8",
                "gptq",
                "marlin",
                "gptq_marlin",
Ying Sheng's avatar
Ying Sheng committed
898
                "awq_marlin",
Ying Sheng's avatar
Ying Sheng committed
899
                "bitsandbytes",
900
                "gguf",
901
                "modelopt",
902
                "modelopt_fp4",
903
                "petit_nvfp4",
904
                "w8a8_int8",
HandH1998's avatar
HandH1998 committed
905
                "w8a8_fp8",
AniZpZ's avatar
AniZpZ committed
906
                "moe_wna16",
HandH1998's avatar
HandH1998 committed
907
                "qoq",
908
                "w4afp8",
909
                "mxfp4",
Ying Sheng's avatar
Ying Sheng committed
910
            ],
Lianmin Zheng's avatar
Lianmin Zheng committed
911
912
            help="The quantization method.",
        )
913
914
915
916
917
918
919
920
921
        parser.add_argument(
            "--quantization-param-path",
            type=nullable_str,
            default=None,
            help="Path to the JSON file containing the KV cache "
            "scaling factors. This should generally be supplied, when "
            "KV cache dtype is FP8. Otherwise, KV cache scaling factors "
            "default to 1.0, which may cause accuracy issues. ",
        )
922
        parser.add_argument(
Lianmin Zheng's avatar
Lianmin Zheng committed
923
            "--kv-cache-dtype",
924
            type=str,
Lianmin Zheng's avatar
Lianmin Zheng committed
925
926
927
            default=ServerArgs.kv_cache_dtype,
            choices=["auto", "fp8_e5m2", "fp8_e4m3"],
            help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" and "fp8_e4m3" is supported for CUDA 11.8+.',
928
        )
929

930
        # Memory and scheduling
Lianmin Zheng's avatar
Lianmin Zheng committed
931
932
933
934
        parser.add_argument(
            "--mem-fraction-static",
            type=float,
            default=ServerArgs.mem_fraction_static,
935
            help="The fraction of the memory used for static allocation (model weights and KV cache memory pool). Use a smaller value if you see out-of-memory errors.",
Lianmin Zheng's avatar
Lianmin Zheng committed
936
        )
937
938
939
940
941
942
        parser.add_argument(
            "--max-running-requests",
            type=int,
            default=ServerArgs.max_running_requests,
            help="The maximum number of running requests.",
        )
943
944
945
946
947
948
        parser.add_argument(
            "--max-queued-requests",
            type=int,
            default=ServerArgs.max_queued_requests,
            help="The maximum number of queued requests. This option is ignored when using disaggregation-mode.",
        )
949
950
951
952
        parser.add_argument(
            "--max-total-tokens",
            type=int,
            default=ServerArgs.max_total_tokens,
953
954
            help="The maximum number of tokens in the memory pool. If not specified, it will be automatically calculated based on the memory usage fraction. "
            "This option is typically used for development and debugging purposes.",
955
        )
956
957
958
959
        parser.add_argument(
            "--chunked-prefill-size",
            type=int,
            default=ServerArgs.chunked_prefill_size,
960
            help="The maximum number of tokens in a chunk for the chunked prefill. Setting this to -1 means disabling chunked prefill.",
961
962
963
964
965
966
967
        )
        parser.add_argument(
            "--max-prefill-tokens",
            type=int,
            default=ServerArgs.max_prefill_tokens,
            help="The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model's maximum context length.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
968
        parser.add_argument(
969
            "--schedule-policy",
Lianmin Zheng's avatar
Lianmin Zheng committed
970
            type=str,
971
            default=ServerArgs.schedule_policy,
972
            choices=["lpm", "random", "fcfs", "dfs-weight", "lof"],
973
            help="The scheduling policy of the requests.",
Lianmin Zheng's avatar
Lianmin Zheng committed
974
        )
975
976
977
978
        parser.add_argument(
            "--schedule-conservativeness",
            type=float,
            default=ServerArgs.schedule_conservativeness,
979
            help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see requests being retracted frequently.",
980
        )
981
982
983
984
985
986
        parser.add_argument(
            "--page-size",
            type=int,
            default=ServerArgs.page_size,
            help="The number of tokens in a page.",
        )
tarinkk's avatar
tarinkk committed
987
988
989
990
991
992
993
994
995
996
997
998
        parser.add_argument(
            "--hybrid-kvcache-ratio",
            nargs="?",
            const=0.5,
            type=float,
            default=ServerArgs.hybrid_kvcache_ratio,
            help=(
                "Mix ratio in [0,1] between uniform and hybrid kv buffers "
                "(0.0 = pure uniform: swa_size / full_size = 1)"
                "(1.0 = pure hybrid: swa_size / full_size = local_attention_size / context_length)"
            ),
        )
Hanming Lu's avatar
Hanming Lu committed
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
        parser.add_argument(
            "--swa-full-tokens-ratio",
            type=float,
            default=ServerArgs.swa_full_tokens_ratio,
            help="The ratio of SWA layer KV tokens / full layer KV tokens, regardless of the number of swa:full layers. It should be between 0 and 1. "
            "E.g. 0.5 means if each swa layer has 50 tokens, then each full layer has 100 tokens.",
        )
        parser.add_argument(
            "--disable-hybrid-swa-memory",
            action="store_true",
            help="Disable the hybrid SWA memory.",
        )
1011

Lianmin Zheng's avatar
Lianmin Zheng committed
1012
1013
1014
1015
1016
1017
1018
        # Runtime options
        parser.add_argument(
            "--device",
            type=str,
            default=ServerArgs.device,
            help="The device to use ('cuda', 'xpu', 'hpu', 'npu', 'cpu'). Defaults to auto-detection if not specified.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1019
        parser.add_argument(
1020
            "--tensor-parallel-size",
Lianmin Zheng's avatar
Lianmin Zheng committed
1021
            "--tp-size",
Lianmin Zheng's avatar
Lianmin Zheng committed
1022
            type=int,
Lianmin Zheng's avatar
Lianmin Zheng committed
1023
            default=ServerArgs.tp_size,
1024
            help="The tensor parallelism size.",
1025
        )
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
        parser.add_argument(
            "--pipeline-parallel-size",
            "--pp-size",
            type=int,
            default=ServerArgs.pp_size,
            help="The pipeline parallelism size.",
        )
        parser.add_argument(
            "--max-micro-batch-size",
            type=int,
            default=ServerArgs.max_micro_batch_size,
            help="The maximum micro batch size in pipeline parallelism.",
        )
1039
1040
1041
        parser.add_argument(
            "--stream-interval",
            type=int,
Lianmin Zheng's avatar
Lianmin Zheng committed
1042
            default=ServerArgs.stream_interval,
1043
            help="The interval (or buffer size) for streaming in terms of the token length. A smaller value makes streaming smoother, while a larger value makes the throughput higher",
1044
        )
1045
1046
1047
1048
1049
        parser.add_argument(
            "--stream-output",
            action="store_true",
            help="Whether to output as a sequence of disjoint segments.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1050
1051
1052
1053
        parser.add_argument(
            "--random-seed",
            type=int,
            default=ServerArgs.random_seed,
1054
            help="The random seed.",
Lianmin Zheng's avatar
Lianmin Zheng committed
1055
        )
1056
1057
1058
1059
        parser.add_argument(
            "--constrained-json-whitespace-pattern",
            type=str,
            default=ServerArgs.constrained_json_whitespace_pattern,
Lianmin Zheng's avatar
Lianmin Zheng committed
1060
            help="(outlines backend only) Regex pattern for syntactic whitespaces allowed in JSON constrained output. For example, to allow the model generate consecutive whitespaces, set the pattern to [\n\t ]*",
1061
        )
1062
1063
1064
1065
1066
1067
        parser.add_argument(
            "--watchdog-timeout",
            type=float,
            default=ServerArgs.watchdog_timeout,
            help="Set watchdog timeout in seconds. If a forward batch takes longer than this, the server will crash to prevent hanging.",
        )
1068
1069
1070
1071
1072
1073
        parser.add_argument(
            "--dist-timeout",
            type=int,
            default=ServerArgs.dist_timeout,
            help="Set timeout for torch.distributed initialization.",
        )
1074
1075
1076
1077
        parser.add_argument(
            "--download-dir",
            type=str,
            default=ServerArgs.download_dir,
1078
            help="Model download directory for huggingface.",
1079
        )
1080
1081
1082
1083
1084
1085
        parser.add_argument(
            "--base-gpu-id",
            type=int,
            default=ServerArgs.base_gpu_id,
            help="The base GPU ID to start allocating GPUs from. Useful when running multiple instances on the same machine.",
        )
1086
1087
1088
1089
1090
1091
        parser.add_argument(
            "--gpu-id-step",
            type=int,
            default=ServerArgs.gpu_id_step,
            help="The delta between consecutive GPU IDs that are used. For example, setting it to 2 will use GPU 0,2,4,...",
        )
1092
1093
1094
1095
1096
        parser.add_argument(
            "--sleep-on-idle",
            action="store_true",
            help="Reduce CPU usage when sglang is idle.",
        )
1097
1098

        # Logging
Lianmin Zheng's avatar
Lianmin Zheng committed
1099
1100
1101
1102
        parser.add_argument(
            "--log-level",
            type=str,
            default=ServerArgs.log_level,
1103
            help="The logging level of all loggers.",
Lianmin Zheng's avatar
Lianmin Zheng committed
1104
        )
1105
        parser.add_argument(
1106
1107
1108
1109
            "--log-level-http",
            type=str,
            default=ServerArgs.log_level_http,
            help="The logging level of HTTP server. If not set, reuse --log-level by default.",
1110
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1111
        parser.add_argument(
1112
            "--log-requests",
Lianmin Zheng's avatar
Lianmin Zheng committed
1113
            action="store_true",
1114
1115
1116
1117
1118
            help="Log metadata, inputs, outputs of all requests. The verbosity is decided by --log-requests-level",
        )
        parser.add_argument(
            "--log-requests-level",
            type=int,
1119
            default=ServerArgs.log_requests_level,
1120
1121
1122
1123
1124
1125
1126
1127
            help="0: Log metadata (no sampling parameters). 1: Log metadata and sampling parameters. 2: Log metadata, sampling parameters and partial input/output. 3: Log every input/output.",
            choices=[0, 1, 2, 3],
        )
        parser.add_argument(
            "--crash-dump-folder",
            type=str,
            default=ServerArgs.crash_dump_folder,
            help="Folder path to dump requests from the last 5 min before a crash (if any). If not specified, crash dumping is disabled.",
Lianmin Zheng's avatar
Lianmin Zheng committed
1128
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1129
1130
1131
        parser.add_argument(
            "--show-time-cost",
            action="store_true",
Ying Sheng's avatar
Ying Sheng committed
1132
            help="Show time cost of custom marks.",
Lianmin Zheng's avatar
Lianmin Zheng committed
1133
        )
1134
1135
1136
1137
1138
        parser.add_argument(
            "--enable-metrics",
            action="store_true",
            help="Enable log prometheus metrics.",
        )
1139
1140
1141
1142
1143
1144
1145
        parser.add_argument(
            "--enable-metrics-for-all-schedulers",
            action="store_true",
            help="Enable --enable-metrics-for-all-schedulers when you want schedulers on all TP ranks (not just TP 0) "
            "to record request metrics separately. This is especially useful when dp_attention is enabled, as "
            "otherwise all metrics appear to come from TP 0.",
        )
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
        parser.add_argument(
            "--bucket-time-to-first-token",
            type=float,
            nargs="+",
            default=ServerArgs.bucket_time_to_first_token,
            help="The buckets of time to first token, specified as a list of floats.",
        )
        parser.add_argument(
            "--bucket-inter-token-latency",
            type=float,
            nargs="+",
            default=ServerArgs.bucket_inter_token_latency,
            help="The buckets of inter-token latency, specified as a list of floats.",
        )
        parser.add_argument(
            "--bucket-e2e-request-latency",
            type=float,
            nargs="+",
            default=ServerArgs.bucket_e2e_request_latency,
            help="The buckets of end-to-end request latency, specified as a list of floats.",
        )
        parser.add_argument(
            "--collect-tokens-histogram",
            action="store_true",
            default=ServerArgs.collect_tokens_histogram,
            help="Collect prompt/generation tokens histogram.",
        )
1173
1174
1175
1176
1177
1178
        parser.add_argument(
            "--gc-warning-threshold-secs",
            type=float,
            default=ServerArgs.gc_warning_threshold_secs,
            help="The threshold for long GC warning. If a GC takes longer than this, a warning will be logged. Set to 0 to disable.",
        )
1179
1180
1181
1182
        parser.add_argument(
            "--decode-log-interval",
            type=int,
            default=ServerArgs.decode_log_interval,
1183
            help="The log interval of decode batch.",
1184
        )
1185
1186
1187
1188
1189
1190
        parser.add_argument(
            "--enable-request-time-stats-logging",
            action="store_true",
            default=ServerArgs.enable_request_time_stats_logging,
            help="Enable per request time stats logging",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1191
1192
1193
1194
1195
1196
        parser.add_argument(
            "--kv-events-config",
            type=str,
            default=None,
            help="Config in json format for NVIDIA dynamo KV event publishing. Publishing will be enabled if this flag is used.",
        )
1197

1198
        # API related
Liangsheng Yin's avatar
Liangsheng Yin committed
1199
1200
1201
1202
        parser.add_argument(
            "--api-key",
            type=str,
            default=ServerArgs.api_key,
1203
            help="Set API key of the server. It is also used in the OpenAI API compatible server.",
Liangsheng Yin's avatar
Liangsheng Yin committed
1204
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1205
1206
1207
1208
1209
1210
        parser.add_argument(
            "--served-model-name",
            type=str,
            default=ServerArgs.served_model_name,
            help="Override the model name returned by the v1/models endpoint in OpenAI API server.",
        )
1211
1212
1213
1214
1215
1216
        parser.add_argument(
            "--weight-version",
            type=str,
            default=ServerArgs.weight_version,
            help="Version identifier for the model weights. Defaults to 'default' if not specified.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
        parser.add_argument(
            "--chat-template",
            type=str,
            default=ServerArgs.chat_template,
            help="The buliltin chat template name or the path of the chat template file. This is only used for OpenAI-compatible API server.",
        )
        parser.add_argument(
            "--completion-template",
            type=str,
            default=ServerArgs.completion_template,
            help="The buliltin completion template name or the path of the completion template file. This is only used for OpenAI-compatible API server. only for code completion currently.",
        )
1229
        parser.add_argument(
1230
            "--file-storage-path",
1231
            type=str,
1232
            default=ServerArgs.file_storage_path,
1233
1234
            help="The path of the file storage in backend.",
        )
1235
1236
1237
1238
1239
        parser.add_argument(
            "--enable-cache-report",
            action="store_true",
            help="Return number of cached tokens in usage.prompt_tokens_details for each openai request.",
        )
Xihuai Wang's avatar
Xihuai Wang committed
1240
1241
1242
1243
1244
1245
1246
        parser.add_argument(
            "--reasoning-parser",
            type=str,
            choices=list(ReasoningParser.DetectorMap.keys()),
            default=ServerArgs.reasoning_parser,
            help=f"Specify the parser for reasoning models, supported parsers are: {list(ReasoningParser.DetectorMap.keys())}.",
        )
1247
        tool_call_parser_choices = list(FunctionCallParser.ToolCallParserEnum.keys())
1248
1249
1250
        parser.add_argument(
            "--tool-call-parser",
            type=str,
1251
            choices=tool_call_parser_choices,
1252
            default=ServerArgs.tool_call_parser,
1253
            help=f"Specify the parser for handling tool-call interactions. Options include: {tool_call_parser_choices}.",
1254
        )
1255
1256
1257
1258
1259
1260
        parser.add_argument(
            "--tool-server",
            type=str,
            default=None,
            help="Either 'demo' or a comma-separated list of tool server urls to use for the model. If not specified, no tool server will be used.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1261

1262
1263
        # Data parallelism
        parser.add_argument(
1264
            "--data-parallel-size",
1265
1266
1267
            "--dp-size",
            type=int,
            default=ServerArgs.dp_size,
1268
            help="The data parallelism size.",
1269
1270
1271
1272
1273
        )
        parser.add_argument(
            "--load-balance-method",
            type=str,
            default=ServerArgs.load_balance_method,
1274
            help="The load balancing strategy for data parallelism.",
1275
1276
1277
            choices=[
                "round_robin",
                "shortest_queue",
1278
                "minimum_tokens",
1279
1280
            ],
        )
1281

1282
        # Multi-node distributed serving
1283
        parser.add_argument(
1284
            "--dist-init-addr",
1285
            "--nccl-init-addr",  # For backward compatibility. This will be removed in the future.
1286
            type=str,
1287
            help="The host address for initializing distributed backend (e.g., `192.168.0.2:25000`).",
1288
1289
        )
        parser.add_argument(
Liangsheng Yin's avatar
Liangsheng Yin committed
1290
            "--nnodes", type=int, default=ServerArgs.nnodes, help="The number of nodes."
1291
        )
1292
1293
1294
        parser.add_argument(
            "--node-rank", type=int, default=ServerArgs.node_rank, help="The node rank."
        )
1295

Lianmin Zheng's avatar
Lianmin Zheng committed
1296
1297
1298
1299
1300
1301
1302
        # Model override args
        parser.add_argument(
            "--json-model-override-args",
            type=str,
            help="A dictionary in JSON string format used to override default model configurations.",
            default=ServerArgs.json_model_override_args,
        )
1303
1304
1305
1306
1307
        parser.add_argument(
            "--preferred-sampling-params",
            type=str,
            help="json-formatted sampling settings that will be returned in /get_model_info",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1308

1309
        # LoRA
1310
1311
1312
1313
1314
1315
        parser.add_argument(
            "--enable-lora",
            default=ServerArgs.enable_lora,
            action="store_true",
            help="Enable LoRA support for the model. This argument is automatically set to True if `--lora-paths` is provided for backward compatibility.",
        )
1316
1317
1318
1319
1320
1321
1322
1323
1324
        parser.add_argument(
            "--max-lora-rank",
            default=ServerArgs.max_lora_rank,
            type=int,
            help="The maximum rank of LoRA adapters. If not specified, it will be automatically inferred from the adapters provided in --lora-paths.",
        )
        parser.add_argument(
            "--lora-target-modules",
            type=str,
1325
            choices=SUPPORTED_LORA_TARGET_MODULES + [LORA_TARGET_ALL_MODULES],
1326
1327
            nargs="*",
            default=None,
1328
1329
1330
            help="The union set of all target modules where LoRA should be applied. If not specified, "
            "it will be automatically inferred from the adapters provided in --lora-paths. If 'all' is specified, "
            "all supported modules will be targeted.",
1331
        )
1332
1333
1334
1335
1336
1337
        parser.add_argument(
            "--lora-paths",
            type=str,
            nargs="*",
            default=None,
            action=LoRAPathAction,
1338
            help='The list of LoRA adapters to load. Each adapter must be specified in one of the following formats: <PATH> | <NAME>=<PATH> | JSON with schema {"lora_name":str,"lora_path":str,"pinned":bool}',
1339
1340
1341
1342
1343
        )
        parser.add_argument(
            "--max-loras-per-batch",
            type=int,
            default=8,
1344
1345
            help="Maximum number of adapters for a running batch, include base-only request.",
        )
1346
1347
1348
1349
1350
1351
        parser.add_argument(
            "--max-loaded-loras",
            type=int,
            default=ServerArgs.max_loaded_loras,
            help="If specified, it limits the maximum number of LoRA adapters loaded in CPU memory at a time. The value must be greater than or equal to `--max-loras-per-batch`.",
        )
1352
1353
1354
1355
1356
        parser.add_argument(
            "--lora-backend",
            type=str,
            default="triton",
            help="Choose the kernel backend for multi-LoRA serving.",
1357
1358
1359
        )

        # Kernel backend
1360
        ATTN_BACKENDS = [
Lianmin Zheng's avatar
Lianmin Zheng committed
1361
1362
1363
1364
            # Common
            "triton",
            "torch_native",
            # NVIDIA specific
1365
1366
1367
1368
1369
1370
1371
            "cutlass_mla",
            "fa3",
            "flashinfer",
            "flashmla",
            "trtllm_mla",
            "trtllm_mha",
            "dual_chunk_flash_attn",
Lianmin Zheng's avatar
Lianmin Zheng committed
1372
1373
            # AMD specific
            "aiter",
1374
            "wave",
Lianmin Zheng's avatar
Lianmin Zheng committed
1375
1376
1377
            # Other platforms
            "intel_amx",
            "ascend",
1378
        ]
1379
1380
1381
        parser.add_argument(
            "--attention-backend",
            type=str,
1382
            choices=ATTN_BACKENDS,
1383
1384
1385
            default=ServerArgs.attention_backend,
            help="Choose the kernels for attention layers.",
        )
1386
1387
1388
        parser.add_argument(
            "--prefill-attention-backend",
            type=str,
1389
            choices=ATTN_BACKENDS,
1390
1391
1392
            default=ServerArgs.prefill_attention_backend,
            help="Choose the kernels for prefill attention layers (have priority over --attention-backend).",
        )
1393
1394
1395
1396
1397
1398
1399
        parser.add_argument(
            "--decode-attention-backend",
            type=str,
            choices=ATTN_BACKENDS,
            default=ServerArgs.decode_attention_backend,
            help="Choose the kernels for decode attention layers (have priority over --attention-backend).",
        )
1400
1401
1402
1403
1404
1405
1406
        parser.add_argument(
            "--sampling-backend",
            type=str,
            choices=["flashinfer", "pytorch"],
            default=ServerArgs.sampling_backend,
            help="Choose the kernels for sampling layers.",
        )
1407
1408
1409
        parser.add_argument(
            "--grammar-backend",
            type=str,
1410
            choices=["xgrammar", "outlines", "llguidance", "none"],
1411
            default=ServerArgs.grammar_backend,
Lianmin Zheng's avatar
Lianmin Zheng committed
1412
            help="Choose the backend for grammar-guided decoding.",
1413
        )
1414
1415
1416
1417
1418
1419
1420
        parser.add_argument(
            "--mm-attention-backend",
            type=str,
            choices=["sdpa", "fa3", "triton_attn"],
            default=ServerArgs.mm_attention_backend,
            help="Set multimodal attention backend.",
        )
1421

1422
1423
1424
1425
        # Speculative decoding
        parser.add_argument(
            "--speculative-algorithm",
            type=str,
James Liu's avatar
James Liu committed
1426
            choices=["EAGLE", "EAGLE3", "NEXTN"],
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
            help="Speculative algorithm.",
        )
        parser.add_argument(
            "--speculative-draft-model-path",
            type=str,
            help="The path of the draft model weights. This can be a local folder or a Hugging Face repo ID.",
        )
        parser.add_argument(
            "--speculative-num-steps",
            type=int,
            help="The number of steps sampled from draft model in Speculative Decoding.",
            default=ServerArgs.speculative_num_steps,
        )
        parser.add_argument(
            "--speculative-eagle-topk",
            type=int,
1443
            help="The number of tokens sampled from the draft model in eagle2 each step.",
1444
1445
            default=ServerArgs.speculative_eagle_topk,
        )
1446
1447
1448
        parser.add_argument(
            "--speculative-num-draft-tokens",
            type=int,
1449
            help="The number of tokens sampled from the draft model in Speculative Decoding.",
1450
1451
            default=ServerArgs.speculative_num_draft_tokens,
        )
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
        parser.add_argument(
            "--speculative-accept-threshold-single",
            type=float,
            help="Accept a draft token if its probability in the target model is greater than this threshold.",
            default=ServerArgs.speculative_accept_threshold_single,
        )
        parser.add_argument(
            "--speculative-accept-threshold-acc",
            type=float,
            help="The accept probability of a draft token is raised from its target probability p to min(1, p / threshold_acc).",
            default=ServerArgs.speculative_accept_threshold_acc,
        )
1464
1465
1466
1467
1468
1469
        parser.add_argument(
            "--speculative-token-map",
            type=str,
            help="The path of the draft model's small vocab table.",
            default=ServerArgs.speculative_token_map,
        )
1470
1471
1472
1473
1474

        # Expert parallelism
        parser.add_argument(
            "--expert-parallel-size",
            "--ep-size",
Cheng Wan's avatar
Cheng Wan committed
1475
            "--ep",
1476
1477
1478
1479
1480
            type=int,
            default=ServerArgs.ep_size,
            help="The expert parallelism size.",
        )
        parser.add_argument(
1481
1482
            "--moe-a2a-backend",
            type=str,
1483
            choices=["none", "deepep"],
1484
1485
            default=ServerArgs.moe_a2a_backend,
            help="Choose the backend for MoE A2A.",
1486
        )
1487
        parser.add_argument(
1488
1489
1490
1491
1492
1493
1494
1495
            "--moe-runner-backend",
            type=str,
            choices=[
                "auto",
                "triton",
                "triton_kernel",
                "flashinfer_trtllm",
                "flashinfer_cutlass",
1496
                "flashinfer_mxfp4",
1497
1498
1499
            ],
            default=ServerArgs.moe_runner_backend,
            help="Choose the runner backend for MoE.",
1500
1501
        )
        parser.add_argument(
1502
1503
1504
1505
1506
1507
1508
            "--flashinfer-mxfp4-moe-precision",
            type=str,
            choices=["mxfp4", "bf16"],
            default=ServerArgs.flashinfer_mxfp4_moe_precision,
            help="Choose the computation precision of flashinfer mxfp4 moe",
        )
        parser.add_argument(
1509
1510
            "--enable-flashinfer-allreduce-fusion",
            action="store_true",
1511
            help="Enable FlashInfer allreduce fusion with Residual RMSNorm.",
1512
        )
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
        parser.add_argument(
            "--deepep-mode",
            type=str,
            choices=["normal", "low_latency", "auto"],
            default="auto",
            help="Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch.",
        )
        parser.add_argument(
            "--ep-num-redundant-experts",
            type=int,
            default=ServerArgs.ep_num_redundant_experts,
            help="Allocate this number of redundant experts in expert parallel.",
        )
        parser.add_argument(
            "--ep-dispatch-algorithm",
            type=str,
            default=ServerArgs.ep_dispatch_algorithm,
            help="The algorithm to choose ranks for redundant experts in expert parallel.",
        )
        parser.add_argument(
            "--init-expert-location",
            type=str,
            default=ServerArgs.init_expert_location,
            help="Initial location of EP experts.",
        )
        parser.add_argument(
            "--enable-eplb",
            action="store_true",
            help="Enable EPLB algorithm",
        )
        parser.add_argument(
            "--eplb-algorithm",
            type=str,
            default=ServerArgs.eplb_algorithm,
            help="Chosen EPLB algorithm",
        )
        parser.add_argument(
            "--eplb-rebalance-num-iterations",
            type=int,
            default=ServerArgs.eplb_rebalance_num_iterations,
            help="Number of iterations to automatically trigger a EPLB re-balance.",
        )
        parser.add_argument(
            "--eplb-rebalance-layers-per-chunk",
            type=int,
            default=ServerArgs.eplb_rebalance_layers_per_chunk,
            help="Number of layers to rebalance per forward pass.",
        )
        parser.add_argument(
            "--expert-distribution-recorder-mode",
            type=str,
            default=ServerArgs.expert_distribution_recorder_mode,
            help="Mode of expert distribution recorder.",
        )
        parser.add_argument(
            "--expert-distribution-recorder-buffer-size",
            type=int,
            default=ServerArgs.expert_distribution_recorder_buffer_size,
            help="Circular buffer size of expert distribution recorder. Set to -1 to denote infinite buffer.",
        )
        parser.add_argument(
            "--enable-expert-distribution-metrics",
            action="store_true",
            help="Enable logging metrics for expert balancedness",
        )
        parser.add_argument(
            "--deepep-config",
            type=str,
            default=ServerArgs.deepep_config,
            help="Tuned DeepEP config suitable for your own cluster. It can be either a string with JSON content or a file path.",
        )
        parser.add_argument(
            "--moe-dense-tp-size",
            type=int,
            default=ServerArgs.moe_dense_tp_size,
            help="TP size for MoE dense MLP layers. This flag is useful when, with large TP size, there are errors caused by weights in MLP layers having dimension smaller than the min dimension GEMM supports.",
        )
1590

Lianmin Zheng's avatar
Lianmin Zheng committed
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
        # Hierarchical cache
        parser.add_argument(
            "--enable-hierarchical-cache",
            action="store_true",
            help="Enable hierarchical cache",
        )
        parser.add_argument(
            "--hicache-ratio",
            type=float,
            default=ServerArgs.hicache_ratio,
            help="The ratio of the size of host KV cache memory pool to the size of device pool.",
        )
        parser.add_argument(
            "--hicache-size",
            type=int,
            default=ServerArgs.hicache_size,
            help="The size of host KV cache memory pool in gigabytes, which will override the hicache_ratio if set.",
        )
        parser.add_argument(
            "--hicache-write-policy",
            type=str,
            choices=["write_back", "write_through", "write_through_selective"],
            default=ServerArgs.hicache_write_policy,
            help="The write policy of hierarchical cache.",
        )
        parser.add_argument(
            "--hicache-io-backend",
            type=str,
            choices=["direct", "kernel"],
            default=ServerArgs.hicache_io_backend,
            help="The IO backend for KV cache transfer between CPU and GPU",
        )
1623
1624
1625
1626
1627
1628
1629
        parser.add_argument(
            "--hicache-mem-layout",
            type=str,
            choices=["layer_first", "page_first"],
            default=ServerArgs.hicache_mem_layout,
            help="The layout of host memory pool for hierarchical cache.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1630
1631
1632
        parser.add_argument(
            "--hicache-storage-backend",
            type=str,
1633
            choices=["file", "mooncake", "hf3fs", "nixl"],
Lianmin Zheng's avatar
Lianmin Zheng committed
1634
1635
1636
            default=ServerArgs.hicache_storage_backend,
            help="The storage backend for hierarchical KV cache.",
        )
pansicheng's avatar
pansicheng committed
1637
1638
1639
1640
1641
1642
1643
        parser.add_argument(
            "--hicache-storage-prefetch-policy",
            type=str,
            choices=["best_effort", "wait_complete", "timeout"],
            default=ServerArgs.hicache_storage_prefetch_policy,
            help="Control when prefetching from the storage backend should stop.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1644

1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
        # Double Sparsity
        parser.add_argument(
            "--enable-double-sparsity",
            action="store_true",
            help="Enable double sparsity attention",
        )
        parser.add_argument(
            "--ds-channel-config-path",
            type=str,
            default=ServerArgs.ds_channel_config_path,
            help="The path of the double sparsity channel config",
        )
        parser.add_argument(
            "--ds-heavy-channel-num",
            type=int,
            default=ServerArgs.ds_heavy_channel_num,
            help="The number of heavy channels in double sparsity attention",
        )
        parser.add_argument(
            "--ds-heavy-token-num",
            type=int,
            default=ServerArgs.ds_heavy_token_num,
            help="The number of heavy tokens in double sparsity attention",
        )
        parser.add_argument(
            "--ds-heavy-channel-type",
            type=str,
            default=ServerArgs.ds_heavy_channel_type,
            help="The type of heavy channels in double sparsity attention",
        )
        parser.add_argument(
            "--ds-sparse-decode-threshold",
            type=int,
            default=ServerArgs.ds_sparse_decode_threshold,
            help="The type of heavy channels in double sparsity attention",
        )

fzyzcjy's avatar
fzyzcjy committed
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
        # Offloading
        parser.add_argument(
            "--cpu-offload-gb",
            type=int,
            default=ServerArgs.cpu_offload_gb,
            help="How many GBs of RAM to reserve for CPU offloading.",
        )
        parser.add_argument(
            "--offload-group-size",
            type=int,
            default=ServerArgs.offload_group_size,
            help="Number of layers per group in offloading.",
        )
        parser.add_argument(
            "--offload-num-in-group",
            type=int,
            default=ServerArgs.offload_num_in_group,
            help="Number of layers to be offloaded within a group.",
        )
        parser.add_argument(
            "--offload-prefetch-step",
            type=int,
            default=ServerArgs.offload_prefetch_step,
            help="Steps to prefetch in offloading.",
        )
        parser.add_argument(
            "--offload-mode",
            type=str,
            default=ServerArgs.offload_mode,
            help="Mode of offloading.",
        )

1714
        # Optimization/debug options
Liangsheng Yin's avatar
Liangsheng Yin committed
1715
        parser.add_argument(
Lianmin Zheng's avatar
Lianmin Zheng committed
1716
            "--disable-radix-cache",
Liangsheng Yin's avatar
Liangsheng Yin committed
1717
            action="store_true",
Ying Sheng's avatar
Ying Sheng committed
1718
            help="Disable RadixAttention for prefix caching.",
Liangsheng Yin's avatar
Liangsheng Yin committed
1719
        )
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
        parser.add_argument(
            "--cuda-graph-max-bs",
            type=int,
            default=ServerArgs.cuda_graph_max_bs,
            help="Set the maximum batch size for cuda graph. It will extend the cuda graph capture batch size to this value.",
        )
        parser.add_argument(
            "--cuda-graph-bs",
            type=int,
            nargs="+",
            help="Set the list of batch sizes for cuda graph.",
        )
1732
1733
1734
        parser.add_argument(
            "--disable-cuda-graph",
            action="store_true",
1735
            help="Disable cuda graph.",
1736
        )
1737
        parser.add_argument(
1738
1739
            "--disable-cuda-graph-padding",
            action="store_true",
1740
            help="Disable cuda graph when padding is needed. Still uses cuda graph when padding is not needed.",
1741
        )
1742
1743
1744
1745
1746
        parser.add_argument(
            "--enable-profile-cuda-graph",
            action="store_true",
            help="Enable profiling of cuda graph capture.",
        )
1747
1748
1749
1750
1751
        parser.add_argument(
            "--enable-cudagraph-gc",
            action="store_true",
            help="Enable garbage collection during CUDA graph capture. If disabled (default), GC is frozen during capture to speed up the process.",
        )
1752
1753
1754
1755
1756
        parser.add_argument(
            "--enable-nccl-nvls",
            action="store_true",
            help="Enable NCCL NVLS for prefill heavy requests when available.",
        )
1757
1758
1759
1760
1761
        parser.add_argument(
            "--enable-symm-mem",
            action="store_true",
            help="Enable NCCL symmetric memory for fast collectives.",
        )
1762
1763
1764
1765
1766
        parser.add_argument(
            "--disable-flashinfer-cutlass-moe-fp4-allgather",
            action="store_true",
            help="Disables quantize before all-gather for flashinfer cutlass moe.",
        )
1767
1768
1769
1770
1771
        parser.add_argument(
            "--enable-tokenizer-batch-encode",
            action="store_true",
            help="Enable batch tokenization for improved performance when processing multiple text inputs. Do not use with image inputs, pre-tokenized input_ids, or input_embeds.",
        )
1772
        parser.add_argument(
1773
            "--disable-outlines-disk-cache",
1774
            action="store_true",
1775
            help="Disable disk cache of outlines to avoid possible crashes related to file system or high concurrency.",
1776
        )
1777
1778
1779
1780
1781
        parser.add_argument(
            "--disable-custom-all-reduce",
            action="store_true",
            help="Disable the custom all-reduce kernel and fall back to NCCL.",
        )
1782
1783
1784
1785
1786
        parser.add_argument(
            "--enable-mscclpp",
            action="store_true",
            help="Enable using mscclpp for small messages for all-reduce kernel and fall back to NCCL.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1787
        parser.add_argument(
1788
            "--disable-overlap-schedule",
Lianmin Zheng's avatar
Lianmin Zheng committed
1789
            action="store_true",
1790
            help="Disable the overlap scheduler, which overlaps the CPU scheduler with GPU model worker.",
Lianmin Zheng's avatar
Lianmin Zheng committed
1791
        )
1792
1793
1794
        parser.add_argument(
            "--enable-mixed-chunk",
            action="store_true",
1795
            help="Enabling mixing prefill and decode in a batch when using chunked prefill.",
1796
        )
Ke Bao's avatar
Ke Bao committed
1797
1798
1799
        parser.add_argument(
            "--enable-dp-attention",
            action="store_true",
1800
            help="Enabling data parallelism for attention and tensor parallelism for FFN. The dp size should be equal to the tp size. Currently DeepSeek-V2 and Qwen 2/3 MoE models are supported.",
Ke Bao's avatar
Ke Bao committed
1801
        )
1802
1803
1804
1805
1806
        parser.add_argument(
            "--enable-dp-lm-head",
            action="store_true",
            help="Enable vocabulary parallel across the attention TP group to avoid all-gather across DP groups, optimizing performance under DP attention.",
        )
1807
1808
1809
1810
1811
        parser.add_argument(
            "--enable-two-batch-overlap",
            action="store_true",
            help="Enabling two micro batches to overlap.",
        )
1812
1813
1814
1815
1816
1817
        parser.add_argument(
            "--tbo-token-distribution-threshold",
            type=float,
            default=ServerArgs.tbo_token_distribution_threshold,
            help="The threshold of token distribution between two batches in micro-batch-overlap, determines whether to two-batch-overlap or two-chunk-overlap. Set to 0 denote disable two-chunk-overlap.",
        )
1818
1819
1820
        parser.add_argument(
            "--enable-torch-compile",
            action="store_true",
1821
1822
            help="Optimize the model with torch.compile. Experimental feature.",
        )
1823
        parser.add_argument(
1824
            "--torch-compile-max-bs",
1825
            type=int,
1826
            default=ServerArgs.torch_compile_max_bs,
1827
1828
            help="Set the maximum batch size when using torch compile.",
        )
1829
1830
1831
1832
        parser.add_argument(
            "--torchao-config",
            type=str,
            default=ServerArgs.torchao_config,
1833
            help="Optimize the model with torchao. Experimental feature. Current choices are: int8dq, int8wo, int4wo-<group_size>, fp8wo, fp8dq-per_tensor, fp8dq-per_row",
1834
        )
1835
1836
1837
1838
1839
        parser.add_argument(
            "--enable-nan-detection",
            action="store_true",
            help="Enable the NaN detection for debugging purposes.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1840
        parser.add_argument(
1841
            "--enable-p2p-check",
Lianmin Zheng's avatar
Lianmin Zheng committed
1842
            action="store_true",
1843
            help="Enable P2P check for GPU access, otherwise the p2p access is allowed by default.",
Lianmin Zheng's avatar
Lianmin Zheng committed
1844
        )
1845
        parser.add_argument(
1846
            "--triton-attention-reduce-in-fp32",
1847
            action="store_true",
1848
            help="Cast the intermediate attention results to fp32 to avoid possible crashes related to fp16."
1849
            "This only affects Triton attention kernels.",
1850
        )
1851
1852
1853
1854
1855
1856
        parser.add_argument(
            "--triton-attention-num-kv-splits",
            type=int,
            default=ServerArgs.triton_attention_num_kv_splits,
            help="The number of KV splits in flash decoding Triton kernel. Larger value is better in longer context scenarios. The default value is 8.",
        )
1857
1858
1859
1860
1861
1862
1863
1864
        parser.add_argument(
            "--num-continuous-decode-steps",
            type=int,
            default=ServerArgs.num_continuous_decode_steps,
            help="Run multiple continuous decoding steps to reduce scheduling overhead. "
            "This can potentially increase throughput but may also increase time-to-first-token latency. "
            "The default value is 1, meaning only run one decoding step at a time.",
        )
1865
1866
1867
1868
1869
        parser.add_argument(
            "--delete-ckpt-after-loading",
            action="store_true",
            help="Delete the model checkpoint after loading the model.",
        )
1870
1871
1872
1873
1874
        parser.add_argument(
            "--enable-memory-saver",
            action="store_true",
            help="Allow saving memory using release_memory_occupation and resume_memory_occupation",
        )
1875
1876
1877
1878
1879
        parser.add_argument(
            "--allow-auto-truncate",
            action="store_true",
            help="Allow automatically truncating requests that exceed the maximum input length instead of returning an error.",
        )
1880
1881
1882
1883
1884
        parser.add_argument(
            "--enable-custom-logit-processor",
            action="store_true",
            help="Enable users to pass custom logit processors to the server (disabled by default for security)",
        )
1885
        parser.add_argument(
1886
            "--flashinfer-mla-disable-ragged",
1887
            action="store_true",
1888
            help="Not using ragged prefill wrapper when running flashinfer mla",
1889
        )
1890
        parser.add_argument(
1891
1892
1893
            "--disable-shared-experts-fusion",
            action="store_true",
            help="Disable shared experts fusion optimization for deepseek v3/r1.",
1894
        )
1895
1896
1897
1898
1899
        parser.add_argument(
            "--disable-chunked-prefix-cache",
            action="store_true",
            help="Disable chunked prefix cache feature for deepseek, which should save overhead for short sequences.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1900
1901
1902
1903
1904
        parser.add_argument(
            "--disable-fast-image-processor",
            action="store_true",
            help="Adopt base image processor instead of fast image processor.",
        )
1905
1906
1907
1908
1909
        parser.add_argument(
            "--enable-return-hidden-states",
            action="store_true",
            help="Enable returning hidden states with responses.",
        )
1910
1911
1912
1913
1914
1915
        parser.add_argument(
            "--scheduler-recv-interval",
            type=int,
            default=ServerArgs.scheduler_recv_interval,
            help="The interval to poll requests in scheduler. Can be set to >1 to reduce the overhead of this.",
        )
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935

        # Debug tensor dumps
        parser.add_argument(
            "--debug-tensor-dump-output-folder",
            type=str,
            default=ServerArgs.debug_tensor_dump_output_folder,
            help="The output folder for dumping tensors.",
        )
        parser.add_argument(
            "--debug-tensor-dump-input-file",
            type=str,
            default=ServerArgs.debug_tensor_dump_input_file,
            help="The input filename for dumping tensors",
        )
        parser.add_argument(
            "--debug-tensor-dump-inject",
            type=str,
            default=ServerArgs.debug_tensor_dump_inject,
            help="Inject the outputs from jax as the input of every layer.",
        )
1936
1937
1938
1939
1940
        parser.add_argument(
            "--debug-tensor-dump-prefill-only",
            action="store_true",
            help="Only dump the tensors for prefill requests (i.e. batch size > 1).",
        )
1941

Lianmin Zheng's avatar
Lianmin Zheng committed
1942
        # PD disaggregation
Byron Hsu's avatar
Byron Hsu committed
1943
1944
1945
1946
1947
1948
1949
        parser.add_argument(
            "--disaggregation-mode",
            type=str,
            default="null",
            choices=["null", "prefill", "decode"],
            help='Only used for PD disaggregation. "prefill" for prefill-only server, and "decode" for decode-only server. If not specified, it is not PD disaggregated',
        )
1950
1951
1952
1953
        parser.add_argument(
            "--disaggregation-transfer-backend",
            type=str,
            default=ServerArgs.disaggregation_transfer_backend,
1954
            choices=["mooncake", "nixl", "ascend"],
1955
1956
            help="The backend for disaggregation transfer. Default is mooncake.",
        )
1957
1958
1959
1960
1961
1962
        parser.add_argument(
            "--disaggregation-bootstrap-port",
            type=int,
            default=ServerArgs.disaggregation_bootstrap_port,
            help="Bootstrap server port on the prefill server. Default is 8998.",
        )
Byron Hsu's avatar
Byron Hsu committed
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
        parser.add_argument(
            "--disaggregation-decode-tp",
            type=int,
            default=ServerArgs.disaggregation_decode_tp,
            help="Decode tp size. If not set, it matches the tp size of the current engine. This is only set on the prefill server.",
        )
        parser.add_argument(
            "--disaggregation-decode-dp",
            type=int,
            default=ServerArgs.disaggregation_decode_dp,
            help="Decode dp size. If not set, it matches the dp size of the current engine. This is only set on the prefill server.",
        )
        parser.add_argument(
            "--disaggregation-prefill-pp",
            type=int,
            default=ServerArgs.disaggregation_prefill_pp,
            help="Prefill pp size. If not set, it is default to 1. This is only set on the decode server.",
        )
1981
1982
1983
1984
        parser.add_argument(
            "--disaggregation-ib-device",
            type=str,
            default=ServerArgs.disaggregation_ib_device,
1985
1986
1987
            help="The InfiniBand devices for disaggregation transfer, accepts single device (e.g., --disaggregation-ib-device mlx5_0) "
            "or multiple comma-separated devices (e.g., --disaggregation-ib-device mlx5_0,mlx5_1). "
            "Default is None, which triggers automatic device detection when mooncake backend is enabled.",
1988
        )
1989
1990
1991
1992
1993
1994
        parser.add_argument(
            "--num-reserved-decode-tokens",
            type=int,
            default=ServerArgs.num_reserved_decode_tokens,
            help="Number of decode tokens that will have memory reserved when adding new request to the running batch.",
        )
1995
1996
1997
1998
1999
2000
        parser.add_argument(
            "--pdlb-url",
            type=str,
            default=None,
            help="The URL of the PD disaggregation load balancer. If set, the prefill/decode server will register with the load balancer.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
2001
2002

        # Custom weight loader
2003
2004
2005
2006
2007
2008
2009
        parser.add_argument(
            "--custom-weight-loader",
            type=str,
            nargs="*",
            default=None,
            help="The custom dataloader which used to update the model. Should be set with a valid import path, such as my_package.weight_load_func",
        )
2010
2011
2012
2013
2014
2015
2016
        parser.add_argument(
            "--weight-loader-disable-mmap",
            action="store_true",
            help="Disable mmap while loading weight using safetensors.",
        )

        # For PD-Multiplexing
2017
2018
2019
2020
2021
        parser.add_argument(
            "--enable-pdmux",
            action="store_true",
            help="Enable PD-Multiplexing, PD running on greenctx stream.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
2022

2023
2024
2025
2026
2027
2028
        parser.add_argument(
            "--sm-group-num",
            type=int,
            default=ServerArgs.sm_group_num,
            help="Number of sm partition groups.",
        )
Byron Hsu's avatar
Byron Hsu committed
2029

2030
2031
2032
2033
2034
2035
2036
2037
2038
2039
2040
        # Deprecated arguments
        parser.add_argument(
            "--enable-ep-moe",
            action="store_true",
            help="(Deprecated) Enabling expert parallelism for moe. The ep size is equal to the tp size.",
        )
        parser.add_argument(
            "--enable-deepep-moe",
            action="store_true",
            help="(Deprecated) Enabling DeepEP MoE implementation for EP MoE.",
        )
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
2055
        parser.add_argument(
            "--enable-flashinfer-cutlass-moe",
            action="store_true",
            help="(Deprecated) Enable FlashInfer CUTLASS MoE backend for modelopt_fp4 quant on Blackwell. Supports MoE-EP",
        )
        parser.add_argument(
            "--enable-flashinfer-trtllm-moe",
            action="store_true",
            help="(Deprecated) Enable FlashInfer TRTLLM MoE backend on Blackwell. Supports BlockScale FP8 MoE-EP",
        )
        parser.add_argument(
            "--enable-triton-kernel-moe",
            action="store_true",
            help="(Deprecated) Use triton moe grouped gemm kernel.",
        )
2056
2057
2058
2059
2060
        parser.add_argument(
            "--enable-flashinfer-mxfp4-moe",
            action="store_true",
            help="(Deprecated) Enable FlashInfer MXFP4 MoE backend for modelopt_fp4 quant on Blackwell.",
        )
2061

Lianmin Zheng's avatar
Lianmin Zheng committed
2062
2063
    @classmethod
    def from_cli_args(cls, args: argparse.Namespace):
2064
        args.tp_size = args.tensor_parallel_size
2065
        args.pp_size = args.pipeline_parallel_size
2066
        args.dp_size = args.data_parallel_size
xiaobochen's avatar
xiaobochen committed
2067
        args.ep_size = args.expert_parallel_size
Lianmin Zheng's avatar
Lianmin Zheng committed
2068
2069
2070
2071
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        return cls(**{attr: getattr(args, attr) for attr in attrs})

    def url(self):
2072
        if is_valid_ipv6_address(self.host):
2073
2074
2075
            return f"http://[{self.host}]:{self.port}"
        else:
            return f"http://{self.host}:{self.port}"
Lianmin Zheng's avatar
Lianmin Zheng committed
2076

Lianmin Zheng's avatar
Lianmin Zheng committed
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087
    def get_hf_config(self):
        kwargs = {}
        hf_config = get_config(
            self.model_path,
            trust_remote_code=self.trust_remote_code,
            revision=self.revision,
            model_override_args=json.loads(self.json_model_override_args),
            **kwargs,
        )
        return hf_config

2088
    def check_server_args(self):
2089
        # Check parallel size constraints
2090
        assert (
2091
2092
2093
2094
2095
2096
2097
2098
2099
2100
            self.tp_size * self.pp_size
        ) % self.nnodes == 0, "tp_size must be divisible by number of nodes"

        if self.pp_size > 1:
            assert (
                self.disable_overlap_schedule
                and self.speculative_algorithm is None
                and not self.enable_mixed_chunk
            ), "Pipeline parallelism is not compatible with overlap schedule, speculative decoding, mixed chunked prefill."

2101
        assert not (
2102
2103
            self.dp_size > 1 and self.nnodes != 1 and not self.enable_dp_attention
        ), "multi-node data parallel is not supported unless dp attention!"
2104

2105
        assert self.base_gpu_id >= 0, "base_gpu_id must be non-negative"
2106
        assert self.gpu_id_step >= 1, "gpu_id_step must be positive"
2107

Lianmin Zheng's avatar
Lianmin Zheng committed
2108
2109
2110
2111
2112
        assert self.moe_dense_tp_size in {
            1,
            None,
        }, "moe_dense_tp_size only support 1 and None currently"

2113
        # Check LoRA
2114
2115
        self.check_lora_server_args()

2116
2117
2118
2119
2120
2121
2122
        # Check speculative decoding
        if self.speculative_algorithm is not None:
            assert (
                not self.enable_mixed_chunk
            ), "enable_mixed_chunk is required for speculative decoding"

        # Check chunked prefill
2123
2124
2125
2126
2127
        # Skip validation if chunked prefill is disabled (i.e., size <= 0).
        if self.chunked_prefill_size > 0:
            assert (
                self.chunked_prefill_size % self.page_size == 0
            ), "chunked_prefill_size must be divisible by page_size"
2128

2129
    def check_lora_server_args(self):
2130
        assert self.max_loras_per_batch > 0, "max_loras_per_batch must be positive"
2131

2132
2133
2134
2135
        # Enable LoRA if any LoRA paths are provided for backward compatibility.
        if self.lora_paths:
            if self.enable_lora is None:
                self.enable_lora = True
2136
                logger.warning(
2137
2138
2139
2140
2141
2142
2143
2144
2145
2146
                    "--enable-lora is set to True because --lora-paths is provided."
                )
            elif self.enable_lora is False:
                logger.warning(
                    "--enable-lora is set to False, any provided lora_paths will be ignored."
                )

        if self.enable_lora:
            if isinstance(self.lora_paths, list):
                lora_paths = self.lora_paths
2147
                self.lora_paths = []
2148
                for lora_path in lora_paths:
2149
2150
2151
2152
2153
2154
2155
2156
2157
2158
2159
2160
2161
2162
2163
2164
2165
2166
                    if isinstance(lora_path, str):
                        if "=" in lora_path:
                            name, path = lora_path.split("=", 1)
                            lora_ref = LoRARef(
                                lora_name=name, lora_path=path, pinned=False
                            )
                        else:
                            lora_ref = LoRARef(
                                lora_name=lora_path, lora_path=lora_path, pinned=False
                            )
                    elif isinstance(lora_path, dict):
                        assert (
                            "lora_name" in lora_path and "lora_path" in lora_path
                        ), f"When providing LoRA paths as a list of dict, each dict should contain 'lora_name' and 'lora_path' keys. Got: {lora_path}"
                        lora_ref = LoRARef(
                            lora_name=lora_path["lora_name"],
                            lora_path=lora_path["lora_path"],
                            pinned=lora_path.get("pinned", False),
2167
                        )
2168
                    else:
2169
2170
2171
                        raise ValueError(
                            f"Invalid type for item in --lora-paths list: {type(lora_path)}. "
                            "Expected a string or a dictionary."
2172
                        )
2173
                    self.lora_paths.append(lora_ref)
2174
            elif isinstance(self.lora_paths, dict):
2175
2176
                self.lora_paths = [
                    LoRARef(lora_name=k, lora_path=v, pinned=False)
2177
                    for k, v in self.lora_paths.items()
2178
                ]
2179
            elif self.lora_paths is None:
2180
                self.lora_paths = []
2181
2182
2183
2184
2185
            else:
                raise ValueError(
                    f"Invalid type for --lora-paths: {type(self.lora_paths)}. "
                    "Expected a list or a dictionary."
                )
2186
2187
2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
2198
2199

            # Expand target modules
            if self.lora_target_modules:
                self.lora_target_modules = set(self.lora_target_modules)
                if "all" in self.lora_target_modules:
                    assert (
                        len(self.lora_target_modules) == 1
                    ), "If 'all' is specified in --lora-target-modules, it should be the only module specified."
                    self.lora_target_modules = set(SUPPORTED_LORA_TARGET_MODULES)

            # Ensure sufficient information is provided for LoRA initialization.
            assert self.lora_paths or (
                self.max_lora_rank and self.lora_target_modules
            ), "When no initial --lora-paths is provided, you need to specify both --max-lora-rank and --lora-target-modules for LoRA initialization."
2200

2201
2202
2203
2204
2205
2206
            # Validate max_loaded_loras
            if self.max_loaded_loras is not None:
                assert self.max_loaded_loras >= self.max_loras_per_batch, (
                    "max_loaded_loras should be greater than or equal to max_loras_per_batch. "
                    f"max_loaded_loras={self.max_loaded_loras}, max_loras_per_batch={self.max_loras_per_batch}"
                )
2207
                assert len(self.lora_paths) <= self.max_loaded_loras, (
2208
2209
2210
2211
                    "The number of LoRA paths should not exceed max_loaded_loras. "
                    f"max_loaded_loras={self.max_loaded_loras}, lora_paths={len(self.lora_paths)}"
                )

Lianmin Zheng's avatar
Lianmin Zheng committed
2212
2213
2214
2215
2216
2217
2218
2219
    def validate_disagg_tp_size(self, prefill_tp: int, decode_tp: int):
        larger_tp = max(decode_tp, prefill_tp)
        smaller_tp = min(decode_tp, prefill_tp)
        assert larger_tp % smaller_tp == 0, (
            "Different tp size is supported only when one tp is multiple of the other. "
            f"decode_tp={decode_tp}, prefill_tp={prefill_tp}"
        )

2220
2221
2222
2223
2224
    def model_specific_adjustments(self):
        hf_config = self.get_hf_config()
        model_arch = hf_config.architectures[0]
        if model_arch in ["GptOssForCausalLM"]:
            if self.attention_backend is None:
2225
                if is_cuda() and is_sm100_supported():
2226
                    self.attention_backend = "trtllm_mha"
2227
                elif is_cuda() and is_sm90_supported():
2228
2229
2230
                    self.attention_backend = "fa3"
                else:
                    self.attention_backend = "triton"
2231
            supported_backends = ["triton", "trtllm_mha", "fa3"]
2232
2233
2234
            logger.info(
                f"Use {self.attention_backend} as attention backend for GptOssForCausalLM"
            )
2235
2236
2237
            assert (
                self.attention_backend in supported_backends
            ), f"GptOssForCausalLM requires one of {supported_backends} attention backend, but got '{self.attention_backend}'"
2238
2239

            if is_sm100_supported():
2240
2241
2242
2243
2244
                if not self.enable_dp_attention:
                    self.enable_flashinfer_allreduce_fusion = True
                    logger.info(
                        "Enable FlashInfer AllReduce Fusion on sm100 for GptOssForCausalLM"
                    )
2245
2246
2247
2248
2249
2250
2251
            quantization_config = getattr(hf_config, "quantization_config", None)
            is_mxfp4_quant_format = (
                quantization_config is not None
                and quantization_config.get("quant_method") == "mxfp4"
            )

            if is_sm100_supported() and is_mxfp4_quant_format:
2252
                self.moe_runner_backend = "flashinfer_mxfp4"
2253
2254
2255
2256
                logger.warning(
                    "Detected SM100 and MXFP4 quantization format for GPT-OSS model, enabling FlashInfer MXFP4 MOE kernel."
                )
            else:
2257
                if self.moe_runner_backend == "triton_kernel":
2258
2259
2260
                    assert (
                        self.ep_size == 1
                    ), "Triton kernel MoE is only supported when ep_size == 1"
2261
2262
2263
2264
2265
2266
                if (
                    self.moe_runner_backend == "auto"
                    and self.ep_size == 1
                    and is_triton_kernels_available()
                ):
                    self.moe_runner_backend = "triton_kernel"
2267
2268
2269
2270
2271
2272
2273
2274
                    logger.warning(
                        "Detected GPT-OSS model, enabling triton_kernels MOE kernel."
                    )
            self.disable_hybrid_swa_memory = True
            if is_mxfp4_quant_format:
                # use bf16 for mxfp4 triton kernels
                self.dtype = "bfloat16"
        elif "Llama4" in model_arch:
2275
2276
2277
2278
            assert self.attention_backend in {
                "fa3",
                "aiter",
            }, "fa3 or aiter is required for Llama4 model"
2279
2280
2281
2282
2283
2284
2285
2286
2287
2288
2289
2290
2291
2292
        elif model_arch in [
            "Gemma2ForCausalLM",
            "Gemma3ForCausalLM",
            "Gemma3ForConditionalGeneration",
            "Gemma3nForCausalLM",
            "Gemma3nForConditionalGeneration",
        ]:
            # FIXME: https://github.com/sgl-project/sglang/pull/7367 is not compatible with gemma2 model.
            # It failed at this test: https://github.com/sgl-project/sglang/actions/runs/16255155597/job/45890331952#step:4:736
            logger.warning(
                f"Disable hybrid SWA memory for {model_arch} as it is not yet supported."
            )
            self.disable_hybrid_swa_memory = True

Lianmin Zheng's avatar
Lianmin Zheng committed
2293
2294
2295
2296
2297
2298
2299
2300
2301
2302
2303
2304
2305
2306
2307
2308
2309
2310
2311
2312
2313
2314
2315
2316
2317
2318
2319
2320
2321
2322
2323
2324
2325
2326
2327
2328
2329
2330
    def adjust_mem_fraction_for_vlm(self, model_config):
        vision_config = getattr(model_config.hf_config, "vision_config", None)
        if vision_config is None:
            return

        # roughly reduce the mem_fraction_static base on params of Vit
        original_server_arg_mem_fraction = self.mem_fraction_static
        # a base mem_fraction_static factor for regular Vit
        base_mem_fraction_reduction_ratio = 0.95

        vit_num_layers = getattr(vision_config, "num_hidden_layers", 24)
        vit_hidden_size = getattr(vision_config, "hidden_size", 1024)

        # baseline ViT params (ViT-L/14)
        baseline_vit_layers = 24
        baseline_vit_hidden_size = 1024

        # weight params count
        current_complexity_score = vit_num_layers * (vit_hidden_size**2)
        baseline_complexity_score = baseline_vit_layers * (baseline_vit_hidden_size**2)
        complexity_ratio = (
            current_complexity_score / baseline_complexity_score
            if baseline_complexity_score > 0
            else 1.0
        )

        # every time the complexity grows 100%, adjust final factor for 10%
        sensitivity_scale = 0.1
        dynamic_adjustment_factor = 1.0 - sensitivity_scale * (complexity_ratio - 1.0)
        dynamic_adjustment_factor = max(0.8, min(1.05, dynamic_adjustment_factor))

        final_overall_factor = (
            base_mem_fraction_reduction_ratio * dynamic_adjustment_factor
        )
        self.mem_fraction_static = (
            original_server_arg_mem_fraction * final_overall_factor
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
2331

Lianmin Zheng's avatar
Lianmin Zheng committed
2332
def prepare_server_args(argv: List[str]) -> ServerArgs:
2333
2334
2335
2336
2337
2338
2339
2340
2341
2342
2343
2344
    """
    Prepare the server arguments from the command line arguments.

    Args:
        args: The command line arguments. Typically, it should be `sys.argv[1:]`
            to ensure compatibility with `parse_args` when no arguments are passed.

    Returns:
        The server arguments.
    """
    parser = argparse.ArgumentParser()
    ServerArgs.add_cli_args(parser)
Lianmin Zheng's avatar
Lianmin Zheng committed
2345
    raw_args = parser.parse_args(argv)
2346
2347
2348
2349
    server_args = ServerArgs.from_cli_args(raw_args)
    return server_args


2350
2351
2352
ZMQ_TCP_PORT_DELTA = 233


Lianmin Zheng's avatar
Lianmin Zheng committed
2353
2354
@dataclasses.dataclass
class PortArgs:
2355
2356
2357
2358
2359
2360
    # The ipc filename for tokenizer to receive inputs from detokenizer (zmq)
    tokenizer_ipc_name: str
    # The ipc filename for scheduler (rank 0) to receive inputs from tokenizer (zmq)
    scheduler_input_ipc_name: str
    # The ipc filename for detokenizer to receive inputs from scheduler (zmq)
    detokenizer_ipc_name: str
2361

2362
2363
    # The port for nccl initialization (torch.dist)
    nccl_port: int
2364

2365
2366
2367
    # The ipc filename for rpc call between Engine and Scheduler
    rpc_ipc_name: str

2368
2369
2370
    # The ipc filename for Scheduler to send metrics
    metrics_ipc_name: str

2371
    @staticmethod
2372
    def init_new(server_args, dp_rank: Optional[int] = None) -> "PortArgs":
2373
        if server_args.nccl_port is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
2374
            nccl_port = server_args.port + random.randint(100, 1000)
2375
            while True:
Lianmin Zheng's avatar
Lianmin Zheng committed
2376
                if is_port_available(nccl_port):
2377
                    break
Lianmin Zheng's avatar
Lianmin Zheng committed
2378
2379
                if nccl_port < 60000:
                    nccl_port += 42
2380
                else:
Lianmin Zheng's avatar
Lianmin Zheng committed
2381
                    nccl_port -= 43
2382
        else:
Lianmin Zheng's avatar
Lianmin Zheng committed
2383
            nccl_port = server_args.nccl_port
2384

2385
2386
2387
2388
2389
2390
        if not server_args.enable_dp_attention:
            # Normal case, use IPC within a single node
            return PortArgs(
                tokenizer_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
                scheduler_input_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
                detokenizer_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
Lianmin Zheng's avatar
Lianmin Zheng committed
2391
                nccl_port=nccl_port,
2392
                rpc_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
2393
                metrics_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
2394
2395
2396
2397
2398
            )
        else:
            # DP attention. Use TCP + port to handle both single-node and multi-node.
            if server_args.nnodes == 1 and server_args.dist_init_addr is None:
                dist_init_addr = ("127.0.0.1", server_args.port + ZMQ_TCP_PORT_DELTA)
Vincent's avatar
Vincent committed
2399
2400
2401
            elif server_args.dist_init_addr.startswith("["):  # ipv6 address
                port_num, host = configure_ipv6(server_args.dist_init_addr)
                dist_init_addr = (host, str(port_num))
2402
2403
            else:
                dist_init_addr = server_args.dist_init_addr.split(":")
Vincent's avatar
Vincent committed
2404

2405
2406
2407
2408
2409
2410
            assert (
                len(dist_init_addr) == 2
            ), "please provide --dist-init-addr as host:port of head node"

            dist_init_host, dist_init_port = dist_init_addr
            port_base = int(dist_init_port) + 1
2411
2412
2413
            detokenizer_port = port_base + 1
            rpc_port = port_base + 2
            metrics_ipc_name = port_base + 3
2414
            if dp_rank is None:
2415
                # TokenizerManager to DataParallelController
2416
                scheduler_input_port = port_base + 4
2417
            else:
2418
                scheduler_input_port = port_base + 4 + 1 + dp_rank
2419
2420
2421
2422

            return PortArgs(
                tokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base}",
                scheduler_input_ipc_name=f"tcp://{dist_init_host}:{scheduler_input_port}",
2423
                detokenizer_ipc_name=f"tcp://{dist_init_host}:{detokenizer_port}",
Lianmin Zheng's avatar
Lianmin Zheng committed
2424
                nccl_port=nccl_port,
2425
2426
                rpc_ipc_name=f"tcp://{dist_init_host}:{rpc_port}",
                metrics_ipc_name=f"tcp://{dist_init_host}:{metrics_ipc_name}",
2427
            )
2428

2429
2430
2431

class LoRAPathAction(argparse.Action):
    def __call__(self, parser, namespace, values, option_string=None):
2432
2433
2434
2435
2436
2437
2438
2439
2440
2441
2442
2443
2444
2445
2446
2447
        lora_paths = []
        if values:
            assert isinstance(values, list), "Expected a list of LoRA paths."
            for lora_path in values:
                lora_path = lora_path.strip()
                if lora_path.startswith("{") and lora_path.endswith("}"):
                    obj = json.loads(lora_path)
                    assert "lora_path" in obj and "lora_name" in obj, (
                        f"{repr(lora_path)} looks like a JSON str, "
                        "but it does not contain 'lora_name' and 'lora_path' keys."
                    )
                    lora_paths.append(obj)
                else:
                    lora_paths.append(lora_path)

        setattr(namespace, self.dest, lora_paths)
2448
2449
2450
2451
2452
2453
2454
2455
2456
2457


class DeprecatedAction(argparse.Action):
    def __init__(self, option_strings, dest, nargs=0, **kwargs):
        super(DeprecatedAction, self).__init__(
            option_strings, dest, nargs=nargs, **kwargs
        )

    def __call__(self, parser, namespace, values, option_string=None):
        raise ValueError(self.help)
2458
2459


2460
2461
2462
2463
def print_deprecated_warning(message: str):
    logger.warning(f"\033[33m{message}\033[0m")


2464
def auto_choose_speculative_params(self: ServerArgs):
2465
2466
2467
2468
2469
    """
    Automatically choose the parameters for speculative decoding.

    You can tune them on your own models and prompts with scripts/playground/bench_speculative.py
    """
Lianmin Zheng's avatar
Lianmin Zheng committed
2470
    hf_config = self.get_hf_config()
2471
2472
    arch = hf_config.architectures[0]

2473
2474
2475
    if arch in ["LlamaForCausalLM"]:
        # The default value for llama
        return (5, 4, 8)
2476
2477
2478
2479
2480
2481
    elif arch in [
        "DeepseekV3ForCausalLM",
        "DeepseekV2ForCausalLM",
        "GptOssForCausalLM",
    ]:
        # The default value for deepseek and gpt-oss
2482
        return (3, 1, 4)
2483
2484
2485
2486
2487
    elif arch in ["Grok1ForCausalLM", "Grok1VForCausalLM"]:
        return (5, 4, 8)
    else:
        # The default value for all other models
        return (5, 4, 8)