server_args.py 72.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
Lianmin Zheng's avatar
Lianmin Zheng committed
14
15
"""The arguments of the server."""

Lianmin Zheng's avatar
Lianmin Zheng committed
16
17
import argparse
import dataclasses
18
import json
19
import logging
20
import os
21
import random
22
import tempfile
23
from typing import List, Literal, Optional
Lianmin Zheng's avatar
Lianmin Zheng committed
24

25
from sglang.srt.hf_transformers_utils import check_gguf_file, get_config
Xihuai Wang's avatar
Xihuai Wang committed
26
from sglang.srt.reasoning_parser import ReasoningParser
27
from sglang.srt.utils import (
Vincent's avatar
Vincent committed
28
    configure_ipv6,
29
    get_device,
Lianmin Zheng's avatar
Lianmin Zheng committed
30
    get_device_memory_capacity,
31
    is_flashinfer_available,
HAI's avatar
HAI committed
32
    is_hip,
33
    is_port_available,
34
    is_remote_url,
35
    is_valid_ipv6_address,
bjmsong's avatar
bjmsong committed
36
    nullable_str,
37
)
38

39
40
logger = logging.getLogger(__name__)

Lianmin Zheng's avatar
Lianmin Zheng committed
41
42
43

@dataclasses.dataclass
class ServerArgs:
Lianmin Zheng's avatar
Lianmin Zheng committed
44
    # Model and tokenizer
Lianmin Zheng's avatar
Lianmin Zheng committed
45
46
47
    model_path: str
    tokenizer_path: Optional[str] = None
    tokenizer_mode: str = "auto"
48
    skip_tokenizer_init: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
49
    load_format: str = "auto"
50
    model_loader_extra_config: str = "{}"
51
    trust_remote_code: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
52
    dtype: str = "auto"
53
    kv_cache_dtype: str = "auto"
Lianmin Zheng's avatar
Lianmin Zheng committed
54
    quantization: Optional[str] = None
Vincent's avatar
Vincent committed
55
    quantization_param_path: Optional[str] = None
56
    context_length: Optional[int] = None
57
    device: Optional[str] = None
58
    served_model_name: Optional[str] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
59
    chat_template: Optional[str] = None
60
    completion_template: Optional[str] = None
61
    is_embedding: bool = False
62
    enable_multimodal: Optional[bool] = None
63
    revision: Optional[str] = None
tarinkk's avatar
tarinkk committed
64
    hybrid_kvcache_ratio: Optional[float] = None
65
    impl: str = "auto"
Lianmin Zheng's avatar
Lianmin Zheng committed
66

67
    # Port for the HTTP server
Lianmin Zheng's avatar
Lianmin Zheng committed
68
69
70
71
    host: str = "127.0.0.1"
    port: int = 30000

    # Memory and scheduling
Lianmin Zheng's avatar
Lianmin Zheng committed
72
    mem_fraction_static: Optional[float] = None
73
    max_running_requests: Optional[int] = None
74
    max_total_tokens: Optional[int] = None
75
    chunked_prefill_size: Optional[int] = None
76
    max_prefill_tokens: int = 16384
77
    schedule_policy: str = "fcfs"
78
    schedule_conservativeness: float = 1.0
79
    cpu_offload_gb: int = 0
80
    page_size: int = 1
Lianmin Zheng's avatar
Lianmin Zheng committed
81
82
83

    # Other runtime options
    tp_size: int = 1
84
85
    pp_size: int = 1
    max_micro_batch_size: Optional[int] = None
86
    stream_interval: int = 1
87
    stream_output: bool = False
88
    random_seed: Optional[int] = None
89
    constrained_json_whitespace_pattern: Optional[str] = None
90
    watchdog_timeout: float = 300
91
    dist_timeout: Optional[int] = None  # timeout for torch.distributed
92
    download_dir: Optional[str] = None
93
    base_gpu_id: int = 0
94
    gpu_id_step: int = 1
95
    sleep_on_idle: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
96
97
98

    # Logging
    log_level: str = "info"
99
    log_level_http: Optional[str] = None
100
    log_requests: bool = False
101
    log_requests_level: int = 0
102
    crash_dump_folder: Optional[str] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
103
    show_time_cost: bool = False
104
    enable_metrics: bool = False
105
106
107
108
    bucket_time_to_first_token: Optional[List[float]] = None
    bucket_e2e_request_latency: Optional[List[float]] = None
    bucket_inter_token_latency: Optional[List[float]] = None
    collect_tokens_histogram: bool = False
109
    decode_log_interval: int = 40
110
    enable_request_time_stats_logging: bool = False
111
    kv_events_config: Optional[str] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
112

113
    # API related
114
    api_key: Optional[str] = None
115
    file_storage_path: str = "sglang_storage"
116
    enable_cache_report: bool = False
Xihuai Wang's avatar
Xihuai Wang committed
117
    reasoning_parser: Optional[str] = None
118
    tool_call_parser: Optional[str] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
119

120
121
122
    # Data parallelism
    dp_size: int = 1
    load_balance_method: str = "round_robin"
123

124
    # Multi-node distributed serving
125
    dist_init_addr: Optional[str] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
126
    nnodes: int = 1
127
    node_rank: int = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
128
129
130

    # Model override args in JSON
    json_model_override_args: str = "{}"
131
    preferred_sampling_params: Optional[str] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
132

133
134
135
    # LoRA
    lora_paths: Optional[List[str]] = None
    max_loras_per_batch: int = 8
136
    lora_backend: str = "triton"
137
138

    # Kernel backend
139
140
    attention_backend: Optional[str] = None
    sampling_backend: Optional[str] = None
141
    grammar_backend: Optional[str] = None
142
    mm_attention_backend: Optional[str] = None
143

144
145
    # Speculative decoding
    speculative_algorithm: Optional[str] = None
146
    speculative_draft_model_path: Optional[str] = None
147
148
149
    speculative_num_steps: Optional[int] = None
    speculative_eagle_topk: Optional[int] = None
    speculative_num_draft_tokens: Optional[int] = None
150
151
    speculative_accept_threshold_single: float = 1.0
    speculative_accept_threshold_acc: float = 1.0
152
    speculative_token_map: Optional[str] = None
153

154
155
156
157
    # Expert parallelism
    ep_size: int = 1
    enable_ep_moe: bool = False
    enable_deepep_moe: bool = False
158
    enable_flashinfer_moe: bool = False
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
    deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto"
    ep_num_redundant_experts: int = 0
    ep_dispatch_algorithm: Optional[Literal["static", "dynamic", "fake"]] = None
    init_expert_location: str = "trivial"
    enable_eplb: bool = False
    eplb_algorithm: str = "auto"
    eplb_rebalance_num_iterations: int = 1000
    eplb_rebalance_layers_per_chunk: Optional[int] = None
    expert_distribution_recorder_mode: Optional[
        Literal["stat", "stat_approx", "per_pass", "per_token"]
    ] = None
    expert_distribution_recorder_buffer_size: Optional[int] = None
    enable_expert_distribution_metrics: bool = False
    deepep_config: Optional[str] = None
    moe_dense_tp_size: Optional[int] = None

175
176
    # Double Sparsity
    enable_double_sparsity: bool = False
Vincent's avatar
Vincent committed
177
    ds_channel_config_path: Optional[str] = None
178
179
180
181
182
    ds_heavy_channel_num: int = 32
    ds_heavy_token_num: int = 256
    ds_heavy_channel_type: str = "qk"
    ds_sparse_decode_threshold: int = 4096

183
    # Optimization/debug options
Lianmin Zheng's avatar
Lianmin Zheng committed
184
    disable_radix_cache: bool = False
185
186
    cuda_graph_max_bs: Optional[int] = None
    cuda_graph_bs: Optional[List[int]] = None
187
    disable_cuda_graph: bool = False
188
    disable_cuda_graph_padding: bool = False
189
    enable_profile_cuda_graph: bool = False
190
    enable_nccl_nvls: bool = False
191
    enable_tokenizer_batch_encode: bool = False
192
    disable_outlines_disk_cache: bool = False
193
    disable_custom_all_reduce: bool = False
194
    enable_mscclpp: bool = False
195
    disable_overlap_schedule: bool = False
196
    disable_overlap_cg_plan: bool = False
197
    enable_mixed_chunk: bool = False
Ke Bao's avatar
Ke Bao committed
198
    enable_dp_attention: bool = False
199
    enable_dp_lm_head: bool = False
200
    enable_two_batch_overlap: bool = False
201
    enable_torch_compile: bool = False
202
    torch_compile_max_bs: int = 32
203
    torchao_config: str = ""
204
    enable_nan_detection: bool = False
205
    enable_p2p_check: bool = False
206
    triton_attention_reduce_in_fp32: bool = False
207
    triton_attention_num_kv_splits: int = 8
208
    num_continuous_decode_steps: int = 1
209
    delete_ckpt_after_loading: bool = False
210
    enable_memory_saver: bool = False
211
    allow_auto_truncate: bool = False
212
    enable_custom_logit_processor: bool = False
213
    enable_hierarchical_cache: bool = False
214
    hicache_ratio: float = 2.0
Zhiqiang Xie's avatar
Zhiqiang Xie committed
215
216
    hicache_size: int = 0
    hicache_write_policy: str = "write_through_selective"
217
    flashinfer_mla_disable_ragged: bool = False
218
    disable_shared_experts_fusion: bool = False
219
    disable_chunked_prefix_cache: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
220
    disable_fast_image_processor: bool = False
221
    enable_return_hidden_states: bool = False
222
    warmups: Optional[str] = None
223
224
225
226
227

    # Debug tensor dumps
    debug_tensor_dump_output_folder: Optional[str] = None
    debug_tensor_dump_input_file: Optional[str] = None
    debug_tensor_dump_inject: bool = False
228
    debug_tensor_dump_prefill_only: bool = False
229

Byron Hsu's avatar
Byron Hsu committed
230
231
    # For PD disaggregation: can be "null" (not disaggregated), "prefill" (prefill-only), or "decode" (decode-only)
    disaggregation_mode: str = "null"
232
    disaggregation_transfer_backend: str = "mooncake"
233
    disaggregation_bootstrap_port: int = 8998
Byron Hsu's avatar
Byron Hsu committed
234
235
236
    disaggregation_decode_tp: Optional[int] = None
    disaggregation_decode_dp: Optional[int] = None
    disaggregation_prefill_pp: Optional[int] = 1
237
    disaggregation_ib_device: Optional[str] = None
238
    num_reserved_decode_tokens: int = 512  # used for decode kv cache offload in PD
239
    pdlb_url: Optional[str] = None
Byron Hsu's avatar
Byron Hsu committed
240

241
242
    # For model weight update
    custom_weight_loader: Optional[List[str]] = None
243
    weight_loader_disable_mmap: bool = False
244

Lianmin Zheng's avatar
Lianmin Zheng committed
245
    def __post_init__(self):
246
247
248
        # Expert parallelism
        if self.enable_ep_moe:
            self.ep_size = self.tp_size
Lianmin Zheng's avatar
Lianmin Zheng committed
249
            logger.warning(
250
251
                f"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
            )
252
253
254
255
256
257
258
259
260
        if self.enable_flashinfer_moe:
            assert (
                self.quantization == "modelopt_fp4"
            ), "modelopt_fp4 quantization is required for Flashinfer MOE"
            os.environ["TRTLLM_ENABLE_PDL"] = "1"
            self.disable_shared_experts_fusion = True
            logger.warning(
                f"Flashinfer MoE is enabled. Shared expert fusion is disabled."
            )
261
        # Set missing default values
Lianmin Zheng's avatar
Lianmin Zheng committed
262
263
        if self.tokenizer_path is None:
            self.tokenizer_path = self.model_path
264

265
266
267
        if self.device is None:
            self.device = get_device()

268
269
270
        if self.served_model_name is None:
            self.served_model_name = self.model_path

271
272
273
        if self.random_seed is None:
            self.random_seed = random.randint(0, 1 << 30)

Lianmin Zheng's avatar
Lianmin Zheng committed
274
        gpu_mem = get_device_memory_capacity(self.device)
275

276
        # Set mem fraction static
Lianmin Zheng's avatar
Lianmin Zheng committed
277
        if self.mem_fraction_static is None:
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
            if gpu_mem is not None:
                # GPU memory capacity = model weights + KV cache pool + activations + cuda graph buffers
                # mem_fraction_static = (model weights + KV cache pool) / GPU memory capacity.

                # We want mem_fraction_static to be as large as possible but still has enough room
                # for activations and cuda graph buffers. We use the following heuristic to
                # compute the needed size for activations and cuda graph buffers:
                # - The size of the activation depends on the chunked_prefill_size and model size.
                # - The size of cuda graph buffers depends on the cuda graph capture range and model size.
                # For GPUs with more memory, we use a larger chunked_prefill_size and
                # capture more cuda graphs, so they need to reserve more memory.
                parallel_size = self.tp_size * self.pp_size

                if gpu_mem < 20 * 1024:
                    # T4, 4080. (chunked_prefill_size 2k, cuda_graph_max_bs 8)
                    reserved_mem = (2.8 + parallel_size / 10) * 1024
                elif gpu_mem < 35 * 1024:
                    # A10, L40, 4090, 5090. (chunked_prefill_size 2k, cuda_graph_max_bs 8)
                    reserved_mem = (2.8 + parallel_size / 10) * 1024
                elif gpu_mem < 90 * 1024:
                    # H100, A100. (chunked_prefill_size 8k, cuda_graph_max_bs 160)
                    reserved_mem = (9.5 + parallel_size / 2) * 1024
                elif gpu_mem < 100 * 1024:
                    # H20. (chunked_prefill_size 8k, cuda_graph_max_bs 256)
                    reserved_mem = (12 + parallel_size / 2) * 1024
                elif gpu_mem < 160 * 1024:
                    # H200. (chunked_prefill_size 8k, cuda_graph_max_bs 256)
                    reserved_mem = (12 + parallel_size / 2) * 1024
306
                else:
307
308
309
                    # B200, MI300. (chunked_prefill_size 16k, cuda_graph_max_bs 512)
                    reserved_mem = 32 * 1024

310
                if self.speculative_algorithm is not None:
311
312
313
314
315
316
                    # draft model and larger cuda graph buffers
                    reserved_mem += 2 * 1024
                if self.enable_dp_attention:
                    reserved_mem += 4 * 1024

                self.mem_fraction_static = round((gpu_mem - reserved_mem) / gpu_mem, 3)
317
            else:
318
                self.mem_fraction_static = 0.88
319

320
321
        # Set chunked prefill size, which depends on the gpu memory capacity
        if self.chunked_prefill_size is None:
322
323
324
325
326
327
328
            if gpu_mem is not None:
                if gpu_mem < 35 * 1024:  # A10, L40, 4090
                    self.chunked_prefill_size = 2048
                elif gpu_mem < 160 * 1024:  # H100, H200, A100, H20
                    self.chunked_prefill_size = 8192
                else:  # B200, MI300
                    self.chunked_prefill_size = 16384
329
            else:
330
                self.chunked_prefill_size = 4096
Lianmin Zheng's avatar
Lianmin Zheng committed
331
332
        assert self.chunked_prefill_size % self.page_size == 0

333
334
335
336
337
338
339
340
341
        # Set cuda graph max batch size
        if self.cuda_graph_max_bs is None:
            # Based on detailed statistics, when serving TP1/TP2 models on lower-end GPUs with HBM<25G, you can either disable cuda graph or set `cuda_graph_max_bs` to a very small value to reduce the memory overhead of creating cuda graphs, with almost no impact on performance. However, when serving models with TP4 or TP8, we need to enable cuda graph to maintain high performance. In this case, we can set `cuda_graph_max_bs` to 80 (half of the default value 160) to reduce the memory overhead of creating cuda graphs. Looking at the logs from TP4 serving of qwen2-72b, a value of 80 is sufficient and can reduce the memory overhead of creating cuda graphs on lower-end GPUs compared to the original 160, avoiding OOM issues.
            if gpu_mem is not None and gpu_mem < 35 * 1024:
                if self.tp_size < 4:
                    self.cuda_graph_max_bs = 8
                else:
                    self.cuda_graph_max_bs = 80

342
343
344
        assert self.moe_dense_tp_size in {
            1,
            None,
Lianmin Zheng's avatar
Lianmin Zheng committed
345
        }, "moe_dense_tp_size only support 1 and None currently"
346

347
        if self.attention_backend == "flashmla":
348
349
350
351
            logger.warning(
                "FlashMLA only supports a page_size of 64, change page_size to 64."
            )
            self.page_size = 64
Lianmin Zheng's avatar
Lianmin Zheng committed
352

353
354
355
356
357
358
        if self.attention_backend == "cutlass_mla":
            logger.warning(
                "Cutlass MLA only supports a page_size of 128, change page_size to 128."
            )
            self.page_size = 128

359
        # Set kernel backends for hpu device
360
361
362
363
        if self.device == "hpu":
            self.attention_backend = "torch_native"
            self.sampling_backend = "pytorch"

Lianmin Zheng's avatar
Lianmin Zheng committed
364
        # Set kernel backends
365
366
367
368
369
        if self.device == "cpu":
            if self.attention_backend is None:
                self.attention_backend = "intel_amx"
            self.sampling_backend = "pytorch"

370
        if self.sampling_backend is None:
371
372
373
374
375
            self.sampling_backend = (
                "flashinfer" if is_flashinfer_available() else "pytorch"
            )

        if self.attention_backend == "torch_native":
376
            logger.warning(
377
378
379
                "Cuda graph is disabled because of using torch native attention backend"
            )
            self.disable_cuda_graph = True
380

381
382
383
        # Choose grammar backend
        if self.grammar_backend is None:
            self.grammar_backend = "xgrammar"
384

385
        # Data parallelism attention
Ke Bao's avatar
Ke Bao committed
386
        if self.enable_dp_attention:
387
            self.schedule_conservativeness = self.schedule_conservativeness * 0.3
Lianmin Zheng's avatar
Lianmin Zheng committed
388
389
390
391
392
            assert (
                self.dp_size > 1
            ), "Please set a dp-size > 1. You can use 1 < dp-size <= tp-size "
            assert self.tp_size % self.dp_size == 0
            self.chunked_prefill_size = self.chunked_prefill_size // self.dp_size
393
            logger.warning(
394
                f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE kernel issues. "
395
            )
396

397
398
399
400
401
        if self.enable_dp_lm_head:
            assert (
                self.enable_dp_attention
            ), "Please enable dp attention when setting enable_dp_attention. "

402
403
        # DeepEP MoE
        if self.enable_deepep_moe:
404
405
406
407
            if self.deepep_mode == "auto":
                assert (
                    not self.enable_dp_attention
                ), "DeepEP MoE `auto` mode is not supported with DP Attention."
408
409
410
            if self.deepep_mode == "normal":
                logger.warning("Cuda graph is disabled because deepep_mode=`normal`")
                self.disable_cuda_graph = True
411
            self.ep_size = self.tp_size
Lianmin Zheng's avatar
Lianmin Zheng committed
412
            logger.warning(
413
414
                f"DeepEP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
            )
415

416
417
418
419
420
421
        if self.pp_size > 1:
            self.disable_overlap_schedule = True
            logger.warning(
                "Pipeline parallelism is incompatible with overlap schedule."
            )

422
423
424
        if self.enable_eplb and (self.expert_distribution_recorder_mode is None):
            self.expert_distribution_recorder_mode = "stat"
            logger.info(
425
                "EPLB is enabled. The expert_distribution_recorder_mode is automatically set."
426
427
428
429
430
431
432
            )

        if (self.enable_eplb or (self.init_expert_location is not None)) and (
            self.ep_dispatch_algorithm is None
        ):
            self.ep_dispatch_algorithm = "static"
            logger.info(
433
                "EPLB is enabled or init_expert_location is provided. ep_dispatch_algorithm is configured."
434
435
436
437
438
439
440
            )

        if self.enable_expert_distribution_metrics and (
            self.expert_distribution_recorder_mode is None
        ):
            self.expert_distribution_recorder_mode = "stat"

441
        if self.expert_distribution_recorder_buffer_size is None:
442
443
            if (x := self.eplb_rebalance_num_iterations) is not None:
                self.expert_distribution_recorder_buffer_size = x
444
445
446
            elif self.expert_distribution_recorder_mode is not None:
                self.expert_distribution_recorder_buffer_size = 1000

447
        # Speculative Decoding
448
449
450
451
        if self.speculative_algorithm == "NEXTN":
            # NEXTN shares the same implementation of EAGLE
            self.speculative_algorithm = "EAGLE"

Lianmin Zheng's avatar
Lianmin Zheng committed
452
        if self.speculative_algorithm in ("EAGLE", "EAGLE3"):
453
            if self.max_running_requests is None:
454
                self.max_running_requests = 48
455
            self.disable_overlap_schedule = True
Lianmin Zheng's avatar
Lianmin Zheng committed
456
            logger.warning(
457
                "Overlap scheduler is disabled because of using "
458
                "eagle speculative decoding."
459
            )
460
461
462
463
464
465
            if self.enable_mixed_chunk:
                self.enable_mixed_chunk = False
                logger.warning(
                    "Mixed chunked prefill is disabled because of using "
                    "eagle speculative decoding."
                )
466

467
468
469
            model_arch = get_model_arch(self)

            # Auto set draft_model_path DeepSeek-V3/R1
470
471
472
473
474
475
476
            if model_arch == "DeepseekV3ForCausalLM":
                if self.speculative_draft_model_path is None:
                    self.speculative_draft_model_path = self.model_path
                else:
                    logger.warning(
                        "DeepSeek MTP does not require setting speculative_draft_model_path."
                    )
477

478
479
480
481
482
483
484
485
486
487
            # Auto choose parameters
            if self.speculative_num_steps is None:
                assert (
                    self.speculative_eagle_topk is None
                    and self.speculative_num_draft_tokens is None
                )
                (
                    self.speculative_num_steps,
                    self.speculative_eagle_topk,
                    self.speculative_num_draft_tokens,
488
                ) = auto_choose_speculative_params(self)
489

490
491
492
493
            if (
                self.speculative_eagle_topk == 1
                and self.speculative_num_draft_tokens != self.speculative_num_steps + 1
            ):
Lianmin Zheng's avatar
Lianmin Zheng committed
494
                logger.warning(
495
496
497
                    "speculative_num_draft_tokens is adjusted to speculative_num_steps + 1 when speculative_eagle_topk == 1"
                )
                self.speculative_num_draft_tokens = self.speculative_num_steps + 1
498

499
            # The token generated from the verify step is counted.
500
            # If sepculative_num_steps >= speculative_num_draft_tokens, the additional tokens will definitely be discarded.
501
            # assert self.speculative_num_steps < self.speculative_num_draft_tokens
502

503
504
505
506
507
508
        # GGUF
        if (
            self.load_format == "auto" or self.load_format == "gguf"
        ) and check_gguf_file(self.model_path):
            self.quantization = self.load_format = "gguf"

509
510
511
        if is_remote_url(self.model_path):
            self.load_format = "remote"

512
513
514
515
        # AMD-specific Triton attention KV splits default number
        if is_hip():
            self.triton_attention_num_kv_splits = 16

Byron Hsu's avatar
Byron Hsu committed
516
        # PD disaggregation
Byron Hsu's avatar
Byron Hsu committed
517
518
519
520
521
522
523
524
        if self.disaggregation_mode == "decode":
            assert (
                self.disaggregation_decode_tp is None
            ), "Cannot set --disaggregation-decode-tp for the decode engine."
            assert (
                self.disaggregation_decode_dp is None
            ), "Cannot set --disaggregation-decode-dp for the decode engine."

Byron Hsu's avatar
Byron Hsu committed
525
            self.disable_radix_cache = True
526
            logger.warning("KV cache is forced as chunk cache for decode server")
Byron Hsu's avatar
Byron Hsu committed
527
528
529
530
531
532
533
534
535
536
537
        elif self.disaggregation_mode == "prefill":
            if self.disaggregation_decode_tp is None:
                self.disaggregation_decode_tp = self.tp_size
            if self.disaggregation_decode_dp is None:
                self.disaggregation_decode_dp = self.dp_size

            self.disaggregation_prefill_pp = self.pp_size
            self.validate_disagg_tp_size(self.tp_size, self.disaggregation_decode_tp)

            self.disable_cuda_graph = True
            logger.warning("Cuda graph is disabled for prefill server")
Byron Hsu's avatar
Byron Hsu committed
538

539
540
541
        os.environ["SGLANG_ENABLE_TORCH_COMPILE"] = (
            "1" if self.enable_torch_compile else "0"
        )
542
543
544
545
        # Set env var before grammar backends init
        os.environ["SGLANG_DISABLE_OUTLINES_DISK_CACHE"] = (
            "1" if self.disable_outlines_disk_cache else "0"
        )
546

547
548
549
        if self.custom_weight_loader is None:
            self.custom_weight_loader = []

Byron Hsu's avatar
Byron Hsu committed
550
551
552
553
554
555
556
557
    def validate_disagg_tp_size(self, prefill_tp: int, decode_tp: int):
        larger_tp = max(decode_tp, prefill_tp)
        smaller_tp = min(decode_tp, prefill_tp)
        assert larger_tp % smaller_tp == 0, (
            "Different tp size is supported only when one tp is multiple of the other. "
            f"decode_tp={decode_tp}, prefill_tp={prefill_tp}"
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
558
559
    @staticmethod
    def add_cli_args(parser: argparse.ArgumentParser):
560
        # Model and port args
Lianmin Zheng's avatar
Lianmin Zheng committed
561
562
        parser.add_argument(
            "--model-path",
563
            "--model",
Lianmin Zheng's avatar
Lianmin Zheng committed
564
565
566
567
568
569
570
571
572
573
            type=str,
            help="The path of the model weights. This can be a local folder or a Hugging Face repo ID.",
            required=True,
        )
        parser.add_argument(
            "--tokenizer-path",
            type=str,
            default=ServerArgs.tokenizer_path,
            help="The path of the tokenizer.",
        )
Yuanhan Zhang's avatar
Yuanhan Zhang committed
574
        parser.add_argument(
575
576
577
578
            "--host",
            type=str,
            default=ServerArgs.host,
            help="The host of the HTTP server.",
Yuanhan Zhang's avatar
Yuanhan Zhang committed
579
580
        )
        parser.add_argument(
581
582
583
584
            "--port",
            type=int,
            default=ServerArgs.port,
            help="The port of the HTTP server.",
Yuanhan Zhang's avatar
Yuanhan Zhang committed
585
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
586
587
588
589
590
591
592
593
594
        parser.add_argument(
            "--tokenizer-mode",
            type=str,
            default=ServerArgs.tokenizer_mode,
            choices=["auto", "slow"],
            help="Tokenizer mode. 'auto' will use the fast "
            "tokenizer if available, and 'slow' will "
            "always use the slow tokenizer.",
        )
595
596
597
        parser.add_argument(
            "--skip-tokenizer-init",
            action="store_true",
598
            help="If set, skip init tokenizer and pass input_ids in generate request.",
599
        )
600
        parser.add_argument(
Lianmin Zheng's avatar
Lianmin Zheng committed
601
602
603
            "--load-format",
            type=str,
            default=ServerArgs.load_format,
604
605
606
607
608
609
            choices=[
                "auto",
                "pt",
                "safetensors",
                "npcache",
                "dummy",
610
                "sharded_state",
611
612
                "gguf",
                "bitsandbytes",
613
                "layered",
614
                "remote",
615
            ],
Lianmin Zheng's avatar
Lianmin Zheng committed
616
617
618
619
620
621
622
623
624
            help="The format of the model weights to load. "
            '"auto" will try to load the weights in the safetensors format '
            "and fall back to the pytorch bin format if safetensors format "
            "is not available. "
            '"pt" will load the weights in the pytorch bin format. '
            '"safetensors" will load the weights in the safetensors format. '
            '"npcache" will load the weights in pytorch format and store '
            "a numpy cache to speed up the loading. "
            '"dummy" will initialize the weights with random values, '
625
            "which is mainly for profiling."
626
627
            '"gguf" will load the weights in the gguf format. '
            '"bitsandbytes" will load the weights using bitsandbytes '
628
629
630
631
            "quantization."
            '"layered" loads weights layer by layer so that one can quantize a '
            "layer before loading another to make the peak memory envelope "
            "smaller.",
Lianmin Zheng's avatar
Lianmin Zheng committed
632
        )
633
634
635
636
637
638
639
        parser.add_argument(
            "--model-loader-extra-config",
            type=str,
            help="Extra config for model loader. "
            "This will be passed to the model loader corresponding to the chosen load_format.",
            default=ServerArgs.model_loader_extra_config,
        )
640
641
642
643
644
        parser.add_argument(
            "--trust-remote-code",
            action="store_true",
            help="Whether or not to allow for custom models defined on the Hub in their own modeling files.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
645
        parser.add_argument(
Lianmin Zheng's avatar
Lianmin Zheng committed
646
            "--dtype",
Cody Yu's avatar
Cody Yu committed
647
            type=str,
Lianmin Zheng's avatar
Lianmin Zheng committed
648
            default=ServerArgs.dtype,
Ying Sheng's avatar
Ying Sheng committed
649
650
            choices=["auto", "half", "float16", "bfloat16", "float", "float32"],
            help="Data type for model weights and activations.\n\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
651
            '* "auto" will use FP16 precision for FP32 and FP16 models, and '
Ying Sheng's avatar
Ying Sheng committed
652
            "BF16 precision for BF16 models.\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
653
654
655
656
            '* "half" for FP16. Recommended for AWQ quantization.\n'
            '* "float16" is the same as "half".\n'
            '* "bfloat16" for a balance between precision and range.\n'
            '* "float" is shorthand for FP32 precision.\n'
Ying Sheng's avatar
Ying Sheng committed
657
658
            '* "float32" for FP32 precision.',
        )
659
660
661
662
        parser.add_argument(
            "--kv-cache-dtype",
            type=str,
            default=ServerArgs.kv_cache_dtype,
bjmsong's avatar
bjmsong committed
663
664
665
            choices=["auto", "fp8_e5m2", "fp8_e4m3"],
            help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" and "fp8_e4m3" is supported for CUDA 11.8+.',
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
666
667
668
669
        parser.add_argument(
            "--quantization",
            type=str,
            default=ServerArgs.quantization,
Ying Sheng's avatar
Ying Sheng committed
670
671
672
673
674
675
            choices=[
                "awq",
                "fp8",
                "gptq",
                "marlin",
                "gptq_marlin",
Ying Sheng's avatar
Ying Sheng committed
676
                "awq_marlin",
Ying Sheng's avatar
Ying Sheng committed
677
                "bitsandbytes",
678
                "gguf",
679
                "modelopt",
680
                "modelopt_fp4",
681
                "w8a8_int8",
HandH1998's avatar
HandH1998 committed
682
                "w8a8_fp8",
AniZpZ's avatar
AniZpZ committed
683
                "moe_wna16",
HandH1998's avatar
HandH1998 committed
684
                "qoq",
Ying Sheng's avatar
Ying Sheng committed
685
            ],
Lianmin Zheng's avatar
Lianmin Zheng committed
686
687
            help="The quantization method.",
        )
688
689
690
691
692
693
694
695
696
        parser.add_argument(
            "--quantization-param-path",
            type=nullable_str,
            default=None,
            help="Path to the JSON file containing the KV cache "
            "scaling factors. This should generally be supplied, when "
            "KV cache dtype is FP8. Otherwise, KV cache scaling factors "
            "default to 1.0, which may cause accuracy issues. ",
        )
697
698
699
700
701
702
703
704
705
        parser.add_argument(
            "--context-length",
            type=int,
            default=ServerArgs.context_length,
            help="The model's maximum context length. Defaults to None (will use the value from the model's config.json instead).",
        )
        parser.add_argument(
            "--device",
            type=str,
706
            default=ServerArgs.device,
707
            help="The device to use ('cuda', 'xpu', 'hpu', 'npu', 'cpu'). Defaults to auto-detection if not specified.",
708
        )
709
710
711
712
713
714
        parser.add_argument(
            "--served-model-name",
            type=str,
            default=ServerArgs.served_model_name,
            help="Override the model name returned by the v1/models endpoint in OpenAI API server.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
715
716
717
718
719
720
        parser.add_argument(
            "--chat-template",
            type=str,
            default=ServerArgs.chat_template,
            help="The buliltin chat template name or the path of the chat template file. This is only used for OpenAI-compatible API server.",
        )
721
722
723
724
725
726
        parser.add_argument(
            "--completion-template",
            type=str,
            default=ServerArgs.completion_template,
            help="The buliltin completion template name or the path of the completion template file. This is only used for OpenAI-compatible API server. only for code completion currently.",
        )
727
728
729
730
731
        parser.add_argument(
            "--is-embedding",
            action="store_true",
            help="Whether to use a CausalLM as an embedding model.",
        )
732
733
734
735
736
737
        parser.add_argument(
            "--enable-multimodal",
            default=ServerArgs.enable_multimodal,
            action="store_true",
            help="Enable the multimodal functionality for the served model. If the model being served is not multimodal, nothing will happen",
        )
738
739
740
741
742
743
744
745
        parser.add_argument(
            "--revision",
            type=str,
            default=None,
            help="The specific model version to use. It can be a branch "
            "name, a tag name, or a commit id. If unspecified, will use "
            "the default version.",
        )
746
747
748
749
750
751
752
753
754
755
756
757
        parser.add_argument(
            "--impl",
            type=str,
            default=ServerArgs.impl,
            help="Which implementation of the model to use.\n\n"
            '* "auto" will try to use the SGLang implementation if it exists '
            "and fall back to the Transformers implementation if no SGLang "
            "implementation is available.\n"
            '* "sglang" will use the SGLang model implementation.\n'
            '* "transformers" will use the Transformers model '
            "implementation.\n",
        )
758

759
        # Memory and scheduling
Lianmin Zheng's avatar
Lianmin Zheng committed
760
761
762
763
        parser.add_argument(
            "--mem-fraction-static",
            type=float,
            default=ServerArgs.mem_fraction_static,
764
            help="The fraction of the memory used for static allocation (model weights and KV cache memory pool). Use a smaller value if you see out-of-memory errors.",
Lianmin Zheng's avatar
Lianmin Zheng committed
765
        )
766
767
768
769
770
771
        parser.add_argument(
            "--max-running-requests",
            type=int,
            default=ServerArgs.max_running_requests,
            help="The maximum number of running requests.",
        )
772
773
774
775
        parser.add_argument(
            "--max-total-tokens",
            type=int,
            default=ServerArgs.max_total_tokens,
776
777
            help="The maximum number of tokens in the memory pool. If not specified, it will be automatically calculated based on the memory usage fraction. "
            "This option is typically used for development and debugging purposes.",
778
        )
779
780
781
782
        parser.add_argument(
            "--chunked-prefill-size",
            type=int,
            default=ServerArgs.chunked_prefill_size,
783
            help="The maximum number of tokens in a chunk for the chunked prefill. Setting this to -1 means disabling chunked prefill.",
784
785
786
787
788
789
790
        )
        parser.add_argument(
            "--max-prefill-tokens",
            type=int,
            default=ServerArgs.max_prefill_tokens,
            help="The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model's maximum context length.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
791
        parser.add_argument(
792
            "--schedule-policy",
Lianmin Zheng's avatar
Lianmin Zheng committed
793
            type=str,
794
            default=ServerArgs.schedule_policy,
Liangsheng Yin's avatar
Liangsheng Yin committed
795
            choices=["lpm", "random", "fcfs", "dfs-weight"],
796
            help="The scheduling policy of the requests.",
Lianmin Zheng's avatar
Lianmin Zheng committed
797
        )
798
799
800
801
        parser.add_argument(
            "--schedule-conservativeness",
            type=float,
            default=ServerArgs.schedule_conservativeness,
802
            help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see requests being retracted frequently.",
803
        )
804
805
806
807
        parser.add_argument(
            "--cpu-offload-gb",
            type=int,
            default=ServerArgs.cpu_offload_gb,
808
            help="How many GBs of RAM to reserve for CPU offloading.",
809
        )
810
811
812
813
814
815
        parser.add_argument(
            "--page-size",
            type=int,
            default=ServerArgs.page_size,
            help="The number of tokens in a page.",
        )
tarinkk's avatar
tarinkk committed
816
817
818
819
820
821
822
823
824
825
826
827
        parser.add_argument(
            "--hybrid-kvcache-ratio",
            nargs="?",
            const=0.5,
            type=float,
            default=ServerArgs.hybrid_kvcache_ratio,
            help=(
                "Mix ratio in [0,1] between uniform and hybrid kv buffers "
                "(0.0 = pure uniform: swa_size / full_size = 1)"
                "(1.0 = pure hybrid: swa_size / full_size = local_attention_size / context_length)"
            ),
        )
828

829
        # Other runtime options
Lianmin Zheng's avatar
Lianmin Zheng committed
830
        parser.add_argument(
831
            "--tensor-parallel-size",
Lianmin Zheng's avatar
Lianmin Zheng committed
832
            "--tp-size",
Lianmin Zheng's avatar
Lianmin Zheng committed
833
            type=int,
Lianmin Zheng's avatar
Lianmin Zheng committed
834
            default=ServerArgs.tp_size,
835
            help="The tensor parallelism size.",
836
        )
837
838
839
840
841
842
843
844
845
846
847
848
849
        parser.add_argument(
            "--pipeline-parallel-size",
            "--pp-size",
            type=int,
            default=ServerArgs.pp_size,
            help="The pipeline parallelism size.",
        )
        parser.add_argument(
            "--max-micro-batch-size",
            type=int,
            default=ServerArgs.max_micro_batch_size,
            help="The maximum micro batch size in pipeline parallelism.",
        )
850
851
852
        parser.add_argument(
            "--stream-interval",
            type=int,
Lianmin Zheng's avatar
Lianmin Zheng committed
853
            default=ServerArgs.stream_interval,
854
            help="The interval (or buffer size) for streaming in terms of the token length. A smaller value makes streaming smoother, while a larger value makes the throughput higher",
855
        )
856
857
858
859
860
        parser.add_argument(
            "--stream-output",
            action="store_true",
            help="Whether to output as a sequence of disjoint segments.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
861
862
863
864
        parser.add_argument(
            "--random-seed",
            type=int,
            default=ServerArgs.random_seed,
865
            help="The random seed.",
Lianmin Zheng's avatar
Lianmin Zheng committed
866
        )
867
868
869
870
871
872
        parser.add_argument(
            "--constrained-json-whitespace-pattern",
            type=str,
            default=ServerArgs.constrained_json_whitespace_pattern,
            help=r"Regex pattern for syntactic whitespaces allowed in JSON constrained output. For example, to allow the model generate consecutive whitespaces, set the pattern to [\n\t ]*",
        )
873
874
875
876
877
878
        parser.add_argument(
            "--watchdog-timeout",
            type=float,
            default=ServerArgs.watchdog_timeout,
            help="Set watchdog timeout in seconds. If a forward batch takes longer than this, the server will crash to prevent hanging.",
        )
879
880
881
882
883
884
        parser.add_argument(
            "--dist-timeout",
            type=int,
            default=ServerArgs.dist_timeout,
            help="Set timeout for torch.distributed initialization.",
        )
885
886
887
888
        parser.add_argument(
            "--download-dir",
            type=str,
            default=ServerArgs.download_dir,
889
            help="Model download directory for huggingface.",
890
        )
891
892
893
894
895
896
        parser.add_argument(
            "--base-gpu-id",
            type=int,
            default=ServerArgs.base_gpu_id,
            help="The base GPU ID to start allocating GPUs from. Useful when running multiple instances on the same machine.",
        )
897
898
899
900
901
902
        parser.add_argument(
            "--gpu-id-step",
            type=int,
            default=ServerArgs.gpu_id_step,
            help="The delta between consecutive GPU IDs that are used. For example, setting it to 2 will use GPU 0,2,4,...",
        )
903
904
905
906
907
        parser.add_argument(
            "--sleep-on-idle",
            action="store_true",
            help="Reduce CPU usage when sglang is idle.",
        )
908
909

        # Logging
Lianmin Zheng's avatar
Lianmin Zheng committed
910
911
912
913
        parser.add_argument(
            "--log-level",
            type=str,
            default=ServerArgs.log_level,
914
            help="The logging level of all loggers.",
Lianmin Zheng's avatar
Lianmin Zheng committed
915
        )
916
        parser.add_argument(
917
918
919
920
            "--log-level-http",
            type=str,
            default=ServerArgs.log_level_http,
            help="The logging level of HTTP server. If not set, reuse --log-level by default.",
921
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
922
        parser.add_argument(
923
            "--log-requests",
Lianmin Zheng's avatar
Lianmin Zheng committed
924
            action="store_true",
925
926
927
928
929
930
            help="Log metadata, inputs, outputs of all requests. The verbosity is decided by --log-requests-level",
        )
        parser.add_argument(
            "--log-requests-level",
            type=int,
            default=0,
931
932
933
934
935
936
937
938
            help="0: Log metadata (no sampling parameters). 1: Log metadata and sampling parameters. 2: Log metadata, sampling parameters and partial input/output. 3: Log every input/output.",
            choices=[0, 1, 2, 3],
        )
        parser.add_argument(
            "--crash-dump-folder",
            type=str,
            default=ServerArgs.crash_dump_folder,
            help="Folder path to dump requests from the last 5 min before a crash (if any). If not specified, crash dumping is disabled.",
Lianmin Zheng's avatar
Lianmin Zheng committed
939
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
940
941
942
        parser.add_argument(
            "--show-time-cost",
            action="store_true",
Ying Sheng's avatar
Ying Sheng committed
943
            help="Show time cost of custom marks.",
Lianmin Zheng's avatar
Lianmin Zheng committed
944
        )
945
946
947
948
949
        parser.add_argument(
            "--enable-metrics",
            action="store_true",
            help="Enable log prometheus metrics.",
        )
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
        parser.add_argument(
            "--bucket-time-to-first-token",
            type=float,
            nargs="+",
            default=ServerArgs.bucket_time_to_first_token,
            help="The buckets of time to first token, specified as a list of floats.",
        )
        parser.add_argument(
            "--bucket-inter-token-latency",
            type=float,
            nargs="+",
            default=ServerArgs.bucket_inter_token_latency,
            help="The buckets of inter-token latency, specified as a list of floats.",
        )
        parser.add_argument(
            "--bucket-e2e-request-latency",
            type=float,
            nargs="+",
            default=ServerArgs.bucket_e2e_request_latency,
            help="The buckets of end-to-end request latency, specified as a list of floats.",
        )
        parser.add_argument(
            "--collect-tokens-histogram",
            action="store_true",
            default=ServerArgs.collect_tokens_histogram,
            help="Collect prompt/generation tokens histogram.",
        )
977
978
979
980
981
982
        parser.add_argument(
            "--kv-events-config",
            type=str,
            default=None,
            help="Config in json format for NVIDIA dynamo KV event publishing. Publishing will be enabled if this flag is used.",
        )
983
984
985
986
        parser.add_argument(
            "--decode-log-interval",
            type=int,
            default=ServerArgs.decode_log_interval,
987
            help="The log interval of decode batch.",
988
        )
989
990
991
992
993
994
        parser.add_argument(
            "--enable-request-time-stats-logging",
            action="store_true",
            default=ServerArgs.enable_request_time_stats_logging,
            help="Enable per request time stats logging",
        )
995

996
        # API related
Liangsheng Yin's avatar
Liangsheng Yin committed
997
998
999
1000
        parser.add_argument(
            "--api-key",
            type=str,
            default=ServerArgs.api_key,
1001
            help="Set API key of the server. It is also used in the OpenAI API compatible server.",
Liangsheng Yin's avatar
Liangsheng Yin committed
1002
        )
1003
        parser.add_argument(
1004
            "--file-storage-path",
1005
            type=str,
1006
            default=ServerArgs.file_storage_path,
1007
1008
            help="The path of the file storage in backend.",
        )
1009
1010
1011
1012
1013
        parser.add_argument(
            "--enable-cache-report",
            action="store_true",
            help="Return number of cached tokens in usage.prompt_tokens_details for each openai request.",
        )
Xihuai Wang's avatar
Xihuai Wang committed
1014
1015
1016
1017
1018
1019
1020
        parser.add_argument(
            "--reasoning-parser",
            type=str,
            choices=list(ReasoningParser.DetectorMap.keys()),
            default=ServerArgs.reasoning_parser,
            help=f"Specify the parser for reasoning models, supported parsers are: {list(ReasoningParser.DetectorMap.keys())}.",
        )
1021
1022
1023
1024
1025
1026
1027
        parser.add_argument(
            "--tool-call-parser",
            type=str,
            choices=["qwen25", "mistral", "llama3", "deepseekv3", "pythonic"],
            default=ServerArgs.tool_call_parser,
            help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', 'llama3', 'deepseekv3', and 'pythonic'.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1028

1029
1030
        # Data parallelism
        parser.add_argument(
1031
            "--data-parallel-size",
1032
1033
1034
            "--dp-size",
            type=int,
            default=ServerArgs.dp_size,
1035
            help="The data parallelism size.",
1036
1037
1038
1039
1040
        )
        parser.add_argument(
            "--load-balance-method",
            type=str,
            default=ServerArgs.load_balance_method,
1041
            help="The load balancing strategy for data parallelism.",
1042
1043
1044
1045
1046
            choices=[
                "round_robin",
                "shortest_queue",
            ],
        )
1047

1048
        # Multi-node distributed serving
1049
        parser.add_argument(
1050
            "--dist-init-addr",
1051
            "--nccl-init-addr",  # For backward compatibility. This will be removed in the future.
1052
            type=str,
1053
            help="The host address for initializing distributed backend (e.g., `192.168.0.2:25000`).",
1054
1055
        )
        parser.add_argument(
Liangsheng Yin's avatar
Liangsheng Yin committed
1056
            "--nnodes", type=int, default=ServerArgs.nnodes, help="The number of nodes."
1057
        )
1058
1059
1060
        parser.add_argument(
            "--node-rank", type=int, default=ServerArgs.node_rank, help="The node rank."
        )
1061

Lianmin Zheng's avatar
Lianmin Zheng committed
1062
1063
1064
1065
1066
1067
1068
        # Model override args
        parser.add_argument(
            "--json-model-override-args",
            type=str,
            help="A dictionary in JSON string format used to override default model configurations.",
            default=ServerArgs.json_model_override_args,
        )
1069
1070
1071
1072
1073
        parser.add_argument(
            "--preferred-sampling-params",
            type=str,
            help="json-formatted sampling settings that will be returned in /get_model_info",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1074

1075
1076
1077
1078
1079
1080
1081
        # LoRA
        parser.add_argument(
            "--lora-paths",
            type=str,
            nargs="*",
            default=None,
            action=LoRAPathAction,
1082
            help="The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}.",
1083
1084
1085
1086
1087
        )
        parser.add_argument(
            "--max-loras-per-batch",
            type=int,
            default=8,
1088
1089
1090
1091
1092
1093
1094
            help="Maximum number of adapters for a running batch, include base-only request.",
        )
        parser.add_argument(
            "--lora-backend",
            type=str,
            default="triton",
            help="Choose the kernel backend for multi-LoRA serving.",
1095
1096
1097
        )

        # Kernel backend
1098
1099
1100
        parser.add_argument(
            "--attention-backend",
            type=str,
1101
            choices=[
1102
                "aiter",
1103
                "cutlass_mla",
1104
                "fa3",
1105
                "flashinfer",
1106
                "flashmla",
1107
                "intel_amx",
1108
1109
                "torch_native",
                "triton",
1110
            ],
1111
1112
1113
            default=ServerArgs.attention_backend,
            help="Choose the kernels for attention layers.",
        )
1114
1115
1116
1117
1118
1119
1120
        parser.add_argument(
            "--sampling-backend",
            type=str,
            choices=["flashinfer", "pytorch"],
            default=ServerArgs.sampling_backend,
            help="Choose the kernels for sampling layers.",
        )
1121
1122
1123
        parser.add_argument(
            "--grammar-backend",
            type=str,
1124
            choices=["xgrammar", "outlines", "llguidance", "none"],
1125
            default=ServerArgs.grammar_backend,
Lianmin Zheng's avatar
Lianmin Zheng committed
1126
            help="Choose the backend for grammar-guided decoding.",
1127
        )
1128

1129
1130
1131
1132
        # Speculative decoding
        parser.add_argument(
            "--speculative-algorithm",
            type=str,
James Liu's avatar
James Liu committed
1133
            choices=["EAGLE", "EAGLE3", "NEXTN"],
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
            help="Speculative algorithm.",
        )
        parser.add_argument(
            "--speculative-draft-model-path",
            type=str,
            help="The path of the draft model weights. This can be a local folder or a Hugging Face repo ID.",
        )
        parser.add_argument(
            "--speculative-num-steps",
            type=int,
            help="The number of steps sampled from draft model in Speculative Decoding.",
            default=ServerArgs.speculative_num_steps,
        )
        parser.add_argument(
            "--speculative-eagle-topk",
            type=int,
1150
            help="The number of tokens sampled from the draft model in eagle2 each step.",
1151
1152
            default=ServerArgs.speculative_eagle_topk,
        )
1153
1154
1155
        parser.add_argument(
            "--speculative-num-draft-tokens",
            type=int,
1156
            help="The number of tokens sampled from the draft model in Speculative Decoding.",
1157
1158
            default=ServerArgs.speculative_num_draft_tokens,
        )
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
        parser.add_argument(
            "--speculative-accept-threshold-single",
            type=float,
            help="Accept a draft token if its probability in the target model is greater than this threshold.",
            default=ServerArgs.speculative_accept_threshold_single,
        )
        parser.add_argument(
            "--speculative-accept-threshold-acc",
            type=float,
            help="The accept probability of a draft token is raised from its target probability p to min(1, p / threshold_acc).",
            default=ServerArgs.speculative_accept_threshold_acc,
        )
1171
1172
1173
1174
1175
1176
        parser.add_argument(
            "--speculative-token-map",
            type=str,
            help="The path of the draft model's small vocab table.",
            default=ServerArgs.speculative_token_map,
        )
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
        parser.add_argument(
            "--mm-attention-backend",
            type=str,
            choices=["sdpa", "fa3", "triton_attn"],
            default=ServerArgs.mm_attention_backend,
            help="Set multimodal attention backend.",
        )

        # Expert parallelism
        parser.add_argument(
            "--expert-parallel-size",
            "--ep-size",
            type=int,
            default=ServerArgs.ep_size,
            help="The expert parallelism size.",
        )
        parser.add_argument(
            "--enable-ep-moe",
            action="store_true",
            help="Enabling expert parallelism for moe. The ep size is equal to the tp size.",
        )
1198
1199
1200
1201
1202
        parser.add_argument(
            "--enable-flashinfer-moe",
            action="store_true",
            help="Enable FlashInfer CUTLASS MoE backend for modelopt_fp4 quant on Blackwell. Supports MoE-EP with --enable-ep-moe",
        )
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
        parser.add_argument(
            "--enable-deepep-moe",
            action="store_true",
            help="Enabling DeepEP MoE implementation for EP MoE.",
        )
        parser.add_argument(
            "--deepep-mode",
            type=str,
            choices=["normal", "low_latency", "auto"],
            default="auto",
            help="Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch.",
        )
        parser.add_argument(
            "--ep-num-redundant-experts",
            type=int,
            default=ServerArgs.ep_num_redundant_experts,
            help="Allocate this number of redundant experts in expert parallel.",
        )
        parser.add_argument(
            "--ep-dispatch-algorithm",
            type=str,
            default=ServerArgs.ep_dispatch_algorithm,
            help="The algorithm to choose ranks for redundant experts in expert parallel.",
        )
        parser.add_argument(
            "--init-expert-location",
            type=str,
            default=ServerArgs.init_expert_location,
            help="Initial location of EP experts.",
        )
        parser.add_argument(
            "--enable-eplb",
            action="store_true",
            help="Enable EPLB algorithm",
        )
        parser.add_argument(
            "--eplb-algorithm",
            type=str,
            default=ServerArgs.eplb_algorithm,
            help="Chosen EPLB algorithm",
        )
        parser.add_argument(
            "--eplb-rebalance-num-iterations",
            type=int,
            default=ServerArgs.eplb_rebalance_num_iterations,
            help="Number of iterations to automatically trigger a EPLB re-balance.",
        )
        parser.add_argument(
            "--eplb-rebalance-layers-per-chunk",
            type=int,
            default=ServerArgs.eplb_rebalance_layers_per_chunk,
            help="Number of layers to rebalance per forward pass.",
        )
        parser.add_argument(
            "--expert-distribution-recorder-mode",
            type=str,
            default=ServerArgs.expert_distribution_recorder_mode,
            help="Mode of expert distribution recorder.",
        )
        parser.add_argument(
            "--expert-distribution-recorder-buffer-size",
            type=int,
            default=ServerArgs.expert_distribution_recorder_buffer_size,
            help="Circular buffer size of expert distribution recorder. Set to -1 to denote infinite buffer.",
        )
        parser.add_argument(
            "--enable-expert-distribution-metrics",
            action="store_true",
            help="Enable logging metrics for expert balancedness",
        )
        parser.add_argument(
            "--deepep-config",
            type=str,
            default=ServerArgs.deepep_config,
            help="Tuned DeepEP config suitable for your own cluster. It can be either a string with JSON content or a file path.",
        )
        parser.add_argument(
            "--moe-dense-tp-size",
            type=int,
            default=ServerArgs.moe_dense_tp_size,
            help="TP size for MoE dense MLP layers. This flag is useful when, with large TP size, there are errors caused by weights in MLP layers having dimension smaller than the min dimension GEMM supports.",
        )
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322

        # Double Sparsity
        parser.add_argument(
            "--enable-double-sparsity",
            action="store_true",
            help="Enable double sparsity attention",
        )
        parser.add_argument(
            "--ds-channel-config-path",
            type=str,
            default=ServerArgs.ds_channel_config_path,
            help="The path of the double sparsity channel config",
        )
        parser.add_argument(
            "--ds-heavy-channel-num",
            type=int,
            default=ServerArgs.ds_heavy_channel_num,
            help="The number of heavy channels in double sparsity attention",
        )
        parser.add_argument(
            "--ds-heavy-token-num",
            type=int,
            default=ServerArgs.ds_heavy_token_num,
            help="The number of heavy tokens in double sparsity attention",
        )
        parser.add_argument(
            "--ds-heavy-channel-type",
            type=str,
            default=ServerArgs.ds_heavy_channel_type,
            help="The type of heavy channels in double sparsity attention",
        )
        parser.add_argument(
            "--ds-sparse-decode-threshold",
            type=int,
            default=ServerArgs.ds_sparse_decode_threshold,
            help="The type of heavy channels in double sparsity attention",
        )

1323
        # Optimization/debug options
Liangsheng Yin's avatar
Liangsheng Yin committed
1324
        parser.add_argument(
Lianmin Zheng's avatar
Lianmin Zheng committed
1325
            "--disable-radix-cache",
Liangsheng Yin's avatar
Liangsheng Yin committed
1326
            action="store_true",
Ying Sheng's avatar
Ying Sheng committed
1327
            help="Disable RadixAttention for prefix caching.",
Liangsheng Yin's avatar
Liangsheng Yin committed
1328
        )
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
        parser.add_argument(
            "--cuda-graph-max-bs",
            type=int,
            default=ServerArgs.cuda_graph_max_bs,
            help="Set the maximum batch size for cuda graph. It will extend the cuda graph capture batch size to this value.",
        )
        parser.add_argument(
            "--cuda-graph-bs",
            type=int,
            nargs="+",
            help="Set the list of batch sizes for cuda graph.",
        )
1341
1342
1343
        parser.add_argument(
            "--disable-cuda-graph",
            action="store_true",
1344
            help="Disable cuda graph.",
1345
        )
1346
        parser.add_argument(
1347
1348
            "--disable-cuda-graph-padding",
            action="store_true",
1349
            help="Disable cuda graph when padding is needed. Still uses cuda graph when padding is not needed.",
1350
        )
1351
1352
1353
1354
1355
        parser.add_argument(
            "--enable-profile-cuda-graph",
            action="store_true",
            help="Enable profiling of cuda graph capture.",
        )
1356
1357
1358
1359
1360
        parser.add_argument(
            "--enable-nccl-nvls",
            action="store_true",
            help="Enable NCCL NVLS for prefill heavy requests when available.",
        )
1361
1362
1363
1364
1365
        parser.add_argument(
            "--enable-tokenizer-batch-encode",
            action="store_true",
            help="Enable batch tokenization for improved performance when processing multiple text inputs. Do not use with image inputs, pre-tokenized input_ids, or input_embeds.",
        )
1366
        parser.add_argument(
1367
            "--disable-outlines-disk-cache",
1368
            action="store_true",
1369
            help="Disable disk cache of outlines to avoid possible crashes related to file system or high concurrency.",
1370
        )
1371
1372
1373
1374
1375
        parser.add_argument(
            "--disable-custom-all-reduce",
            action="store_true",
            help="Disable the custom all-reduce kernel and fall back to NCCL.",
        )
1376
1377
1378
1379
1380
        parser.add_argument(
            "--enable-mscclpp",
            action="store_true",
            help="Enable using mscclpp for small messages for all-reduce kernel and fall back to NCCL.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1381
        parser.add_argument(
1382
            "--disable-overlap-schedule",
Lianmin Zheng's avatar
Lianmin Zheng committed
1383
            action="store_true",
1384
            help="Disable the overlap scheduler, which overlaps the CPU scheduler with GPU model worker.",
Lianmin Zheng's avatar
Lianmin Zheng committed
1385
        )
1386
1387
1388
1389
1390
        parser.add_argument(
            "--disable-overlap-cg-plan",
            action="store_true",
            help="Disable the overlap optimization for cudagraph preparation in eagle verify.",
        )
1391
1392
1393
        parser.add_argument(
            "--enable-mixed-chunk",
            action="store_true",
1394
            help="Enabling mixing prefill and decode in a batch when using chunked prefill.",
1395
        )
Ke Bao's avatar
Ke Bao committed
1396
1397
1398
        parser.add_argument(
            "--enable-dp-attention",
            action="store_true",
1399
            help="Enabling data parallelism for attention and tensor parallelism for FFN. The dp size should be equal to the tp size. Currently DeepSeek-V2 and Qwen 2/3 MoE models are supported.",
Ke Bao's avatar
Ke Bao committed
1400
        )
1401
1402
1403
1404
1405
        parser.add_argument(
            "--enable-dp-lm-head",
            action="store_true",
            help="Enable vocabulary parallel across the attention TP group to avoid all-gather across DP groups, optimizing performance under DP attention.",
        )
1406
1407
1408
1409
1410
        parser.add_argument(
            "--enable-two-batch-overlap",
            action="store_true",
            help="Enabling two micro batches to overlap.",
        )
1411
1412
1413
        parser.add_argument(
            "--enable-torch-compile",
            action="store_true",
1414
1415
            help="Optimize the model with torch.compile. Experimental feature.",
        )
1416
        parser.add_argument(
1417
            "--torch-compile-max-bs",
1418
            type=int,
1419
            default=ServerArgs.torch_compile_max_bs,
1420
1421
            help="Set the maximum batch size when using torch compile.",
        )
1422
1423
1424
1425
        parser.add_argument(
            "--torchao-config",
            type=str,
            default=ServerArgs.torchao_config,
1426
            help="Optimize the model with torchao. Experimental feature. Current choices are: int8dq, int8wo, int4wo-<group_size>, fp8wo, fp8dq-per_tensor, fp8dq-per_row",
1427
        )
1428
1429
1430
1431
1432
        parser.add_argument(
            "--enable-nan-detection",
            action="store_true",
            help="Enable the NaN detection for debugging purposes.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1433
        parser.add_argument(
1434
            "--enable-p2p-check",
Lianmin Zheng's avatar
Lianmin Zheng committed
1435
            action="store_true",
1436
            help="Enable P2P check for GPU access, otherwise the p2p access is allowed by default.",
Lianmin Zheng's avatar
Lianmin Zheng committed
1437
        )
1438
        parser.add_argument(
1439
            "--triton-attention-reduce-in-fp32",
1440
            action="store_true",
1441
            help="Cast the intermediate attention results to fp32 to avoid possible crashes related to fp16."
1442
            "This only affects Triton attention kernels.",
1443
        )
1444
1445
1446
1447
1448
1449
        parser.add_argument(
            "--triton-attention-num-kv-splits",
            type=int,
            default=ServerArgs.triton_attention_num_kv_splits,
            help="The number of KV splits in flash decoding Triton kernel. Larger value is better in longer context scenarios. The default value is 8.",
        )
1450
1451
1452
1453
1454
1455
1456
1457
        parser.add_argument(
            "--num-continuous-decode-steps",
            type=int,
            default=ServerArgs.num_continuous_decode_steps,
            help="Run multiple continuous decoding steps to reduce scheduling overhead. "
            "This can potentially increase throughput but may also increase time-to-first-token latency. "
            "The default value is 1, meaning only run one decoding step at a time.",
        )
1458
1459
1460
1461
1462
        parser.add_argument(
            "--delete-ckpt-after-loading",
            action="store_true",
            help="Delete the model checkpoint after loading the model.",
        )
1463
1464
1465
1466
1467
        parser.add_argument(
            "--enable-memory-saver",
            action="store_true",
            help="Allow saving memory using release_memory_occupation and resume_memory_occupation",
        )
1468
1469
1470
1471
1472
        parser.add_argument(
            "--allow-auto-truncate",
            action="store_true",
            help="Allow automatically truncating requests that exceed the maximum input length instead of returning an error.",
        )
1473
1474
1475
1476
1477
        parser.add_argument(
            "--enable-custom-logit-processor",
            action="store_true",
            help="Enable users to pass custom logit processors to the server (disabled by default for security)",
        )
1478
1479
1480
1481
1482
        parser.add_argument(
            "--enable-hierarchical-cache",
            action="store_true",
            help="Enable hierarchical cache",
        )
1483
1484
1485
1486
1487
1488
        parser.add_argument(
            "--hicache-ratio",
            type=float,
            default=ServerArgs.hicache_ratio,
            help="The ratio of the size of host KV cache memory pool to the size of device pool.",
        )
Zhiqiang Xie's avatar
Zhiqiang Xie committed
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
        parser.add_argument(
            "--hicache-size",
            type=int,
            default=ServerArgs.hicache_size,
            help="The size of host KV cache memory pool in gigabytes, which will override the hicache_ratio if set.",
        )
        parser.add_argument(
            "--hicache-write-policy",
            type=str,
            choices=["write_back", "write_through", "write_through_selective"],
            default=ServerArgs.hicache_write_policy,
            help="The write policy of hierarchical cache.",
        )
1502
        parser.add_argument(
1503
            "--flashinfer-mla-disable-ragged",
1504
            action="store_true",
1505
            help="Not using ragged prefill wrapper when running flashinfer mla",
1506
        )
1507
        parser.add_argument(
1508
1509
1510
            "--disable-shared-experts-fusion",
            action="store_true",
            help="Disable shared experts fusion optimization for deepseek v3/r1.",
1511
        )
1512
1513
1514
1515
1516
        parser.add_argument(
            "--disable-chunked-prefix-cache",
            action="store_true",
            help="Disable chunked prefix cache feature for deepseek, which should save overhead for short sequences.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1517
1518
1519
1520
1521
        parser.add_argument(
            "--disable-fast-image-processor",
            action="store_true",
            help="Adopt base image processor instead of fast image processor.",
        )
1522
1523
1524
1525
1526
        parser.add_argument(
            "--enable-return-hidden-states",
            action="store_true",
            help="Enable returning hidden states with responses.",
        )
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
        parser.add_argument(
            "--warmups",
            type=str,
            required=False,
            help="Specify custom warmup functions (csv) to run before server starts eg. --warmups=warmup_name1,warmup_name2 "
            "will run the functions `warmup_name1` and `warmup_name2` specified in warmup.py before the server starts listening for requests",
        )

        # Debug tensor dumps
        parser.add_argument(
            "--debug-tensor-dump-output-folder",
            type=str,
            default=ServerArgs.debug_tensor_dump_output_folder,
            help="The output folder for dumping tensors.",
        )
        parser.add_argument(
            "--debug-tensor-dump-input-file",
            type=str,
            default=ServerArgs.debug_tensor_dump_input_file,
            help="The input filename for dumping tensors",
        )
        parser.add_argument(
            "--debug-tensor-dump-inject",
            type=str,
            default=ServerArgs.debug_tensor_dump_inject,
            help="Inject the outputs from jax as the input of every layer.",
        )
1554
1555
1556
1557
1558
        parser.add_argument(
            "--debug-tensor-dump-prefill-only",
            action="store_true",
            help="Only dump the tensors for prefill requests (i.e. batch size > 1).",
        )
1559

Byron Hsu's avatar
Byron Hsu committed
1560
1561
1562
1563
1564
1565
1566
1567
        # Disaggregation
        parser.add_argument(
            "--disaggregation-mode",
            type=str,
            default="null",
            choices=["null", "prefill", "decode"],
            help='Only used for PD disaggregation. "prefill" for prefill-only server, and "decode" for decode-only server. If not specified, it is not PD disaggregated',
        )
1568
1569
1570
1571
        parser.add_argument(
            "--disaggregation-transfer-backend",
            type=str,
            default=ServerArgs.disaggregation_transfer_backend,
1572
            choices=["mooncake", "nixl"],
1573
1574
            help="The backend for disaggregation transfer. Default is mooncake.",
        )
1575
1576
1577
1578
1579
1580
        parser.add_argument(
            "--disaggregation-bootstrap-port",
            type=int,
            default=ServerArgs.disaggregation_bootstrap_port,
            help="Bootstrap server port on the prefill server. Default is 8998.",
        )
Byron Hsu's avatar
Byron Hsu committed
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
        parser.add_argument(
            "--disaggregation-decode-tp",
            type=int,
            default=ServerArgs.disaggregation_decode_tp,
            help="Decode tp size. If not set, it matches the tp size of the current engine. This is only set on the prefill server.",
        )
        parser.add_argument(
            "--disaggregation-decode-dp",
            type=int,
            default=ServerArgs.disaggregation_decode_dp,
            help="Decode dp size. If not set, it matches the dp size of the current engine. This is only set on the prefill server.",
        )
        parser.add_argument(
            "--disaggregation-prefill-pp",
            type=int,
            default=ServerArgs.disaggregation_prefill_pp,
            help="Prefill pp size. If not set, it is default to 1. This is only set on the decode server.",
        )
1599
1600
1601
1602
        parser.add_argument(
            "--disaggregation-ib-device",
            type=str,
            default=ServerArgs.disaggregation_ib_device,
1603
1604
1605
            help="The InfiniBand devices for disaggregation transfer, accepts single device (e.g., --disaggregation-ib-device mlx5_0) "
            "or multiple comma-separated devices (e.g., --disaggregation-ib-device mlx5_0,mlx5_1). "
            "Default is None, which triggers automatic device detection when mooncake backend is enabled.",
1606
        )
1607
1608
1609
1610
1611
1612
        parser.add_argument(
            "--num-reserved-decode-tokens",
            type=int,
            default=ServerArgs.num_reserved_decode_tokens,
            help="Number of decode tokens that will have memory reserved when adding new request to the running batch.",
        )
1613
1614
1615
1616
1617
1618
        parser.add_argument(
            "--pdlb-url",
            type=str,
            default=None,
            help="The URL of the PD disaggregation load balancer. If set, the prefill/decode server will register with the load balancer.",
        )
1619
1620
1621
1622
1623
1624
1625
        parser.add_argument(
            "--custom-weight-loader",
            type=str,
            nargs="*",
            default=None,
            help="The custom dataloader which used to update the model. Should be set with a valid import path, such as my_package.weight_load_func",
        )
1626
1627
1628
1629
1630
        parser.add_argument(
            "--weight-loader-disable-mmap",
            action="store_true",
            help="Disable mmap while loading weight using safetensors.",
        )
Byron Hsu's avatar
Byron Hsu committed
1631

Lianmin Zheng's avatar
Lianmin Zheng committed
1632
1633
    @classmethod
    def from_cli_args(cls, args: argparse.Namespace):
1634
        args.tp_size = args.tensor_parallel_size
1635
        args.pp_size = args.pipeline_parallel_size
1636
        args.dp_size = args.data_parallel_size
xiaobochen's avatar
xiaobochen committed
1637
        args.ep_size = args.expert_parallel_size
Lianmin Zheng's avatar
Lianmin Zheng committed
1638
1639
1640
1641
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        return cls(**{attr: getattr(args, attr) for attr in attrs})

    def url(self):
1642
        if is_valid_ipv6_address(self.host):
1643
1644
1645
            return f"http://[{self.host}]:{self.port}"
        else:
            return f"http://{self.host}:{self.port}"
Lianmin Zheng's avatar
Lianmin Zheng committed
1646

1647
1648
    def check_server_args(self):
        assert (
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
            self.tp_size * self.pp_size
        ) % self.nnodes == 0, "tp_size must be divisible by number of nodes"

        # FIXME pp constraints
        if self.pp_size > 1:
            assert (
                self.disable_overlap_schedule
                and self.speculative_algorithm is None
                and not self.enable_mixed_chunk
            ), "Pipeline parallelism is not compatible with overlap schedule, speculative decoding, mixed chunked prefill."

1660
        assert not (
1661
1662
            self.dp_size > 1 and self.nnodes != 1 and not self.enable_dp_attention
        ), "multi-node data parallel is not supported unless dp attention!"
1663
1664
1665
1666
        assert (
            self.max_loras_per_batch > 0
            # FIXME
            and (self.lora_paths is None or self.disable_radix_cache)
1667
        ), "compatibility of lora and radix attention is in progress"
1668
        assert self.base_gpu_id >= 0, "base_gpu_id must be non-negative"
1669
        assert self.gpu_id_step >= 1, "gpu_id_step must be positive"
1670

1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
        if isinstance(self.lora_paths, list):
            lora_paths = self.lora_paths
            self.lora_paths = {}
            for lora_path in lora_paths:
                if "=" in lora_path:
                    name, path = lora_path.split("=", 1)
                    self.lora_paths[name] = path
                else:
                    self.lora_paths[lora_path] = lora_path

Lianmin Zheng's avatar
Lianmin Zheng committed
1681

Lianmin Zheng's avatar
Lianmin Zheng committed
1682
def prepare_server_args(argv: List[str]) -> ServerArgs:
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
    """
    Prepare the server arguments from the command line arguments.

    Args:
        args: The command line arguments. Typically, it should be `sys.argv[1:]`
            to ensure compatibility with `parse_args` when no arguments are passed.

    Returns:
        The server arguments.
    """
    parser = argparse.ArgumentParser()
    ServerArgs.add_cli_args(parser)
Lianmin Zheng's avatar
Lianmin Zheng committed
1695
    raw_args = parser.parse_args(argv)
1696
1697
1698
1699
    server_args = ServerArgs.from_cli_args(raw_args)
    return server_args


1700
1701
1702
ZMQ_TCP_PORT_DELTA = 233


Lianmin Zheng's avatar
Lianmin Zheng committed
1703
1704
@dataclasses.dataclass
class PortArgs:
1705
1706
1707
1708
1709
1710
    # The ipc filename for tokenizer to receive inputs from detokenizer (zmq)
    tokenizer_ipc_name: str
    # The ipc filename for scheduler (rank 0) to receive inputs from tokenizer (zmq)
    scheduler_input_ipc_name: str
    # The ipc filename for detokenizer to receive inputs from scheduler (zmq)
    detokenizer_ipc_name: str
1711

1712
1713
    # The port for nccl initialization (torch.dist)
    nccl_port: int
1714

1715
1716
1717
    # The ipc filename for rpc call between Engine and Scheduler
    rpc_ipc_name: str

1718
1719
1720
    # The ipc filename for Scheduler to send metrics
    metrics_ipc_name: str

1721
    @staticmethod
1722
    def init_new(server_args, dp_rank: Optional[int] = None) -> "PortArgs":
1723
        port = server_args.port + random.randint(100, 1000)
1724
1725
1726
        while True:
            if is_port_available(port):
                break
TianYu GUO's avatar
TianYu GUO committed
1727
1728
1729
1730
            if port < 60000:
                port += 42
            else:
                port -= 43
1731

1732
1733
1734
1735
1736
1737
1738
        if not server_args.enable_dp_attention:
            # Normal case, use IPC within a single node
            return PortArgs(
                tokenizer_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
                scheduler_input_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
                detokenizer_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
                nccl_port=port,
1739
                rpc_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
1740
                metrics_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
1741
1742
1743
1744
1745
            )
        else:
            # DP attention. Use TCP + port to handle both single-node and multi-node.
            if server_args.nnodes == 1 and server_args.dist_init_addr is None:
                dist_init_addr = ("127.0.0.1", server_args.port + ZMQ_TCP_PORT_DELTA)
Vincent's avatar
Vincent committed
1746
1747
1748
            elif server_args.dist_init_addr.startswith("["):  # ipv6 address
                port_num, host = configure_ipv6(server_args.dist_init_addr)
                dist_init_addr = (host, str(port_num))
1749
1750
            else:
                dist_init_addr = server_args.dist_init_addr.split(":")
Vincent's avatar
Vincent committed
1751

1752
1753
1754
1755
1756
1757
1758
            assert (
                len(dist_init_addr) == 2
            ), "please provide --dist-init-addr as host:port of head node"

            dist_init_host, dist_init_port = dist_init_addr
            port_base = int(dist_init_port) + 1
            if dp_rank is None:
1759
                # TokenizerManager to DataParallelController
1760
                scheduler_input_port = port_base + 4
1761
            else:
1762
                scheduler_input_port = port_base + 4 + 1 + dp_rank
1763
1764
1765
1766
1767
1768

            return PortArgs(
                tokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base}",
                scheduler_input_ipc_name=f"tcp://{dist_init_host}:{scheduler_input_port}",
                detokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base + 1}",
                nccl_port=port,
1769
                rpc_ipc_name=f"tcp://{dist_init_host}:{port_base + 2}",
1770
                metrics_ipc_name=f"tcp://{dist_init_host}:{port_base + 3}",
1771
            )
1772

1773
1774
1775
1776
1777
1778
1779
1780
1781
1782

class LoRAPathAction(argparse.Action):
    def __call__(self, parser, namespace, values, option_string=None):
        setattr(namespace, self.dest, {})
        for lora_path in values:
            if "=" in lora_path:
                name, path = lora_path.split("=", 1)
                getattr(namespace, self.dest)[name] = path
            else:
                getattr(namespace, self.dest)[lora_path] = lora_path
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792


class DeprecatedAction(argparse.Action):
    def __init__(self, option_strings, dest, nargs=0, **kwargs):
        super(DeprecatedAction, self).__init__(
            option_strings, dest, nargs=nargs, **kwargs
        )

    def __call__(self, parser, namespace, values, option_string=None):
        raise ValueError(self.help)
1793
1794


1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
def get_model_arch(args: ServerArgs):
    hf_config = get_config(
        args.model_path,
        trust_remote_code=args.trust_remote_code,
        revision=args.revision,
        model_override_args=json.loads(args.json_model_override_args),
    )
    return hf_config.architectures[0]


1805
def auto_choose_speculative_params(self: ServerArgs):
1806
1807
1808
1809
1810
    """
    Automatically choose the parameters for speculative decoding.

    You can tune them on your own models and prompts with scripts/playground/bench_speculative.py
    """
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
    kwargs = {}

    hf_config = get_config(
        self.model_path,
        trust_remote_code=self.trust_remote_code,
        revision=self.revision,
        model_override_args=json.loads(self.json_model_override_args),
        **kwargs,
    )
    arch = hf_config.architectures[0]

1822
1823
1824
1825
1826
    if arch in ["LlamaForCausalLM"]:
        # The default value for llama
        return (5, 4, 8)
    elif arch in ["DeepseekV3ForCausalLM", "DeepseekV2ForCausalLM"]:
        # The default value for deepseek
1827
        return (3, 1, 4)
1828
1829
1830
1831
1832
    elif arch in ["Grok1ForCausalLM", "Grok1VForCausalLM"]:
        return (5, 4, 8)
    else:
        # The default value for all other models
        return (5, 4, 8)