server_args.py 56 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
Lianmin Zheng's avatar
Lianmin Zheng committed
14
15
"""The arguments of the server."""

Lianmin Zheng's avatar
Lianmin Zheng committed
16
17
import argparse
import dataclasses
18
import json
19
import logging
20
import os
21
import random
22
import tempfile
23
from typing import List, Literal, Optional
Lianmin Zheng's avatar
Lianmin Zheng committed
24

25
from sglang.srt.hf_transformers_utils import check_gguf_file, get_config
Xihuai Wang's avatar
Xihuai Wang committed
26
from sglang.srt.reasoning_parser import ReasoningParser
27
from sglang.srt.utils import (
Vincent's avatar
Vincent committed
28
    configure_ipv6,
29
    get_device,
Lianmin Zheng's avatar
Lianmin Zheng committed
30
    get_device_memory_capacity,
31
    is_flashinfer_available,
HAI's avatar
HAI committed
32
    is_hip,
33
    is_port_available,
34
    is_remote_url,
35
    is_valid_ipv6_address,
bjmsong's avatar
bjmsong committed
36
    nullable_str,
37
)
38

39
40
logger = logging.getLogger(__name__)

Lianmin Zheng's avatar
Lianmin Zheng committed
41
42
43

@dataclasses.dataclass
class ServerArgs:
Lianmin Zheng's avatar
Lianmin Zheng committed
44
    # Model and tokenizer
Lianmin Zheng's avatar
Lianmin Zheng committed
45
46
47
    model_path: str
    tokenizer_path: Optional[str] = None
    tokenizer_mode: str = "auto"
48
    skip_tokenizer_init: bool = False
49
    enable_tokenizer_batch_encode: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
50
    load_format: str = "auto"
51
    trust_remote_code: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
52
    dtype: str = "auto"
53
    kv_cache_dtype: str = "auto"
Lianmin Zheng's avatar
Lianmin Zheng committed
54
    quantization: Optional[str] = None
Vincent's avatar
Vincent committed
55
    quantization_param_path: Optional[str] = None
56
    context_length: Optional[int] = None
57
    device: Optional[str] = None
58
    served_model_name: Optional[str] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
59
    chat_template: Optional[str] = None
60
    completion_template: Optional[str] = None
61
    is_embedding: bool = False
62
    revision: Optional[str] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
63

64
    # Port for the HTTP server
Lianmin Zheng's avatar
Lianmin Zheng committed
65
66
67
68
    host: str = "127.0.0.1"
    port: int = 30000

    # Memory and scheduling
Lianmin Zheng's avatar
Lianmin Zheng committed
69
    mem_fraction_static: Optional[float] = None
70
    max_running_requests: Optional[int] = None
71
    max_total_tokens: Optional[int] = None
72
    chunked_prefill_size: Optional[int] = None
73
    max_prefill_tokens: int = 16384
74
    schedule_policy: str = "fcfs"
75
    schedule_conservativeness: float = 1.0
76
    cpu_offload_gb: int = 0
77
    page_size: int = 1
Lianmin Zheng's avatar
Lianmin Zheng committed
78
79
80

    # Other runtime options
    tp_size: int = 1
81
82
    pp_size: int = 1
    max_micro_batch_size: Optional[int] = None
83
    stream_interval: int = 1
84
    stream_output: bool = False
85
    random_seed: Optional[int] = None
86
    constrained_json_whitespace_pattern: Optional[str] = None
87
    watchdog_timeout: float = 300
88
    dist_timeout: Optional[int] = None  # timeout for torch.distributed
89
    download_dir: Optional[str] = None
90
    base_gpu_id: int = 0
91
    gpu_id_step: int = 1
Lianmin Zheng's avatar
Lianmin Zheng committed
92
93
94

    # Logging
    log_level: str = "info"
95
    log_level_http: Optional[str] = None
96
    log_requests: bool = False
97
    log_requests_level: int = 0
Liangsheng Yin's avatar
Liangsheng Yin committed
98
    show_time_cost: bool = False
99
    enable_metrics: bool = False
100
    decode_log_interval: int = 40
Liangsheng Yin's avatar
Liangsheng Yin committed
101

102
    # API related
103
    api_key: Optional[str] = None
104
    file_storage_path: str = "sglang_storage"
105
    enable_cache_report: bool = False
Xihuai Wang's avatar
Xihuai Wang committed
106
    reasoning_parser: Optional[str] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
107

108
109
110
    # Data parallelism
    dp_size: int = 1
    load_balance_method: str = "round_robin"
111

xiaobochen's avatar
xiaobochen committed
112
113
    # Expert parallelism
    ep_size: int = 1
114

115
    # Multi-node distributed serving
116
    dist_init_addr: Optional[str] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
117
    nnodes: int = 1
118
    node_rank: int = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
119
120
121
122

    # Model override args in JSON
    json_model_override_args: str = "{}"

123
124
125
    # LoRA
    lora_paths: Optional[List[str]] = None
    max_loras_per_batch: int = 8
126
    lora_backend: str = "triton"
127
128

    # Kernel backend
129
130
    attention_backend: Optional[str] = None
    sampling_backend: Optional[str] = None
131
    grammar_backend: Optional[str] = None
132

133
134
    # Speculative decoding
    speculative_algorithm: Optional[str] = None
135
    speculative_draft_model_path: Optional[str] = None
136
137
138
    speculative_num_steps: Optional[int] = None
    speculative_eagle_topk: Optional[int] = None
    speculative_num_draft_tokens: Optional[int] = None
139
140
    speculative_accept_threshold_single: float = 1.0
    speculative_accept_threshold_acc: float = 1.0
141
    speculative_token_map: Optional[str] = None
142
143
144

    # Double Sparsity
    enable_double_sparsity: bool = False
Vincent's avatar
Vincent committed
145
    ds_channel_config_path: Optional[str] = None
146
147
148
149
150
    ds_heavy_channel_num: int = 32
    ds_heavy_token_num: int = 256
    ds_heavy_channel_type: str = "qk"
    ds_sparse_decode_threshold: int = 4096

151
    # Optimization/debug options
Lianmin Zheng's avatar
Lianmin Zheng committed
152
    disable_radix_cache: bool = False
153
    disable_cuda_graph: bool = False
154
    disable_cuda_graph_padding: bool = False
155
    enable_nccl_nvls: bool = False
156
    disable_outlines_disk_cache: bool = False
157
    disable_custom_all_reduce: bool = False
158
    enable_multimodal: Optional[bool] = None
159
    disable_overlap_schedule: bool = False
160
    enable_mixed_chunk: bool = False
Ke Bao's avatar
Ke Bao committed
161
    enable_dp_attention: bool = False
xiaobochen's avatar
xiaobochen committed
162
    enable_ep_moe: bool = False
163
    enable_deepep_moe: bool = False
164
    deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto"
165
    enable_torch_compile: bool = False
166
    torch_compile_max_bs: int = 32
167
    cuda_graph_max_bs: Optional[int] = None
168
    cuda_graph_bs: Optional[List[int]] = None
169
    torchao_config: str = ""
170
    enable_nan_detection: bool = False
171
    enable_p2p_check: bool = False
172
    triton_attention_reduce_in_fp32: bool = False
173
    triton_attention_num_kv_splits: int = 8
174
    num_continuous_decode_steps: int = 1
175
    delete_ckpt_after_loading: bool = False
176
    enable_memory_saver: bool = False
177
    allow_auto_truncate: bool = False
178
    enable_custom_logit_processor: bool = False
Vincent's avatar
Vincent committed
179
    tool_call_parser: Optional[str] = None
180
    enable_hierarchical_cache: bool = False
181
    hicache_ratio: float = 2.0
Zhiqiang Xie's avatar
Zhiqiang Xie committed
182
183
    hicache_size: int = 0
    hicache_write_policy: str = "write_through_selective"
184
    flashinfer_mla_disable_ragged: bool = False
185
    warmups: Optional[str] = None
186
    moe_dense_tp_size: Optional[int] = None
187
    n_share_experts_fusion: int = 0
188
    disable_chunked_prefix_cache: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
189
    disable_fast_image_processor: bool = False
190
191
192
193
194

    # Debug tensor dumps
    debug_tensor_dump_output_folder: Optional[str] = None
    debug_tensor_dump_input_file: Optional[str] = None
    debug_tensor_dump_inject: bool = False
195

Byron Hsu's avatar
Byron Hsu committed
196
197
198
    # For PD disaggregation: can be "null" (not disaggregated), "prefill" (prefill-only), or "decode" (decode-only)
    disaggregation_mode: str = "null"
    disaggregation_bootstrap_port: int = 8998
199
    disaggregation_transfer_backend: str = "mooncake"
200
    disaggregation_ib_device: Optional[str] = None
Byron Hsu's avatar
Byron Hsu committed
201

Lianmin Zheng's avatar
Lianmin Zheng committed
202
    def __post_init__(self):
203
204
205
        # Expert parallelism
        if self.enable_ep_moe:
            self.ep_size = self.tp_size
Lianmin Zheng's avatar
Lianmin Zheng committed
206
            logger.warning(
207
208
209
                f"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
            )

210
        # Set missing default values
Lianmin Zheng's avatar
Lianmin Zheng committed
211
212
        if self.tokenizer_path is None:
            self.tokenizer_path = self.model_path
213

214
215
216
        if self.device is None:
            self.device = get_device()

217
218
219
        if self.served_model_name is None:
            self.served_model_name = self.model_path

220
221
222
        if self.random_seed is None:
            self.random_seed = random.randint(0, 1 << 30)

Lianmin Zheng's avatar
Lianmin Zheng committed
223
        gpu_mem = get_device_memory_capacity(self.device)
224
225

        # Set mem fraction static, which depends on the tensor parallelism size
Lianmin Zheng's avatar
Lianmin Zheng committed
226
        if self.mem_fraction_static is None:
227
228
229
230
231
232
233
234
235
236
237
238
            parallel_size = self.tp_size * self.pp_size
            if gpu_mem <= 81920:
                if parallel_size >= 16:
                    self.mem_fraction_static = 0.79
                elif parallel_size >= 8:
                    self.mem_fraction_static = 0.81
                elif parallel_size >= 4:
                    self.mem_fraction_static = 0.85
                elif parallel_size >= 2:
                    self.mem_fraction_static = 0.87
                else:
                    self.mem_fraction_static = 0.88
Ying Sheng's avatar
Ying Sheng committed
239
            else:
240
241
242
243
244
245
246
247
                self.mem_fraction_static = 0.88
            if gpu_mem > 96 * 1024:
                mem_fraction = self.mem_fraction_static
                self.mem_fraction_static = min(
                    mem_fraction + 48 * 1024 * (1 - mem_fraction) / gpu_mem,
                    (gpu_mem - 1024 * 18)
                    / gpu_mem,  # 15 GB + additional 3GB for cuda graph
                )
248

249
250
        # Set chunked prefill size, which depends on the gpu memory capacity
        if self.chunked_prefill_size is None:
251
            if gpu_mem is not None and gpu_mem < 25_000:
252
                self.chunked_prefill_size = 2048
253
254
            elif self.disaggregation_mode != "null":
                self.chunked_prefill_size = 16384
255
256
            else:
                self.chunked_prefill_size = 8192
Lianmin Zheng's avatar
Lianmin Zheng committed
257
258
        assert self.chunked_prefill_size % self.page_size == 0

259
260
261
        assert self.moe_dense_tp_size in {
            1,
            None,
Lianmin Zheng's avatar
Lianmin Zheng committed
262
        }, "moe_dense_tp_size only support 1 and None currently"
263

264
        if self.attention_backend == "flashmla":
265
266
267
268
            logger.warning(
                "FlashMLA only supports a page_size of 64, change page_size to 64."
            )
            self.page_size = 64
Lianmin Zheng's avatar
Lianmin Zheng committed
269

270
271
272
273
274
275
        if self.attention_backend == "cutlass_mla":
            logger.warning(
                "Cutlass MLA only supports a page_size of 128, change page_size to 128."
            )
            self.page_size = 128

276
277
        # Set cuda graph max batch size
        if self.cuda_graph_max_bs is None:
278
            # Based on detailed statistics, when serving TP1/TP2 models on lower-end GPUs with HBM<25G, you can either disable cuda graph or set `cuda_graph_max_bs` to a very small value to reduce the memory overhead of creating cuda graphs, with almost no impact on performance. However, when serving models with TP4 or TP8, we need to enable cuda graph to maintain high performance. In this case, we can set `cuda_graph_max_bs` to 80 (half of the default value 160) to reduce the memory overhead of creating cuda graphs. Looking at the logs from TP4 serving of qwen2-72b, a value of 80 is sufficient and can reduce the memory overhead of creating cuda graphs on lower-end GPUs compared to the original 160, avoiding OOM issues.
279
            if gpu_mem is not None and gpu_mem < 25_000:
280
281
282
283
                if self.tp_size < 4:
                    self.cuda_graph_max_bs = 8
                else:
                    self.cuda_graph_max_bs = 80
284

285
        # Set kernel backends for hpu device
286
287
288
289
        if self.device == "hpu":
            self.attention_backend = "torch_native"
            self.sampling_backend = "pytorch"

Lianmin Zheng's avatar
Lianmin Zheng committed
290
        # Set kernel backends
291
        if self.sampling_backend is None:
292
293
294
295
296
            self.sampling_backend = (
                "flashinfer" if is_flashinfer_available() else "pytorch"
            )

        if self.attention_backend == "torch_native":
297
            logger.warning(
298
299
300
                "Cuda graph is disabled because of using torch native attention backend"
            )
            self.disable_cuda_graph = True
301

302
303
304
        # Choose grammar backend
        if self.grammar_backend is None:
            self.grammar_backend = "xgrammar"
305

306
        # Data parallelism attention
Ke Bao's avatar
Ke Bao committed
307
        if self.enable_dp_attention:
308
            self.schedule_conservativeness = self.schedule_conservativeness * 0.3
Lianmin Zheng's avatar
Lianmin Zheng committed
309
310
311
312
313
            assert (
                self.dp_size > 1
            ), "Please set a dp-size > 1. You can use 1 < dp-size <= tp-size "
            assert self.tp_size % self.dp_size == 0
            self.chunked_prefill_size = self.chunked_prefill_size // self.dp_size
314
            logger.warning(
315
                f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE kernel issues. "
316
            )
317
318

        # DeepEP MoE
Lianmin Zheng's avatar
Lianmin Zheng committed
319
        self.enable_sp_layernorm = False
320
        if self.enable_deepep_moe:
321
322
323
324
            if self.deepep_mode == "auto":
                assert (
                    not self.enable_dp_attention
                ), "DeepEP MoE `auto` mode is not supported with DP Attention."
325
326
327
328
            self.ep_size = self.tp_size
            self.enable_sp_layernorm = (
                self.dp_size < self.tp_size if self.enable_dp_attention else True
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
329
            logger.warning(
330
331
                f"DeepEP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
            )
332

333
        # Speculative Decoding
334
335
336
337
        if self.speculative_algorithm == "NEXTN":
            # NEXTN shares the same implementation of EAGLE
            self.speculative_algorithm = "EAGLE"

Lianmin Zheng's avatar
Lianmin Zheng committed
338
        if self.speculative_algorithm in ("EAGLE", "EAGLE3"):
339
            if self.max_running_requests is None:
340
                self.max_running_requests = 48
341
            self.disable_overlap_schedule = True
Lianmin Zheng's avatar
Lianmin Zheng committed
342
            logger.warning(
343
                "Overlap scheduler is disabled because of using "
344
                "eagle speculative decoding."
345
            )
346

347
348
349
            model_arch = get_model_arch(self)

            # Auto set draft_model_path DeepSeek-V3/R1
350
351
352
353
354
355
356
            if model_arch == "DeepseekV3ForCausalLM":
                if self.speculative_draft_model_path is None:
                    self.speculative_draft_model_path = self.model_path
                else:
                    logger.warning(
                        "DeepSeek MTP does not require setting speculative_draft_model_path."
                    )
357

358
359
360
361
362
363
364
365
366
367
            # Auto choose parameters
            if self.speculative_num_steps is None:
                assert (
                    self.speculative_eagle_topk is None
                    and self.speculative_num_draft_tokens is None
                )
                (
                    self.speculative_num_steps,
                    self.speculative_eagle_topk,
                    self.speculative_num_draft_tokens,
368
                ) = auto_choose_speculative_params(model_arch)
369
370
371

            if self.page_size > 1 and self.speculative_eagle_topk > 1:
                self.speculative_eagle_topk = 1
Lianmin Zheng's avatar
Lianmin Zheng committed
372
                logger.warning(
373
374
375
376
377
378
379
                    "speculative_eagle_topk is adjusted to 1 when page_size > 1"
                )

            if (
                self.speculative_eagle_topk == 1
                and self.speculative_num_draft_tokens != self.speculative_num_steps + 1
            ):
Lianmin Zheng's avatar
Lianmin Zheng committed
380
                logger.warning(
381
382
383
                    "speculative_num_draft_tokens is adjusted to speculative_num_steps + 1 when speculative_eagle_topk == 1"
                )
                self.speculative_num_draft_tokens = self.speculative_num_steps + 1
384

385
            # The token generated from the verify step is counted.
386
            # If sepculative_num_steps >= speculative_num_draft_tokens, the additional tokens will definitely be discarded.
387
            # assert self.speculative_num_steps < self.speculative_num_draft_tokens
388

389
390
391
392
393
394
        # GGUF
        if (
            self.load_format == "auto" or self.load_format == "gguf"
        ) and check_gguf_file(self.model_path):
            self.quantization = self.load_format = "gguf"

395
396
397
        if is_remote_url(self.model_path):
            self.load_format = "remote"

398
399
400
401
        # AMD-specific Triton attention KV splits default number
        if is_hip():
            self.triton_attention_num_kv_splits = 16

Byron Hsu's avatar
Byron Hsu committed
402
403
404
        # PD disaggregation
        if self.disaggregation_mode == "prefill":
            self.disable_cuda_graph = True
405
            logger.warning("Cuda graph is disabled for prefill server")
Byron Hsu's avatar
Byron Hsu committed
406
407
        elif self.disaggregation_mode == "decode":
            self.disable_radix_cache = True
408
            logger.warning("KV cache is forced as chunk cache for decode server")
Byron Hsu's avatar
Byron Hsu committed
409

410
411
412
        os.environ["SGLANG_ENABLE_TORCH_COMPILE"] = (
            "1" if self.enable_torch_compile else "0"
        )
413
414
415
416
        # Set env var before grammar backends init
        os.environ["SGLANG_DISABLE_OUTLINES_DISK_CACHE"] = (
            "1" if self.disable_outlines_disk_cache else "0"
        )
417

Lianmin Zheng's avatar
Lianmin Zheng committed
418
419
    @staticmethod
    def add_cli_args(parser: argparse.ArgumentParser):
420
        # Model and port args
Lianmin Zheng's avatar
Lianmin Zheng committed
421
422
423
424
425
426
427
428
429
430
431
432
        parser.add_argument(
            "--model-path",
            type=str,
            help="The path of the model weights. This can be a local folder or a Hugging Face repo ID.",
            required=True,
        )
        parser.add_argument(
            "--tokenizer-path",
            type=str,
            default=ServerArgs.tokenizer_path,
            help="The path of the tokenizer.",
        )
Yuanhan Zhang's avatar
Yuanhan Zhang committed
433
434
435
436
437
438
        parser.add_argument(
            "--host", type=str, default=ServerArgs.host, help="The host of the server."
        )
        parser.add_argument(
            "--port", type=int, default=ServerArgs.port, help="The port of the server."
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
439
440
441
442
443
444
445
446
447
        parser.add_argument(
            "--tokenizer-mode",
            type=str,
            default=ServerArgs.tokenizer_mode,
            choices=["auto", "slow"],
            help="Tokenizer mode. 'auto' will use the fast "
            "tokenizer if available, and 'slow' will "
            "always use the slow tokenizer.",
        )
448
449
450
        parser.add_argument(
            "--skip-tokenizer-init",
            action="store_true",
451
            help="If set, skip init tokenizer and pass input_ids in generate request.",
452
        )
453
454
455
456
457
        parser.add_argument(
            "--enable-tokenizer-batch-encode",
            action="store_true",
            help="Enable batch tokenization for improved performance when processing multiple text inputs. Do not use with image inputs, pre-tokenized input_ids, or input_embeds.",
        )
458
        parser.add_argument(
Lianmin Zheng's avatar
Lianmin Zheng committed
459
460
461
            "--load-format",
            type=str,
            default=ServerArgs.load_format,
462
463
464
465
466
467
            choices=[
                "auto",
                "pt",
                "safetensors",
                "npcache",
                "dummy",
468
                "sharded_state",
469
470
                "gguf",
                "bitsandbytes",
471
                "layered",
472
                "remote",
473
            ],
Lianmin Zheng's avatar
Lianmin Zheng committed
474
475
476
477
478
479
480
481
482
            help="The format of the model weights to load. "
            '"auto" will try to load the weights in the safetensors format '
            "and fall back to the pytorch bin format if safetensors format "
            "is not available. "
            '"pt" will load the weights in the pytorch bin format. '
            '"safetensors" will load the weights in the safetensors format. '
            '"npcache" will load the weights in pytorch format and store '
            "a numpy cache to speed up the loading. "
            '"dummy" will initialize the weights with random values, '
483
            "which is mainly for profiling."
484
485
            '"gguf" will load the weights in the gguf format. '
            '"bitsandbytes" will load the weights using bitsandbytes '
486
487
488
489
            "quantization."
            '"layered" loads weights layer by layer so that one can quantize a '
            "layer before loading another to make the peak memory envelope "
            "smaller.",
Lianmin Zheng's avatar
Lianmin Zheng committed
490
        )
491
492
493
494
495
        parser.add_argument(
            "--trust-remote-code",
            action="store_true",
            help="Whether or not to allow for custom models defined on the Hub in their own modeling files.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
496
        parser.add_argument(
Lianmin Zheng's avatar
Lianmin Zheng committed
497
            "--dtype",
Cody Yu's avatar
Cody Yu committed
498
            type=str,
Lianmin Zheng's avatar
Lianmin Zheng committed
499
            default=ServerArgs.dtype,
Ying Sheng's avatar
Ying Sheng committed
500
501
            choices=["auto", "half", "float16", "bfloat16", "float", "float32"],
            help="Data type for model weights and activations.\n\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
502
            '* "auto" will use FP16 precision for FP32 and FP16 models, and '
Ying Sheng's avatar
Ying Sheng committed
503
            "BF16 precision for BF16 models.\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
504
505
506
507
            '* "half" for FP16. Recommended for AWQ quantization.\n'
            '* "float16" is the same as "half".\n'
            '* "bfloat16" for a balance between precision and range.\n'
            '* "float" is shorthand for FP32 precision.\n'
Ying Sheng's avatar
Ying Sheng committed
508
509
            '* "float32" for FP32 precision.',
        )
510
511
512
513
        parser.add_argument(
            "--kv-cache-dtype",
            type=str,
            default=ServerArgs.kv_cache_dtype,
bjmsong's avatar
bjmsong committed
514
515
516
            choices=["auto", "fp8_e5m2", "fp8_e4m3"],
            help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" and "fp8_e4m3" is supported for CUDA 11.8+.',
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
517
518
519
520
        parser.add_argument(
            "--quantization",
            type=str,
            default=ServerArgs.quantization,
Ying Sheng's avatar
Ying Sheng committed
521
522
523
524
525
526
            choices=[
                "awq",
                "fp8",
                "gptq",
                "marlin",
                "gptq_marlin",
Ying Sheng's avatar
Ying Sheng committed
527
                "awq_marlin",
Ying Sheng's avatar
Ying Sheng committed
528
                "bitsandbytes",
529
                "gguf",
530
                "modelopt",
531
                "modelopt_fp4",
532
                "w8a8_int8",
HandH1998's avatar
HandH1998 committed
533
                "w8a8_fp8",
AniZpZ's avatar
AniZpZ committed
534
                "moe_wna16",
Ying Sheng's avatar
Ying Sheng committed
535
            ],
Lianmin Zheng's avatar
Lianmin Zheng committed
536
537
            help="The quantization method.",
        )
538
539
540
541
542
543
544
545
546
        parser.add_argument(
            "--quantization-param-path",
            type=nullable_str,
            default=None,
            help="Path to the JSON file containing the KV cache "
            "scaling factors. This should generally be supplied, when "
            "KV cache dtype is FP8. Otherwise, KV cache scaling factors "
            "default to 1.0, which may cause accuracy issues. ",
        )
547
548
549
550
551
552
553
554
555
        parser.add_argument(
            "--context-length",
            type=int,
            default=ServerArgs.context_length,
            help="The model's maximum context length. Defaults to None (will use the value from the model's config.json instead).",
        )
        parser.add_argument(
            "--device",
            type=str,
556
557
            default=ServerArgs.device,
            help="The device to use ('cuda', 'xpu', 'hpu', 'cpu'). Defaults to auto-detection if not specified.",
558
        )
559
560
561
562
563
564
        parser.add_argument(
            "--served-model-name",
            type=str,
            default=ServerArgs.served_model_name,
            help="Override the model name returned by the v1/models endpoint in OpenAI API server.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
565
566
567
568
569
570
        parser.add_argument(
            "--chat-template",
            type=str,
            default=ServerArgs.chat_template,
            help="The buliltin chat template name or the path of the chat template file. This is only used for OpenAI-compatible API server.",
        )
571
572
573
574
575
576
        parser.add_argument(
            "--completion-template",
            type=str,
            default=ServerArgs.completion_template,
            help="The buliltin completion template name or the path of the completion template file. This is only used for OpenAI-compatible API server. only for code completion currently.",
        )
577
578
579
580
581
        parser.add_argument(
            "--is-embedding",
            action="store_true",
            help="Whether to use a CausalLM as an embedding model.",
        )
582
583
584
585
586
587
588
589
        parser.add_argument(
            "--revision",
            type=str,
            default=None,
            help="The specific model version to use. It can be a branch "
            "name, a tag name, or a commit id. If unspecified, will use "
            "the default version.",
        )
590

591
        # Memory and scheduling
Lianmin Zheng's avatar
Lianmin Zheng committed
592
593
594
595
        parser.add_argument(
            "--mem-fraction-static",
            type=float,
            default=ServerArgs.mem_fraction_static,
596
            help="The fraction of the memory used for static allocation (model weights and KV cache memory pool). Use a smaller value if you see out-of-memory errors.",
Lianmin Zheng's avatar
Lianmin Zheng committed
597
        )
598
599
600
601
602
603
        parser.add_argument(
            "--max-running-requests",
            type=int,
            default=ServerArgs.max_running_requests,
            help="The maximum number of running requests.",
        )
604
605
606
607
        parser.add_argument(
            "--max-total-tokens",
            type=int,
            default=ServerArgs.max_total_tokens,
608
609
            help="The maximum number of tokens in the memory pool. If not specified, it will be automatically calculated based on the memory usage fraction. "
            "This option is typically used for development and debugging purposes.",
610
        )
611
612
613
614
        parser.add_argument(
            "--chunked-prefill-size",
            type=int,
            default=ServerArgs.chunked_prefill_size,
615
            help="The maximum number of tokens in a chunk for the chunked prefill. Setting this to -1 means disabling chunked prefill.",
616
617
618
619
620
621
622
        )
        parser.add_argument(
            "--max-prefill-tokens",
            type=int,
            default=ServerArgs.max_prefill_tokens,
            help="The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model's maximum context length.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
623
        parser.add_argument(
624
            "--schedule-policy",
Lianmin Zheng's avatar
Lianmin Zheng committed
625
            type=str,
626
            default=ServerArgs.schedule_policy,
Liangsheng Yin's avatar
Liangsheng Yin committed
627
            choices=["lpm", "random", "fcfs", "dfs-weight"],
628
            help="The scheduling policy of the requests.",
Lianmin Zheng's avatar
Lianmin Zheng committed
629
        )
630
631
632
633
        parser.add_argument(
            "--schedule-conservativeness",
            type=float,
            default=ServerArgs.schedule_conservativeness,
634
            help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see requests being retracted frequently.",
635
        )
636
637
638
639
        parser.add_argument(
            "--cpu-offload-gb",
            type=int,
            default=ServerArgs.cpu_offload_gb,
640
            help="How many GBs of RAM to reserve for CPU offloading.",
641
        )
642
643
644
645
646
647
        parser.add_argument(
            "--page-size",
            type=int,
            default=ServerArgs.page_size,
            help="The number of tokens in a page.",
        )
648

649
        # Other runtime options
Lianmin Zheng's avatar
Lianmin Zheng committed
650
        parser.add_argument(
651
            "--tensor-parallel-size",
Lianmin Zheng's avatar
Lianmin Zheng committed
652
            "--tp-size",
Lianmin Zheng's avatar
Lianmin Zheng committed
653
            type=int,
Lianmin Zheng's avatar
Lianmin Zheng committed
654
            default=ServerArgs.tp_size,
655
            help="The tensor parallelism size.",
656
        )
657
658
659
660
661
662
663
664
665
666
667
668
669
        parser.add_argument(
            "--pipeline-parallel-size",
            "--pp-size",
            type=int,
            default=ServerArgs.pp_size,
            help="The pipeline parallelism size.",
        )
        parser.add_argument(
            "--max-micro-batch-size",
            type=int,
            default=ServerArgs.max_micro_batch_size,
            help="The maximum micro batch size in pipeline parallelism.",
        )
670
671
672
        parser.add_argument(
            "--stream-interval",
            type=int,
Lianmin Zheng's avatar
Lianmin Zheng committed
673
            default=ServerArgs.stream_interval,
674
            help="The interval (or buffer size) for streaming in terms of the token length. A smaller value makes streaming smoother, while a larger value makes the throughput higher",
675
        )
676
677
678
679
680
        parser.add_argument(
            "--stream-output",
            action="store_true",
            help="Whether to output as a sequence of disjoint segments.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
681
682
683
684
        parser.add_argument(
            "--random-seed",
            type=int,
            default=ServerArgs.random_seed,
685
            help="The random seed.",
Lianmin Zheng's avatar
Lianmin Zheng committed
686
        )
687
688
689
690
691
692
        parser.add_argument(
            "--constrained-json-whitespace-pattern",
            type=str,
            default=ServerArgs.constrained_json_whitespace_pattern,
            help=r"Regex pattern for syntactic whitespaces allowed in JSON constrained output. For example, to allow the model generate consecutive whitespaces, set the pattern to [\n\t ]*",
        )
693
694
695
696
697
698
        parser.add_argument(
            "--watchdog-timeout",
            type=float,
            default=ServerArgs.watchdog_timeout,
            help="Set watchdog timeout in seconds. If a forward batch takes longer than this, the server will crash to prevent hanging.",
        )
699
700
701
702
703
704
        parser.add_argument(
            "--dist-timeout",
            type=int,
            default=ServerArgs.dist_timeout,
            help="Set timeout for torch.distributed initialization.",
        )
705
706
707
708
        parser.add_argument(
            "--download-dir",
            type=str,
            default=ServerArgs.download_dir,
Lianmin Zheng's avatar
Lianmin Zheng committed
709
            help="Model download directory for huggingface.",
710
        )
711
712
713
714
715
716
        parser.add_argument(
            "--base-gpu-id",
            type=int,
            default=ServerArgs.base_gpu_id,
            help="The base GPU ID to start allocating GPUs from. Useful when running multiple instances on the same machine.",
        )
717
718
719
720
721
722
        parser.add_argument(
            "--gpu-id-step",
            type=int,
            default=ServerArgs.gpu_id_step,
            help="The delta between consecutive GPU IDs that are used. For example, setting it to 2 will use GPU 0,2,4,...",
        )
723
724

        # Logging
Lianmin Zheng's avatar
Lianmin Zheng committed
725
726
727
728
        parser.add_argument(
            "--log-level",
            type=str,
            default=ServerArgs.log_level,
729
            help="The logging level of all loggers.",
Lianmin Zheng's avatar
Lianmin Zheng committed
730
        )
731
        parser.add_argument(
732
733
734
735
            "--log-level-http",
            type=str,
            default=ServerArgs.log_level_http,
            help="The logging level of HTTP server. If not set, reuse --log-level by default.",
736
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
737
        parser.add_argument(
738
            "--log-requests",
Lianmin Zheng's avatar
Lianmin Zheng committed
739
            action="store_true",
740
741
742
743
744
745
746
747
            help="Log metadata, inputs, outputs of all requests. The verbosity is decided by --log-requests-level",
        )
        parser.add_argument(
            "--log-requests-level",
            type=int,
            default=0,
            help="0: Log metadata. 1. Log metadata and partial input/output. 2. Log every input/output.",
            choices=[0, 1, 2],
Lianmin Zheng's avatar
Lianmin Zheng committed
748
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
749
750
751
        parser.add_argument(
            "--show-time-cost",
            action="store_true",
Ying Sheng's avatar
Ying Sheng committed
752
            help="Show time cost of custom marks.",
Lianmin Zheng's avatar
Lianmin Zheng committed
753
        )
754
755
756
757
758
        parser.add_argument(
            "--enable-metrics",
            action="store_true",
            help="Enable log prometheus metrics.",
        )
759
760
761
762
        parser.add_argument(
            "--decode-log-interval",
            type=int,
            default=ServerArgs.decode_log_interval,
763
            help="The log interval of decode batch.",
764
        )
765

766
        # API related
Liangsheng Yin's avatar
Liangsheng Yin committed
767
768
769
770
        parser.add_argument(
            "--api-key",
            type=str,
            default=ServerArgs.api_key,
771
            help="Set API key of the server. It is also used in the OpenAI API compatible server.",
Liangsheng Yin's avatar
Liangsheng Yin committed
772
        )
773
        parser.add_argument(
774
            "--file-storage-path",
775
            type=str,
776
            default=ServerArgs.file_storage_path,
777
778
            help="The path of the file storage in backend.",
        )
779
780
781
782
783
        parser.add_argument(
            "--enable-cache-report",
            action="store_true",
            help="Return number of cached tokens in usage.prompt_tokens_details for each openai request.",
        )
Xihuai Wang's avatar
Xihuai Wang committed
784
785
786
787
788
789
790
        parser.add_argument(
            "--reasoning-parser",
            type=str,
            choices=list(ReasoningParser.DetectorMap.keys()),
            default=ServerArgs.reasoning_parser,
            help=f"Specify the parser for reasoning models, supported parsers are: {list(ReasoningParser.DetectorMap.keys())}.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
791

792
793
        # Data parallelism
        parser.add_argument(
794
            "--data-parallel-size",
795
796
797
            "--dp-size",
            type=int,
            default=ServerArgs.dp_size,
798
            help="The data parallelism size.",
799
800
801
802
803
        )
        parser.add_argument(
            "--load-balance-method",
            type=str,
            default=ServerArgs.load_balance_method,
804
            help="The load balancing strategy for data parallelism.",
805
806
807
808
809
            choices=[
                "round_robin",
                "shortest_queue",
            ],
        )
810

xiaobochen's avatar
xiaobochen committed
811
812
813
814
815
816
817
818
        # Expert parallelism
        parser.add_argument(
            "--expert-parallel-size",
            "--ep-size",
            type=int,
            default=ServerArgs.ep_size,
            help="The expert parallelism size.",
        )
819

820
        # Multi-node distributed serving
821
        parser.add_argument(
822
823
            "--dist-init-addr",
            "--nccl-init-addr",  # For backward compatbility. This will be removed in the future.
824
            type=str,
825
            help="The host address for initializing distributed backend (e.g., `192.168.0.2:25000`).",
826
827
        )
        parser.add_argument(
Liangsheng Yin's avatar
Liangsheng Yin committed
828
            "--nnodes", type=int, default=ServerArgs.nnodes, help="The number of nodes."
829
        )
830
831
832
        parser.add_argument(
            "--node-rank", type=int, default=ServerArgs.node_rank, help="The node rank."
        )
833

Lianmin Zheng's avatar
Lianmin Zheng committed
834
835
836
837
838
839
840
841
        # Model override args
        parser.add_argument(
            "--json-model-override-args",
            type=str,
            help="A dictionary in JSON string format used to override default model configurations.",
            default=ServerArgs.json_model_override_args,
        )

842
843
844
845
846
847
848
        # LoRA
        parser.add_argument(
            "--lora-paths",
            type=str,
            nargs="*",
            default=None,
            action=LoRAPathAction,
849
            help="The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}.",
850
851
852
853
854
        )
        parser.add_argument(
            "--max-loras-per-batch",
            type=int,
            default=8,
855
856
857
858
859
860
861
            help="Maximum number of adapters for a running batch, include base-only request.",
        )
        parser.add_argument(
            "--lora-backend",
            type=str,
            default="triton",
            help="Choose the kernel backend for multi-LoRA serving.",
862
863
864
        )

        # Kernel backend
865
866
867
        parser.add_argument(
            "--attention-backend",
            type=str,
868
869
870
871
872
873
874
875
            choices=[
                "flashinfer",
                "triton",
                "torch_native",
                "fa3",
                "flashmla",
                "cutlass_mla",
            ],
876
877
878
            default=ServerArgs.attention_backend,
            help="Choose the kernels for attention layers.",
        )
879
880
881
882
883
884
885
        parser.add_argument(
            "--sampling-backend",
            type=str,
            choices=["flashinfer", "pytorch"],
            default=ServerArgs.sampling_backend,
            help="Choose the kernels for sampling layers.",
        )
886
887
888
        parser.add_argument(
            "--grammar-backend",
            type=str,
889
            choices=["xgrammar", "outlines", "llguidance", "none"],
890
            default=ServerArgs.grammar_backend,
Lianmin Zheng's avatar
Lianmin Zheng committed
891
            help="Choose the backend for grammar-guided decoding.",
892
        )
893
894
        parser.add_argument(
            "--enable-flashinfer-mla",
895
896
            action=DeprecatedAction,
            help="--enable-flashinfer-mla is deprecated. Please use '--attention-backend flashinfer' instead.",
897
        )
lukec's avatar
lukec committed
898
899
        parser.add_argument(
            "--enable-flashmla",
900
901
            action=DeprecatedAction,
            help="--enable-flashmla is deprecated. Please use '--attention-backend flashmla' instead.",
lukec's avatar
lukec committed
902
        )
903
904
905
906
907
        parser.add_argument(
            "--flashinfer-mla-disable-ragged",
            action="store_true",
            help="Not using ragged prefill wrapper when running flashinfer mla",
        )
908

909
910
911
912
        # Speculative decoding
        parser.add_argument(
            "--speculative-algorithm",
            type=str,
James Liu's avatar
James Liu committed
913
            choices=["EAGLE", "EAGLE3", "NEXTN"],
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
            help="Speculative algorithm.",
        )
        parser.add_argument(
            "--speculative-draft-model-path",
            type=str,
            help="The path of the draft model weights. This can be a local folder or a Hugging Face repo ID.",
        )
        parser.add_argument(
            "--speculative-num-steps",
            type=int,
            help="The number of steps sampled from draft model in Speculative Decoding.",
            default=ServerArgs.speculative_num_steps,
        )
        parser.add_argument(
            "--speculative-eagle-topk",
            type=int,
930
            help="The number of tokens sampled from the draft model in eagle2 each step.",
931
932
            default=ServerArgs.speculative_eagle_topk,
        )
933
934
935
        parser.add_argument(
            "--speculative-num-draft-tokens",
            type=int,
936
            help="The number of tokens sampled from the draft model in Speculative Decoding.",
937
938
            default=ServerArgs.speculative_num_draft_tokens,
        )
939
940
941
942
943
944
945
946
947
948
949
950
        parser.add_argument(
            "--speculative-accept-threshold-single",
            type=float,
            help="Accept a draft token if its probability in the target model is greater than this threshold.",
            default=ServerArgs.speculative_accept_threshold_single,
        )
        parser.add_argument(
            "--speculative-accept-threshold-acc",
            type=float,
            help="The accept probability of a draft token is raised from its target probability p to min(1, p / threshold_acc).",
            default=ServerArgs.speculative_accept_threshold_acc,
        )
951
952
953
954
955
956
        parser.add_argument(
            "--speculative-token-map",
            type=str,
            help="The path of the draft model's small vocab table.",
            default=ServerArgs.speculative_token_map,
        )
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994

        # Double Sparsity
        parser.add_argument(
            "--enable-double-sparsity",
            action="store_true",
            help="Enable double sparsity attention",
        )
        parser.add_argument(
            "--ds-channel-config-path",
            type=str,
            default=ServerArgs.ds_channel_config_path,
            help="The path of the double sparsity channel config",
        )
        parser.add_argument(
            "--ds-heavy-channel-num",
            type=int,
            default=ServerArgs.ds_heavy_channel_num,
            help="The number of heavy channels in double sparsity attention",
        )
        parser.add_argument(
            "--ds-heavy-token-num",
            type=int,
            default=ServerArgs.ds_heavy_token_num,
            help="The number of heavy tokens in double sparsity attention",
        )
        parser.add_argument(
            "--ds-heavy-channel-type",
            type=str,
            default=ServerArgs.ds_heavy_channel_type,
            help="The type of heavy channels in double sparsity attention",
        )
        parser.add_argument(
            "--ds-sparse-decode-threshold",
            type=int,
            default=ServerArgs.ds_sparse_decode_threshold,
            help="The type of heavy channels in double sparsity attention",
        )

995
        # Optimization/debug options
Liangsheng Yin's avatar
Liangsheng Yin committed
996
        parser.add_argument(
Lianmin Zheng's avatar
Lianmin Zheng committed
997
            "--disable-radix-cache",
Liangsheng Yin's avatar
Liangsheng Yin committed
998
            action="store_true",
Ying Sheng's avatar
Ying Sheng committed
999
            help="Disable RadixAttention for prefix caching.",
Liangsheng Yin's avatar
Liangsheng Yin committed
1000
        )
1001
1002
1003
1004
1005
        parser.add_argument(
            "--disable-cuda-graph",
            action="store_true",
            help="Disable cuda graph.",
        )
1006
        parser.add_argument(
1007
1008
1009
1010
            "--disable-cuda-graph-padding",
            action="store_true",
            help="Disable cuda graph when padding is needed. Still uses cuda graph when padding is not needed.",
        )
1011
1012
1013
1014
1015
        parser.add_argument(
            "--enable-nccl-nvls",
            action="store_true",
            help="Enable NCCL NVLS for prefill heavy requests when available.",
        )
1016
        parser.add_argument(
1017
            "--disable-outlines-disk-cache",
1018
            action="store_true",
1019
            help="Disable disk cache of outlines to avoid possible crashes related to file system or high concurrency.",
1020
        )
1021
1022
1023
1024
1025
        parser.add_argument(
            "--disable-custom-all-reduce",
            action="store_true",
            help="Disable the custom all-reduce kernel and fall back to NCCL.",
        )
1026
        parser.add_argument(
1027
1028
            "--enable-multimodal",
            default=ServerArgs.enable_multimodal,
1029
            action="store_true",
1030
            help="Enable the multimodal functionality for the served model. If the model being served is not multimodal, nothing will happen",
1031
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1032
        parser.add_argument(
1033
            "--disable-overlap-schedule",
Lianmin Zheng's avatar
Lianmin Zheng committed
1034
            action="store_true",
1035
            help="Disable the overlap scheduler, which overlaps the CPU scheduler with GPU model worker.",
Lianmin Zheng's avatar
Lianmin Zheng committed
1036
        )
1037
1038
1039
        parser.add_argument(
            "--enable-mixed-chunk",
            action="store_true",
1040
            help="Enabling mixing prefill and decode in a batch when using chunked prefill.",
1041
        )
Ke Bao's avatar
Ke Bao committed
1042
1043
1044
1045
1046
        parser.add_argument(
            "--enable-dp-attention",
            action="store_true",
            help="Enabling data parallelism for attention and tensor parallelism for FFN. The dp size should be equal to the tp size. Currently only DeepSeek-V2 is supported.",
        )
xiaobochen's avatar
xiaobochen committed
1047
1048
1049
1050
1051
        parser.add_argument(
            "--enable-ep-moe",
            action="store_true",
            help="Enabling expert parallelism for moe. The ep size is equal to the tp size.",
        )
1052
1053
1054
        parser.add_argument(
            "--enable-torch-compile",
            action="store_true",
1055
1056
            help="Optimize the model with torch.compile. Experimental feature.",
        )
1057
        parser.add_argument(
1058
            "--torch-compile-max-bs",
1059
            type=int,
1060
            default=ServerArgs.torch_compile_max_bs,
1061
1062
            help="Set the maximum batch size when using torch compile.",
        )
1063
        parser.add_argument(
1064
            "--cuda-graph-max-bs",
1065
            type=int,
1066
            default=ServerArgs.cuda_graph_max_bs,
1067
1068
            help="Set the maximum batch size for cuda graph.",
        )
1069
1070
1071
1072
1073
1074
        parser.add_argument(
            "--cuda-graph-bs",
            type=int,
            nargs="+",
            help="Set the list of batch sizes for cuda graph.",
        )
1075
1076
1077
1078
        parser.add_argument(
            "--torchao-config",
            type=str,
            default=ServerArgs.torchao_config,
1079
            help="Optimize the model with torchao. Experimental feature. Current choices are: int8dq, int8wo, int4wo-<group_size>, fp8wo, fp8dq-per_tensor, fp8dq-per_row",
1080
        )
1081
1082
1083
1084
1085
        parser.add_argument(
            "--enable-nan-detection",
            action="store_true",
            help="Enable the NaN detection for debugging purposes.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1086
        parser.add_argument(
1087
            "--enable-p2p-check",
Lianmin Zheng's avatar
Lianmin Zheng committed
1088
            action="store_true",
1089
            help="Enable P2P check for GPU access, otherwise the p2p access is allowed by default.",
Lianmin Zheng's avatar
Lianmin Zheng committed
1090
        )
1091
        parser.add_argument(
1092
            "--triton-attention-reduce-in-fp32",
1093
            action="store_true",
1094
            help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
1095
            "This only affects Triton attention kernels.",
1096
        )
1097
1098
1099
1100
1101
1102
        parser.add_argument(
            "--triton-attention-num-kv-splits",
            type=int,
            default=ServerArgs.triton_attention_num_kv_splits,
            help="The number of KV splits in flash decoding Triton kernel. Larger value is better in longer context scenarios. The default value is 8.",
        )
1103
1104
1105
1106
1107
1108
1109
1110
        parser.add_argument(
            "--num-continuous-decode-steps",
            type=int,
            default=ServerArgs.num_continuous_decode_steps,
            help="Run multiple continuous decoding steps to reduce scheduling overhead. "
            "This can potentially increase throughput but may also increase time-to-first-token latency. "
            "The default value is 1, meaning only run one decoding step at a time.",
        )
1111
1112
1113
1114
1115
        parser.add_argument(
            "--delete-ckpt-after-loading",
            action="store_true",
            help="Delete the model checkpoint after loading the model.",
        )
1116
1117
1118
1119
1120
        parser.add_argument(
            "--enable-memory-saver",
            action="store_true",
            help="Allow saving memory using release_memory_occupation and resume_memory_occupation",
        )
1121
1122
1123
1124
1125
        parser.add_argument(
            "--allow-auto-truncate",
            action="store_true",
            help="Allow automatically truncating requests that exceed the maximum input length instead of returning an error.",
        )
1126
1127
1128
1129
1130
        parser.add_argument(
            "--enable-custom-logit-processor",
            action="store_true",
            help="Enable users to pass custom logit processors to the server (disabled by default for security)",
        )
YAMY's avatar
YAMY committed
1131
1132
1133
        parser.add_argument(
            "--tool-call-parser",
            type=str,
1134
            choices=["qwen25", "mistral", "llama3", "deepseekv3", "pythonic"],
YAMY's avatar
YAMY committed
1135
            default=ServerArgs.tool_call_parser,
1136
            help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', 'llama3', 'deepseekv3', and 'pythonic'.",
YAMY's avatar
YAMY committed
1137
        )
1138
1139
1140
1141
1142
        parser.add_argument(
            "--enable-hierarchical-cache",
            action="store_true",
            help="Enable hierarchical cache",
        )
1143
1144
1145
1146
1147
1148
        parser.add_argument(
            "--hicache-ratio",
            type=float,
            default=ServerArgs.hicache_ratio,
            help="The ratio of the size of host KV cache memory pool to the size of device pool.",
        )
Zhiqiang Xie's avatar
Zhiqiang Xie committed
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
        parser.add_argument(
            "--hicache-size",
            type=int,
            default=ServerArgs.hicache_size,
            help="The size of host KV cache memory pool in gigabytes, which will override the hicache_ratio if set.",
        )
        parser.add_argument(
            "--hicache-write-policy",
            type=str,
            choices=["write_back", "write_through", "write_through_selective"],
            default=ServerArgs.hicache_write_policy,
            help="The write policy of hierarchical cache.",
        )
1162
1163
1164
1165
1166
        parser.add_argument(
            "--enable-deepep-moe",
            action="store_true",
            help="Enabling DeepEP MoE implementation for EP MoE.",
        )
1167
1168
1169
1170
1171
1172
        parser.add_argument(
            "--moe-dense-tp-size",
            type=int,
            default=ServerArgs.moe_dense_tp_size,
            help="TP size for MoE dense MLP layers. This flag is useful when, with large TP size, there are errors caused by weights in MLP layers having dimension smaller than the min dimension GEMM supports.",
        )
1173
1174
1175
1176
        parser.add_argument(
            "--deepep-mode",
            type=str,
            choices=["normal", "low_latency", "auto"],
1177
            default="auto",
1178
1179
            help="Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch.",
        )
1180

1181
1182
1183
        parser.add_argument(
            "--n-share-experts-fusion",
            type=int,
1184
            default=0,
1185
1186
            help="The number of shared_experts need to be replicated to fuse with normal experts in deepseek v3/r1, "
            "set it to tp_size can get best optimized performace.",
1187
        )
1188
1189
1190
1191
1192
        parser.add_argument(
            "--disable-chunked-prefix-cache",
            action="store_true",
            help="Disable chunked prefix cache feature for deepseek, which should save overhead for short sequences.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1193
1194
1195
1196
1197
        parser.add_argument(
            "--disable-fast-image-processor",
            action="store_true",
            help="Adopt base image processor instead of fast image processor.",
        )
1198

1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
        # Server warmups
        parser.add_argument(
            "--warmups",
            type=str,
            required=False,
            help="Specify custom warmup functions (csv) to run before server starts eg. --warmups=warmup_name1,warmup_name2 "
            "will run the functions `warmup_name1` and `warmup_name2` specified in warmup.py before the server starts listening for requests",
        )

        # Debug tensor dumps
        parser.add_argument(
            "--debug-tensor-dump-output-folder",
            type=str,
            default=ServerArgs.debug_tensor_dump_output_folder,
            help="The output folder for dumping tensors.",
        )
        parser.add_argument(
            "--debug-tensor-dump-input-file",
            type=str,
            default=ServerArgs.debug_tensor_dump_input_file,
            help="The input filename for dumping tensors",
        )
        parser.add_argument(
            "--debug-tensor-dump-inject",
            type=str,
            default=ServerArgs.debug_tensor_dump_inject,
            help="Inject the outputs from jax as the input of every layer.",
        )

Byron Hsu's avatar
Byron Hsu committed
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
        # Disaggregation
        parser.add_argument(
            "--disaggregation-mode",
            type=str,
            default="null",
            choices=["null", "prefill", "decode"],
            help='Only used for PD disaggregation. "prefill" for prefill-only server, and "decode" for decode-only server. If not specified, it is not PD disaggregated',
        )
        parser.add_argument(
            "--disaggregation-bootstrap-port",
            type=int,
            default=ServerArgs.disaggregation_bootstrap_port,
            help="Bootstrap server port on the prefill server. Default is 8998.",
        )
1242
1243
1244
1245
        parser.add_argument(
            "--disaggregation-transfer-backend",
            type=str,
            default=ServerArgs.disaggregation_transfer_backend,
1246
            choices=["mooncake", "nixl"],
1247
1248
            help="The backend for disaggregation transfer. Default is mooncake.",
        )
1249
1250
1251
1252
        parser.add_argument(
            "--disaggregation-ib-device",
            type=str,
            default=ServerArgs.disaggregation_ib_device,
1253
1254
1255
            help="The InfiniBand devices for disaggregation transfer, accepts single device (e.g., --disaggregation-ib-device mlx5_0) "
            "or multiple comma-separated devices (e.g., --disaggregation-ib-device mlx5_0,mlx5_1). "
            "Default is None, which triggers automatic device detection when mooncake backend is enabled.",
1256
        )
Byron Hsu's avatar
Byron Hsu committed
1257

Lianmin Zheng's avatar
Lianmin Zheng committed
1258
1259
    @classmethod
    def from_cli_args(cls, args: argparse.Namespace):
1260
        args.tp_size = args.tensor_parallel_size
1261
        args.pp_size = args.pipeline_parallel_size
1262
        args.dp_size = args.data_parallel_size
xiaobochen's avatar
xiaobochen committed
1263
        args.ep_size = args.expert_parallel_size
Lianmin Zheng's avatar
Lianmin Zheng committed
1264
1265
1266
1267
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        return cls(**{attr: getattr(args, attr) for attr in attrs})

    def url(self):
1268
        if is_valid_ipv6_address(self.host):
1269
1270
1271
            return f"http://[{self.host}]:{self.port}"
        else:
            return f"http://{self.host}:{self.port}"
Lianmin Zheng's avatar
Lianmin Zheng committed
1272

1273
1274
    def check_server_args(self):
        assert (
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
            self.tp_size * self.pp_size
        ) % self.nnodes == 0, "tp_size must be divisible by number of nodes"

        # FIXME pp constraints
        if self.pp_size > 1:
            logger.warning(f"Turn off overlap scheule for pipeline parallelism.")
            self.disable_overlap_schedule = True
            assert (
                self.disable_overlap_schedule
                and self.speculative_algorithm is None
                and not self.enable_mixed_chunk
            ), "Pipeline parallelism is not compatible with overlap schedule, speculative decoding, mixed chunked prefill."

1288
        assert not (
1289
1290
            self.dp_size > 1 and self.nnodes != 1 and not self.enable_dp_attention
        ), "multi-node data parallel is not supported unless dp attention!"
1291
1292
1293
1294
1295
        assert (
            self.max_loras_per_batch > 0
            # FIXME
            and (self.lora_paths is None or self.disable_radix_cache)
        ), "compatibility of lora and cuda graph and radix attention is in progress"
1296
        assert self.base_gpu_id >= 0, "base_gpu_id must be non-negative"
1297
        assert self.gpu_id_step >= 1, "gpu_id_step must be positive"
1298

1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
        if isinstance(self.lora_paths, list):
            lora_paths = self.lora_paths
            self.lora_paths = {}
            for lora_path in lora_paths:
                if "=" in lora_path:
                    name, path = lora_path.split("=", 1)
                    self.lora_paths[name] = path
                else:
                    self.lora_paths[lora_path] = lora_path

Lianmin Zheng's avatar
Lianmin Zheng committed
1309

Lianmin Zheng's avatar
Lianmin Zheng committed
1310
def prepare_server_args(argv: List[str]) -> ServerArgs:
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
    """
    Prepare the server arguments from the command line arguments.

    Args:
        args: The command line arguments. Typically, it should be `sys.argv[1:]`
            to ensure compatibility with `parse_args` when no arguments are passed.

    Returns:
        The server arguments.
    """
    parser = argparse.ArgumentParser()
    ServerArgs.add_cli_args(parser)
Lianmin Zheng's avatar
Lianmin Zheng committed
1323
    raw_args = parser.parse_args(argv)
1324
1325
1326
1327
    server_args = ServerArgs.from_cli_args(raw_args)
    return server_args


1328
1329
1330
ZMQ_TCP_PORT_DELTA = 233


Lianmin Zheng's avatar
Lianmin Zheng committed
1331
1332
@dataclasses.dataclass
class PortArgs:
1333
1334
1335
1336
1337
1338
    # The ipc filename for tokenizer to receive inputs from detokenizer (zmq)
    tokenizer_ipc_name: str
    # The ipc filename for scheduler (rank 0) to receive inputs from tokenizer (zmq)
    scheduler_input_ipc_name: str
    # The ipc filename for detokenizer to receive inputs from scheduler (zmq)
    detokenizer_ipc_name: str
1339

1340
1341
    # The port for nccl initialization (torch.dist)
    nccl_port: int
1342

1343
1344
1345
    # The ipc filename for rpc call between Engine and Scheduler
    rpc_ipc_name: str

1346
    @staticmethod
1347
    def init_new(server_args, dp_rank: Optional[int] = None) -> "PortArgs":
1348
        port = server_args.port + random.randint(100, 1000)
1349
1350
1351
        while True:
            if is_port_available(port):
                break
TianYu GUO's avatar
TianYu GUO committed
1352
1353
1354
1355
            if port < 60000:
                port += 42
            else:
                port -= 43
1356

1357
1358
1359
1360
1361
1362
1363
        if not server_args.enable_dp_attention:
            # Normal case, use IPC within a single node
            return PortArgs(
                tokenizer_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
                scheduler_input_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
                detokenizer_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
                nccl_port=port,
1364
                rpc_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
1365
1366
1367
1368
1369
            )
        else:
            # DP attention. Use TCP + port to handle both single-node and multi-node.
            if server_args.nnodes == 1 and server_args.dist_init_addr is None:
                dist_init_addr = ("127.0.0.1", server_args.port + ZMQ_TCP_PORT_DELTA)
Vincent's avatar
Vincent committed
1370
1371
1372
            elif server_args.dist_init_addr.startswith("["):  # ipv6 address
                port_num, host = configure_ipv6(server_args.dist_init_addr)
                dist_init_addr = (host, str(port_num))
1373
1374
            else:
                dist_init_addr = server_args.dist_init_addr.split(":")
Vincent's avatar
Vincent committed
1375

1376
1377
1378
1379
1380
1381
1382
1383
            assert (
                len(dist_init_addr) == 2
            ), "please provide --dist-init-addr as host:port of head node"

            dist_init_host, dist_init_port = dist_init_addr
            port_base = int(dist_init_port) + 1
            if dp_rank is None:
                scheduler_input_port = (
1384
                    port_base + 3
1385
                )  # TokenizerManager to DataParallelController
1386
            else:
1387
                scheduler_input_port = port_base + 3 + 1 + dp_rank
1388
1389
1390
1391
1392
1393

            return PortArgs(
                tokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base}",
                scheduler_input_ipc_name=f"tcp://{dist_init_host}:{scheduler_input_port}",
                detokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base + 1}",
                nccl_port=port,
1394
                rpc_ipc_name=f"tcp://{dist_init_host}:{port_base + 2}",
1395
            )
1396

1397
1398
1399
1400
1401
1402
1403
1404
1405
1406

class LoRAPathAction(argparse.Action):
    def __call__(self, parser, namespace, values, option_string=None):
        setattr(namespace, self.dest, {})
        for lora_path in values:
            if "=" in lora_path:
                name, path = lora_path.split("=", 1)
                getattr(namespace, self.dest)[name] = path
            else:
                getattr(namespace, self.dest)[lora_path] = lora_path
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416


class DeprecatedAction(argparse.Action):
    def __init__(self, option_strings, dest, nargs=0, **kwargs):
        super(DeprecatedAction, self).__init__(
            option_strings, dest, nargs=nargs, **kwargs
        )

    def __call__(self, parser, namespace, values, option_string=None):
        raise ValueError(self.help)
1417
1418


1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
def get_model_arch(args: ServerArgs):
    hf_config = get_config(
        args.model_path,
        trust_remote_code=args.trust_remote_code,
        revision=args.revision,
        model_override_args=json.loads(args.json_model_override_args),
    )
    return hf_config.architectures[0]


def auto_choose_speculative_params(arch: str):
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
    """
    Automatically choose the parameters for speculative decoding.

    You can tune them on your own models and prompts with scripts/playground/bench_speculative.py
    """
    if arch in ["LlamaForCausalLM"]:
        # The default value for llama
        return (5, 4, 8)
    elif arch in ["DeepseekV3ForCausalLM", "DeepseekV2ForCausalLM"]:
        # The default value for deepseek
        return (5, 4, 8)
    elif arch in ["Grok1ForCausalLM", "Grok1VForCausalLM"]:
        return (5, 4, 8)
    else:
        # The default value for all other models
        return (5, 4, 8)