server_args.py 64.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
Lianmin Zheng's avatar
Lianmin Zheng committed
14
15
"""The arguments of the server."""

Lianmin Zheng's avatar
Lianmin Zheng committed
16
17
import argparse
import dataclasses
18
import json
19
import logging
20
import os
21
import random
22
import tempfile
23
from typing import List, Literal, Optional
Lianmin Zheng's avatar
Lianmin Zheng committed
24

25
from sglang.srt.hf_transformers_utils import check_gguf_file, get_config
Xihuai Wang's avatar
Xihuai Wang committed
26
from sglang.srt.reasoning_parser import ReasoningParser
27
from sglang.srt.utils import (
Vincent's avatar
Vincent committed
28
    configure_ipv6,
29
    get_device,
Lianmin Zheng's avatar
Lianmin Zheng committed
30
    get_device_memory_capacity,
31
    is_cuda,
32
    is_flashinfer_available,
HAI's avatar
HAI committed
33
    is_hip,
34
    is_port_available,
35
    is_remote_url,
36
    is_valid_ipv6_address,
bjmsong's avatar
bjmsong committed
37
    nullable_str,
38
)
39

40
41
logger = logging.getLogger(__name__)

Lianmin Zheng's avatar
Lianmin Zheng committed
42
43
44

@dataclasses.dataclass
class ServerArgs:
Lianmin Zheng's avatar
Lianmin Zheng committed
45
    # Model and tokenizer
Lianmin Zheng's avatar
Lianmin Zheng committed
46
47
48
    model_path: str
    tokenizer_path: Optional[str] = None
    tokenizer_mode: str = "auto"
49
    skip_tokenizer_init: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
50
    load_format: str = "auto"
51
    trust_remote_code: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
52
    dtype: str = "auto"
53
    kv_cache_dtype: str = "auto"
Lianmin Zheng's avatar
Lianmin Zheng committed
54
    quantization: Optional[str] = None
Vincent's avatar
Vincent committed
55
    quantization_param_path: Optional[str] = None
56
    context_length: Optional[int] = None
57
    device: Optional[str] = None
58
    served_model_name: Optional[str] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
59
    chat_template: Optional[str] = None
60
    completion_template: Optional[str] = None
61
    is_embedding: bool = False
62
    enable_multimodal: Optional[bool] = None
63
    revision: Optional[str] = None
64
    impl: str = "auto"
Lianmin Zheng's avatar
Lianmin Zheng committed
65

66
    # Port for the HTTP server
Lianmin Zheng's avatar
Lianmin Zheng committed
67
68
69
70
    host: str = "127.0.0.1"
    port: int = 30000

    # Memory and scheduling
Lianmin Zheng's avatar
Lianmin Zheng committed
71
    mem_fraction_static: Optional[float] = None
72
    max_running_requests: Optional[int] = None
73
    max_total_tokens: Optional[int] = None
74
    chunked_prefill_size: Optional[int] = None
75
    max_prefill_tokens: int = 16384
76
    schedule_policy: str = "fcfs"
77
    schedule_conservativeness: float = 1.0
78
    cpu_offload_gb: int = 0
79
    page_size: int = 1
Lianmin Zheng's avatar
Lianmin Zheng committed
80
81
82

    # Other runtime options
    tp_size: int = 1
83
84
    pp_size: int = 1
    max_micro_batch_size: Optional[int] = None
85
    stream_interval: int = 1
86
    stream_output: bool = False
87
    random_seed: Optional[int] = None
88
    constrained_json_whitespace_pattern: Optional[str] = None
89
    watchdog_timeout: float = 300
90
    dist_timeout: Optional[int] = None  # timeout for torch.distributed
91
    download_dir: Optional[str] = None
92
    base_gpu_id: int = 0
93
    gpu_id_step: int = 1
Lianmin Zheng's avatar
Lianmin Zheng committed
94
95
96

    # Logging
    log_level: str = "info"
97
    log_level_http: Optional[str] = None
98
    log_requests: bool = False
99
    log_requests_level: int = 0
Liangsheng Yin's avatar
Liangsheng Yin committed
100
    show_time_cost: bool = False
101
    enable_metrics: bool = False
102
103
104
105
    bucket_time_to_first_token: Optional[List[float]] = None
    bucket_e2e_request_latency: Optional[List[float]] = None
    bucket_inter_token_latency: Optional[List[float]] = None
    collect_tokens_histogram: bool = False
106
    decode_log_interval: int = 40
107
    enable_request_time_stats_logging: bool = False
108
    kv_events_config: Optional[str] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
109

110
    # API related
111
    api_key: Optional[str] = None
112
    file_storage_path: str = "sglang_storage"
113
    enable_cache_report: bool = False
Xihuai Wang's avatar
Xihuai Wang committed
114
    reasoning_parser: Optional[str] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
115

116
117
118
    # Data parallelism
    dp_size: int = 1
    load_balance_method: str = "round_robin"
119

xiaobochen's avatar
xiaobochen committed
120
121
    # Expert parallelism
    ep_size: int = 1
122

123
    # Multi-node distributed serving
124
    dist_init_addr: Optional[str] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
125
    nnodes: int = 1
126
    node_rank: int = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
127
128
129

    # Model override args in JSON
    json_model_override_args: str = "{}"
130
    preferred_sampling_params: Optional[str] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
131

132
133
134
    # LoRA
    lora_paths: Optional[List[str]] = None
    max_loras_per_batch: int = 8
135
    lora_backend: str = "triton"
136
137

    # Kernel backend
138
139
    attention_backend: Optional[str] = None
    sampling_backend: Optional[str] = None
140
    grammar_backend: Optional[str] = None
141

142
143
    # Speculative decoding
    speculative_algorithm: Optional[str] = None
144
    speculative_draft_model_path: Optional[str] = None
145
146
147
    speculative_num_steps: Optional[int] = None
    speculative_eagle_topk: Optional[int] = None
    speculative_num_draft_tokens: Optional[int] = None
148
149
    speculative_accept_threshold_single: float = 1.0
    speculative_accept_threshold_acc: float = 1.0
150
    speculative_token_map: Optional[str] = None
151
152
153

    # Double Sparsity
    enable_double_sparsity: bool = False
Vincent's avatar
Vincent committed
154
    ds_channel_config_path: Optional[str] = None
155
156
157
158
159
    ds_heavy_channel_num: int = 32
    ds_heavy_token_num: int = 256
    ds_heavy_channel_type: str = "qk"
    ds_sparse_decode_threshold: int = 4096

160
    # Optimization/debug options
Lianmin Zheng's avatar
Lianmin Zheng committed
161
    disable_radix_cache: bool = False
162
    disable_cuda_graph: bool = False
163
    disable_cuda_graph_padding: bool = False
164
    enable_nccl_nvls: bool = False
165
    enable_tokenizer_batch_encode: bool = False
166
    disable_outlines_disk_cache: bool = False
167
    disable_custom_all_reduce: bool = False
168
    disable_overlap_schedule: bool = False
169
    enable_mixed_chunk: bool = False
Ke Bao's avatar
Ke Bao committed
170
    enable_dp_attention: bool = False
171
    enable_dp_lm_head: bool = False
172
    enable_two_batch_overlap: bool = False
xiaobochen's avatar
xiaobochen committed
173
    enable_ep_moe: bool = False
174
    enable_deepep_moe: bool = False
175
    deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto"
176
    ep_num_redundant_experts: int = 0
177
    ep_dispatch_algorithm: Optional[Literal["static", "dynamic", "fake"]] = None
178
    init_expert_location: str = "trivial"
179
    enable_eplb: bool = False
180
    eplb_algorithm: str = "auto"
181
    eplb_rebalance_num_iterations: int = 1000
182
183
184
185
    expert_distribution_recorder_mode: Optional[
        Literal["stat", "per_pass", "per_token"]
    ] = None
    expert_distribution_recorder_buffer_size: Optional[int] = None
186
    enable_expert_distribution_metrics: bool = False
187
    deepep_config: Optional[str] = None
188
    enable_torch_compile: bool = False
189
    torch_compile_max_bs: int = 32
190
    cuda_graph_max_bs: Optional[int] = None
191
    cuda_graph_bs: Optional[List[int]] = None
192
    torchao_config: str = ""
193
    enable_nan_detection: bool = False
194
    enable_p2p_check: bool = False
195
    triton_attention_reduce_in_fp32: bool = False
196
    triton_attention_num_kv_splits: int = 8
197
    num_continuous_decode_steps: int = 1
198
    delete_ckpt_after_loading: bool = False
199
    enable_memory_saver: bool = False
200
    allow_auto_truncate: bool = False
201
    enable_custom_logit_processor: bool = False
Vincent's avatar
Vincent committed
202
    tool_call_parser: Optional[str] = None
203
    enable_hierarchical_cache: bool = False
204
    hicache_ratio: float = 2.0
Zhiqiang Xie's avatar
Zhiqiang Xie committed
205
206
    hicache_size: int = 0
    hicache_write_policy: str = "write_through_selective"
207
    flashinfer_mla_disable_ragged: bool = False
208
    warmups: Optional[str] = None
209
    moe_dense_tp_size: Optional[int] = None
210
    disable_shared_experts_fusion: bool = False
211
    disable_chunked_prefix_cache: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
212
    disable_fast_image_processor: bool = False
213
    mm_attention_backend: Optional[str] = None
214
215
216
217
218

    # Debug tensor dumps
    debug_tensor_dump_output_folder: Optional[str] = None
    debug_tensor_dump_input_file: Optional[str] = None
    debug_tensor_dump_inject: bool = False
219

Byron Hsu's avatar
Byron Hsu committed
220
221
222
    # For PD disaggregation: can be "null" (not disaggregated), "prefill" (prefill-only), or "decode" (decode-only)
    disaggregation_mode: str = "null"
    disaggregation_bootstrap_port: int = 8998
223
    disaggregation_transfer_backend: str = "mooncake"
224
    disaggregation_ib_device: Optional[str] = None
225
    pdlb_url: Optional[str] = None
Byron Hsu's avatar
Byron Hsu committed
226

Lianmin Zheng's avatar
Lianmin Zheng committed
227
    def __post_init__(self):
228
229
230
        # Expert parallelism
        if self.enable_ep_moe:
            self.ep_size = self.tp_size
Lianmin Zheng's avatar
Lianmin Zheng committed
231
            logger.warning(
232
233
234
                f"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
            )

235
        # Set missing default values
Lianmin Zheng's avatar
Lianmin Zheng committed
236
237
        if self.tokenizer_path is None:
            self.tokenizer_path = self.model_path
238

239
240
241
        if self.device is None:
            self.device = get_device()

242
243
244
        if self.served_model_name is None:
            self.served_model_name = self.model_path

245
246
247
        if self.random_seed is None:
            self.random_seed = random.randint(0, 1 << 30)

Lianmin Zheng's avatar
Lianmin Zheng committed
248
        gpu_mem = get_device_memory_capacity(self.device)
249
250

        # Set mem fraction static, which depends on the tensor parallelism size
Lianmin Zheng's avatar
Lianmin Zheng committed
251
        if self.mem_fraction_static is None:
252
            parallel_size = self.tp_size * self.pp_size
Yi Liu's avatar
Yi Liu committed
253
            if gpu_mem is not None and gpu_mem <= 81920:
254
255
256
257
258
259
260
261
262
263
                if parallel_size >= 16:
                    self.mem_fraction_static = 0.79
                elif parallel_size >= 8:
                    self.mem_fraction_static = 0.81
                elif parallel_size >= 4:
                    self.mem_fraction_static = 0.85
                elif parallel_size >= 2:
                    self.mem_fraction_static = 0.87
                else:
                    self.mem_fraction_static = 0.88
Ying Sheng's avatar
Ying Sheng committed
264
            else:
265
                self.mem_fraction_static = 0.88
266
            if gpu_mem is not None and gpu_mem > 180 * 1000 and is_cuda():
267
268
                self.mem_fraction_static = 0.79
            elif gpu_mem is not None and gpu_mem > 96 * 1024:
269
                mem_fraction = self.mem_fraction_static
270
271
272
273
274
                # 15 GB + additional 3GB for cuda graph
                reserve_mem = 1024 * 18
                # need reserve more memory for spec cuda graph
                if self.speculative_algorithm is not None:
                    reserve_mem = 1024 * 20
275
276
                self.mem_fraction_static = min(
                    mem_fraction + 48 * 1024 * (1 - mem_fraction) / gpu_mem,
277
                    (gpu_mem - reserve_mem) / gpu_mem,
278
                )
279
280
281
            else:
                if self.speculative_algorithm is not None:
                    self.mem_fraction_static *= 0.95
282

283
284
        # Set chunked prefill size, which depends on the gpu memory capacity
        if self.chunked_prefill_size is None:
285
286
287
            if gpu_mem is not None and gpu_mem > 180_000:
                self.chunked_prefill_size = 16384
            elif gpu_mem is not None and gpu_mem < 25_000:
288
                self.chunked_prefill_size = 2048
289
290
            elif self.disaggregation_mode != "null":
                self.chunked_prefill_size = 16384
291
292
            else:
                self.chunked_prefill_size = 8192
Lianmin Zheng's avatar
Lianmin Zheng committed
293
294
        assert self.chunked_prefill_size % self.page_size == 0

295
296
297
        assert self.moe_dense_tp_size in {
            1,
            None,
Lianmin Zheng's avatar
Lianmin Zheng committed
298
        }, "moe_dense_tp_size only support 1 and None currently"
299

300
        if self.attention_backend == "flashmla":
301
302
303
304
            logger.warning(
                "FlashMLA only supports a page_size of 64, change page_size to 64."
            )
            self.page_size = 64
Lianmin Zheng's avatar
Lianmin Zheng committed
305

306
307
308
309
310
311
        if self.attention_backend == "cutlass_mla":
            logger.warning(
                "Cutlass MLA only supports a page_size of 128, change page_size to 128."
            )
            self.page_size = 128

312
        # Set cuda graph max batch size
313
        if self.cuda_graph_max_bs is None:
314
            # Based on detailed statistics, when serving TP1/TP2 models on lower-end GPUs with HBM<25G, you can either disable cuda graph or set `cuda_graph_max_bs` to a very small value to reduce the memory overhead of creating cuda graphs, with almost no impact on performance. However, when serving models with TP4 or TP8, we need to enable cuda graph to maintain high performance. In this case, we can set `cuda_graph_max_bs` to 80 (half of the default value 160) to reduce the memory overhead of creating cuda graphs. Looking at the logs from TP4 serving of qwen2-72b, a value of 80 is sufficient and can reduce the memory overhead of creating cuda graphs on lower-end GPUs compared to the original 160, avoiding OOM issues.
315
            if gpu_mem is not None and gpu_mem < 25_000:
316
317
318
319
                if self.tp_size < 4:
                    self.cuda_graph_max_bs = 8
                else:
                    self.cuda_graph_max_bs = 80
320

321
        # Set kernel backends for hpu device
322
323
324
325
        if self.device == "hpu":
            self.attention_backend = "torch_native"
            self.sampling_backend = "pytorch"

Lianmin Zheng's avatar
Lianmin Zheng committed
326
        # Set kernel backends
327
328
329
330
331
        if self.device == "cpu":
            if self.attention_backend is None:
                self.attention_backend = "intel_amx"
            self.sampling_backend = "pytorch"

332
        if self.sampling_backend is None:
333
334
335
336
337
            self.sampling_backend = (
                "flashinfer" if is_flashinfer_available() else "pytorch"
            )

        if self.attention_backend == "torch_native":
338
            logger.warning(
339
340
341
                "Cuda graph is disabled because of using torch native attention backend"
            )
            self.disable_cuda_graph = True
342

343
344
345
        # Choose grammar backend
        if self.grammar_backend is None:
            self.grammar_backend = "xgrammar"
346

347
        # Data parallelism attention
Ke Bao's avatar
Ke Bao committed
348
        if self.enable_dp_attention:
349
            self.schedule_conservativeness = self.schedule_conservativeness * 0.3
Lianmin Zheng's avatar
Lianmin Zheng committed
350
351
352
353
354
            assert (
                self.dp_size > 1
            ), "Please set a dp-size > 1. You can use 1 < dp-size <= tp-size "
            assert self.tp_size % self.dp_size == 0
            self.chunked_prefill_size = self.chunked_prefill_size // self.dp_size
355
            logger.warning(
356
                f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE kernel issues. "
357
            )
358

359
360
361
362
363
        if self.enable_dp_lm_head:
            assert (
                self.enable_dp_attention
            ), "Please enable dp attention when setting enable_dp_attention. "

364
        # DeepEP MoE
Lianmin Zheng's avatar
Lianmin Zheng committed
365
        self.enable_sp_layernorm = False
366
        if self.enable_deepep_moe:
367
368
369
370
            if self.deepep_mode == "auto":
                assert (
                    not self.enable_dp_attention
                ), "DeepEP MoE `auto` mode is not supported with DP Attention."
371
372
373
            if self.deepep_mode == "normal":
                logger.warning("Cuda graph is disabled because deepep_mode=`normal`")
                self.disable_cuda_graph = True
374
375
376
377
            self.ep_size = self.tp_size
            self.enable_sp_layernorm = (
                self.dp_size < self.tp_size if self.enable_dp_attention else True
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
378
            logger.warning(
379
380
                f"DeepEP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
            )
381

382
383
384
385
386
387
        if self.pp_size > 1:
            self.disable_overlap_schedule = True
            logger.warning(
                "Pipeline parallelism is incompatible with overlap schedule."
            )

388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
        if self.enable_eplb and (self.expert_distribution_recorder_mode is None):
            self.expert_distribution_recorder_mode = "stat"
            logger.info(
                f"EPLB is enabled. The expert_distribution_recorder_mode is automatically set."
            )

        if (self.enable_eplb or (self.init_expert_location is not None)) and (
            self.ep_dispatch_algorithm is None
        ):
            self.ep_dispatch_algorithm = "static"
            logger.info(
                f"EPLB is enabled or init_expert_location is provided. ep_dispatch_algorithm is configured."
            )

        if self.enable_expert_distribution_metrics and (
            self.expert_distribution_recorder_mode is None
        ):
            self.expert_distribution_recorder_mode = "stat"

407
        if self.expert_distribution_recorder_buffer_size is None:
408
409
            if (x := self.eplb_rebalance_num_iterations) is not None:
                self.expert_distribution_recorder_buffer_size = x
410
411
412
            elif self.expert_distribution_recorder_mode is not None:
                self.expert_distribution_recorder_buffer_size = 1000

413
        # Speculative Decoding
414
415
416
417
        if self.speculative_algorithm == "NEXTN":
            # NEXTN shares the same implementation of EAGLE
            self.speculative_algorithm = "EAGLE"

Lianmin Zheng's avatar
Lianmin Zheng committed
418
        if self.speculative_algorithm in ("EAGLE", "EAGLE3"):
419
            if self.max_running_requests is None:
420
                self.max_running_requests = 48
421
            self.disable_overlap_schedule = True
Lianmin Zheng's avatar
Lianmin Zheng committed
422
            logger.warning(
423
                "Overlap scheduler is disabled because of using "
424
                "eagle speculative decoding."
425
            )
426

427
428
429
            model_arch = get_model_arch(self)

            # Auto set draft_model_path DeepSeek-V3/R1
430
431
432
433
434
435
436
            if model_arch == "DeepseekV3ForCausalLM":
                if self.speculative_draft_model_path is None:
                    self.speculative_draft_model_path = self.model_path
                else:
                    logger.warning(
                        "DeepSeek MTP does not require setting speculative_draft_model_path."
                    )
437

438
439
440
441
442
443
444
445
446
447
            # Auto choose parameters
            if self.speculative_num_steps is None:
                assert (
                    self.speculative_eagle_topk is None
                    and self.speculative_num_draft_tokens is None
                )
                (
                    self.speculative_num_steps,
                    self.speculative_eagle_topk,
                    self.speculative_num_draft_tokens,
448
                ) = auto_choose_speculative_params(model_arch)
449
450
451

            if self.page_size > 1 and self.speculative_eagle_topk > 1:
                self.speculative_eagle_topk = 1
Lianmin Zheng's avatar
Lianmin Zheng committed
452
                logger.warning(
453
454
455
456
457
458
459
                    "speculative_eagle_topk is adjusted to 1 when page_size > 1"
                )

            if (
                self.speculative_eagle_topk == 1
                and self.speculative_num_draft_tokens != self.speculative_num_steps + 1
            ):
Lianmin Zheng's avatar
Lianmin Zheng committed
460
                logger.warning(
461
462
463
                    "speculative_num_draft_tokens is adjusted to speculative_num_steps + 1 when speculative_eagle_topk == 1"
                )
                self.speculative_num_draft_tokens = self.speculative_num_steps + 1
464

465
            # The token generated from the verify step is counted.
466
            # If sepculative_num_steps >= speculative_num_draft_tokens, the additional tokens will definitely be discarded.
467
            # assert self.speculative_num_steps < self.speculative_num_draft_tokens
468

469
470
471
472
473
474
        # GGUF
        if (
            self.load_format == "auto" or self.load_format == "gguf"
        ) and check_gguf_file(self.model_path):
            self.quantization = self.load_format = "gguf"

475
476
477
        if is_remote_url(self.model_path):
            self.load_format = "remote"

478
479
480
481
        # AMD-specific Triton attention KV splits default number
        if is_hip():
            self.triton_attention_num_kv_splits = 16

Byron Hsu's avatar
Byron Hsu committed
482
483
484
        # PD disaggregation
        if self.disaggregation_mode == "prefill":
            self.disable_cuda_graph = True
485
            logger.warning("Cuda graph is disabled for prefill server")
Byron Hsu's avatar
Byron Hsu committed
486
487
        elif self.disaggregation_mode == "decode":
            self.disable_radix_cache = True
488
            logger.warning("KV cache is forced as chunk cache for decode server")
Byron Hsu's avatar
Byron Hsu committed
489

490
491
492
        os.environ["SGLANG_ENABLE_TORCH_COMPILE"] = (
            "1" if self.enable_torch_compile else "0"
        )
493
494
495
496
        # Set env var before grammar backends init
        os.environ["SGLANG_DISABLE_OUTLINES_DISK_CACHE"] = (
            "1" if self.disable_outlines_disk_cache else "0"
        )
497

Lianmin Zheng's avatar
Lianmin Zheng committed
498
499
    @staticmethod
    def add_cli_args(parser: argparse.ArgumentParser):
500
        # Model and port args
Lianmin Zheng's avatar
Lianmin Zheng committed
501
502
503
504
505
506
507
508
509
510
511
512
        parser.add_argument(
            "--model-path",
            type=str,
            help="The path of the model weights. This can be a local folder or a Hugging Face repo ID.",
            required=True,
        )
        parser.add_argument(
            "--tokenizer-path",
            type=str,
            default=ServerArgs.tokenizer_path,
            help="The path of the tokenizer.",
        )
Yuanhan Zhang's avatar
Yuanhan Zhang committed
513
514
515
516
517
518
        parser.add_argument(
            "--host", type=str, default=ServerArgs.host, help="The host of the server."
        )
        parser.add_argument(
            "--port", type=int, default=ServerArgs.port, help="The port of the server."
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
519
520
521
522
523
524
525
526
527
        parser.add_argument(
            "--tokenizer-mode",
            type=str,
            default=ServerArgs.tokenizer_mode,
            choices=["auto", "slow"],
            help="Tokenizer mode. 'auto' will use the fast "
            "tokenizer if available, and 'slow' will "
            "always use the slow tokenizer.",
        )
528
529
530
        parser.add_argument(
            "--skip-tokenizer-init",
            action="store_true",
531
            help="If set, skip init tokenizer and pass input_ids in generate request.",
532
        )
533
        parser.add_argument(
Lianmin Zheng's avatar
Lianmin Zheng committed
534
535
536
            "--load-format",
            type=str,
            default=ServerArgs.load_format,
537
538
539
540
541
542
            choices=[
                "auto",
                "pt",
                "safetensors",
                "npcache",
                "dummy",
543
                "sharded_state",
544
545
                "gguf",
                "bitsandbytes",
546
                "layered",
547
                "remote",
548
            ],
Lianmin Zheng's avatar
Lianmin Zheng committed
549
550
551
552
553
554
555
556
557
            help="The format of the model weights to load. "
            '"auto" will try to load the weights in the safetensors format '
            "and fall back to the pytorch bin format if safetensors format "
            "is not available. "
            '"pt" will load the weights in the pytorch bin format. '
            '"safetensors" will load the weights in the safetensors format. '
            '"npcache" will load the weights in pytorch format and store '
            "a numpy cache to speed up the loading. "
            '"dummy" will initialize the weights with random values, '
558
            "which is mainly for profiling."
559
560
            '"gguf" will load the weights in the gguf format. '
            '"bitsandbytes" will load the weights using bitsandbytes '
561
562
563
564
            "quantization."
            '"layered" loads weights layer by layer so that one can quantize a '
            "layer before loading another to make the peak memory envelope "
            "smaller.",
Lianmin Zheng's avatar
Lianmin Zheng committed
565
        )
566
567
568
569
570
        parser.add_argument(
            "--trust-remote-code",
            action="store_true",
            help="Whether or not to allow for custom models defined on the Hub in their own modeling files.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
571
        parser.add_argument(
Lianmin Zheng's avatar
Lianmin Zheng committed
572
            "--dtype",
Cody Yu's avatar
Cody Yu committed
573
            type=str,
Lianmin Zheng's avatar
Lianmin Zheng committed
574
            default=ServerArgs.dtype,
Ying Sheng's avatar
Ying Sheng committed
575
576
            choices=["auto", "half", "float16", "bfloat16", "float", "float32"],
            help="Data type for model weights and activations.\n\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
577
            '* "auto" will use FP16 precision for FP32 and FP16 models, and '
Ying Sheng's avatar
Ying Sheng committed
578
            "BF16 precision for BF16 models.\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
579
580
581
582
            '* "half" for FP16. Recommended for AWQ quantization.\n'
            '* "float16" is the same as "half".\n'
            '* "bfloat16" for a balance between precision and range.\n'
            '* "float" is shorthand for FP32 precision.\n'
Ying Sheng's avatar
Ying Sheng committed
583
584
            '* "float32" for FP32 precision.',
        )
585
586
587
588
        parser.add_argument(
            "--kv-cache-dtype",
            type=str,
            default=ServerArgs.kv_cache_dtype,
bjmsong's avatar
bjmsong committed
589
590
591
            choices=["auto", "fp8_e5m2", "fp8_e4m3"],
            help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" and "fp8_e4m3" is supported for CUDA 11.8+.',
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
592
593
594
595
        parser.add_argument(
            "--quantization",
            type=str,
            default=ServerArgs.quantization,
Ying Sheng's avatar
Ying Sheng committed
596
597
598
599
600
601
            choices=[
                "awq",
                "fp8",
                "gptq",
                "marlin",
                "gptq_marlin",
Ying Sheng's avatar
Ying Sheng committed
602
                "awq_marlin",
Ying Sheng's avatar
Ying Sheng committed
603
                "bitsandbytes",
604
                "gguf",
605
                "modelopt",
606
                "modelopt_fp4",
607
                "w8a8_int8",
HandH1998's avatar
HandH1998 committed
608
                "w8a8_fp8",
AniZpZ's avatar
AniZpZ committed
609
                "moe_wna16",
HandH1998's avatar
HandH1998 committed
610
                "qoq",
Ying Sheng's avatar
Ying Sheng committed
611
            ],
Lianmin Zheng's avatar
Lianmin Zheng committed
612
613
            help="The quantization method.",
        )
614
615
616
617
618
619
620
621
622
        parser.add_argument(
            "--quantization-param-path",
            type=nullable_str,
            default=None,
            help="Path to the JSON file containing the KV cache "
            "scaling factors. This should generally be supplied, when "
            "KV cache dtype is FP8. Otherwise, KV cache scaling factors "
            "default to 1.0, which may cause accuracy issues. ",
        )
623
624
625
626
627
628
629
630
631
        parser.add_argument(
            "--context-length",
            type=int,
            default=ServerArgs.context_length,
            help="The model's maximum context length. Defaults to None (will use the value from the model's config.json instead).",
        )
        parser.add_argument(
            "--device",
            type=str,
632
            default=ServerArgs.device,
633
            help="The device to use ('cuda', 'xpu', 'hpu', 'npu', 'cpu'). Defaults to auto-detection if not specified.",
634
        )
635
636
637
638
639
640
        parser.add_argument(
            "--served-model-name",
            type=str,
            default=ServerArgs.served_model_name,
            help="Override the model name returned by the v1/models endpoint in OpenAI API server.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
641
642
643
644
645
646
        parser.add_argument(
            "--chat-template",
            type=str,
            default=ServerArgs.chat_template,
            help="The buliltin chat template name or the path of the chat template file. This is only used for OpenAI-compatible API server.",
        )
647
648
649
650
651
652
        parser.add_argument(
            "--completion-template",
            type=str,
            default=ServerArgs.completion_template,
            help="The buliltin completion template name or the path of the completion template file. This is only used for OpenAI-compatible API server. only for code completion currently.",
        )
653
654
655
656
657
        parser.add_argument(
            "--is-embedding",
            action="store_true",
            help="Whether to use a CausalLM as an embedding model.",
        )
658
659
660
661
662
663
        parser.add_argument(
            "--enable-multimodal",
            default=ServerArgs.enable_multimodal,
            action="store_true",
            help="Enable the multimodal functionality for the served model. If the model being served is not multimodal, nothing will happen",
        )
664
665
666
667
668
669
670
671
        parser.add_argument(
            "--revision",
            type=str,
            default=None,
            help="The specific model version to use. It can be a branch "
            "name, a tag name, or a commit id. If unspecified, will use "
            "the default version.",
        )
672

673
        # Memory and scheduling
Lianmin Zheng's avatar
Lianmin Zheng committed
674
675
676
677
        parser.add_argument(
            "--mem-fraction-static",
            type=float,
            default=ServerArgs.mem_fraction_static,
678
            help="The fraction of the memory used for static allocation (model weights and KV cache memory pool). Use a smaller value if you see out-of-memory errors.",
Lianmin Zheng's avatar
Lianmin Zheng committed
679
        )
680
681
682
683
684
685
        parser.add_argument(
            "--max-running-requests",
            type=int,
            default=ServerArgs.max_running_requests,
            help="The maximum number of running requests.",
        )
686
687
688
689
        parser.add_argument(
            "--max-total-tokens",
            type=int,
            default=ServerArgs.max_total_tokens,
690
691
            help="The maximum number of tokens in the memory pool. If not specified, it will be automatically calculated based on the memory usage fraction. "
            "This option is typically used for development and debugging purposes.",
692
        )
693
694
695
696
        parser.add_argument(
            "--chunked-prefill-size",
            type=int,
            default=ServerArgs.chunked_prefill_size,
697
            help="The maximum number of tokens in a chunk for the chunked prefill. Setting this to -1 means disabling chunked prefill.",
698
699
700
701
702
703
704
        )
        parser.add_argument(
            "--max-prefill-tokens",
            type=int,
            default=ServerArgs.max_prefill_tokens,
            help="The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model's maximum context length.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
705
        parser.add_argument(
706
            "--schedule-policy",
Lianmin Zheng's avatar
Lianmin Zheng committed
707
            type=str,
708
            default=ServerArgs.schedule_policy,
Liangsheng Yin's avatar
Liangsheng Yin committed
709
            choices=["lpm", "random", "fcfs", "dfs-weight"],
710
            help="The scheduling policy of the requests.",
Lianmin Zheng's avatar
Lianmin Zheng committed
711
        )
712
713
714
715
        parser.add_argument(
            "--schedule-conservativeness",
            type=float,
            default=ServerArgs.schedule_conservativeness,
716
            help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see requests being retracted frequently.",
717
        )
718
719
720
721
        parser.add_argument(
            "--cpu-offload-gb",
            type=int,
            default=ServerArgs.cpu_offload_gb,
722
            help="How many GBs of RAM to reserve for CPU offloading.",
723
        )
724
725
726
727
728
729
        parser.add_argument(
            "--page-size",
            type=int,
            default=ServerArgs.page_size,
            help="The number of tokens in a page.",
        )
730
731
732
733
734
735
736
737
738
739
740
741
        parser.add_argument(
            "--impl",
            type=str,
            default=ServerArgs.impl,
            help="Which implementation of the model to use.\n\n"
            '* "auto" will try to use the SGLang implementation if it exists '
            "and fall back to the Transformers implementation if no SGLang "
            "implementation is available.\n"
            '* "sglang" will use the SGLang model implementation.\n'
            '* "transformers" will use the Transformers model '
            "implementation.\n",
        )
742

743
        # Other runtime options
Lianmin Zheng's avatar
Lianmin Zheng committed
744
        parser.add_argument(
745
            "--tensor-parallel-size",
Lianmin Zheng's avatar
Lianmin Zheng committed
746
            "--tp-size",
Lianmin Zheng's avatar
Lianmin Zheng committed
747
            type=int,
Lianmin Zheng's avatar
Lianmin Zheng committed
748
            default=ServerArgs.tp_size,
749
            help="The tensor parallelism size.",
750
        )
751
752
753
754
755
756
757
758
759
760
761
762
763
        parser.add_argument(
            "--pipeline-parallel-size",
            "--pp-size",
            type=int,
            default=ServerArgs.pp_size,
            help="The pipeline parallelism size.",
        )
        parser.add_argument(
            "--max-micro-batch-size",
            type=int,
            default=ServerArgs.max_micro_batch_size,
            help="The maximum micro batch size in pipeline parallelism.",
        )
764
765
766
        parser.add_argument(
            "--stream-interval",
            type=int,
Lianmin Zheng's avatar
Lianmin Zheng committed
767
            default=ServerArgs.stream_interval,
768
            help="The interval (or buffer size) for streaming in terms of the token length. A smaller value makes streaming smoother, while a larger value makes the throughput higher",
769
        )
770
771
772
773
774
        parser.add_argument(
            "--stream-output",
            action="store_true",
            help="Whether to output as a sequence of disjoint segments.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
775
776
777
778
        parser.add_argument(
            "--random-seed",
            type=int,
            default=ServerArgs.random_seed,
779
            help="The random seed.",
Lianmin Zheng's avatar
Lianmin Zheng committed
780
        )
781
782
783
784
785
786
        parser.add_argument(
            "--constrained-json-whitespace-pattern",
            type=str,
            default=ServerArgs.constrained_json_whitespace_pattern,
            help=r"Regex pattern for syntactic whitespaces allowed in JSON constrained output. For example, to allow the model generate consecutive whitespaces, set the pattern to [\n\t ]*",
        )
787
788
789
790
791
792
        parser.add_argument(
            "--watchdog-timeout",
            type=float,
            default=ServerArgs.watchdog_timeout,
            help="Set watchdog timeout in seconds. If a forward batch takes longer than this, the server will crash to prevent hanging.",
        )
793
794
795
796
797
798
        parser.add_argument(
            "--dist-timeout",
            type=int,
            default=ServerArgs.dist_timeout,
            help="Set timeout for torch.distributed initialization.",
        )
799
800
801
802
        parser.add_argument(
            "--download-dir",
            type=str,
            default=ServerArgs.download_dir,
803
            help="Model download directory for huggingface.",
804
        )
805
806
807
808
809
810
        parser.add_argument(
            "--base-gpu-id",
            type=int,
            default=ServerArgs.base_gpu_id,
            help="The base GPU ID to start allocating GPUs from. Useful when running multiple instances on the same machine.",
        )
811
812
813
814
815
816
        parser.add_argument(
            "--gpu-id-step",
            type=int,
            default=ServerArgs.gpu_id_step,
            help="The delta between consecutive GPU IDs that are used. For example, setting it to 2 will use GPU 0,2,4,...",
        )
817
818

        # Logging
Lianmin Zheng's avatar
Lianmin Zheng committed
819
820
821
822
        parser.add_argument(
            "--log-level",
            type=str,
            default=ServerArgs.log_level,
823
            help="The logging level of all loggers.",
Lianmin Zheng's avatar
Lianmin Zheng committed
824
        )
825
        parser.add_argument(
826
827
828
829
            "--log-level-http",
            type=str,
            default=ServerArgs.log_level_http,
            help="The logging level of HTTP server. If not set, reuse --log-level by default.",
830
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
831
        parser.add_argument(
832
            "--log-requests",
Lianmin Zheng's avatar
Lianmin Zheng committed
833
            action="store_true",
834
835
836
837
838
839
840
841
            help="Log metadata, inputs, outputs of all requests. The verbosity is decided by --log-requests-level",
        )
        parser.add_argument(
            "--log-requests-level",
            type=int,
            default=0,
            help="0: Log metadata. 1. Log metadata and partial input/output. 2. Log every input/output.",
            choices=[0, 1, 2],
Lianmin Zheng's avatar
Lianmin Zheng committed
842
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
843
844
845
        parser.add_argument(
            "--show-time-cost",
            action="store_true",
Ying Sheng's avatar
Ying Sheng committed
846
            help="Show time cost of custom marks.",
Lianmin Zheng's avatar
Lianmin Zheng committed
847
        )
848
849
850
851
852
        parser.add_argument(
            "--enable-metrics",
            action="store_true",
            help="Enable log prometheus metrics.",
        )
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
        parser.add_argument(
            "--bucket-time-to-first-token",
            type=float,
            nargs="+",
            default=ServerArgs.bucket_time_to_first_token,
            help="The buckets of time to first token, specified as a list of floats.",
        )
        parser.add_argument(
            "--bucket-inter-token-latency",
            type=float,
            nargs="+",
            default=ServerArgs.bucket_inter_token_latency,
            help="The buckets of inter-token latency, specified as a list of floats.",
        )
        parser.add_argument(
            "--bucket-e2e-request-latency",
            type=float,
            nargs="+",
            default=ServerArgs.bucket_e2e_request_latency,
            help="The buckets of end-to-end request latency, specified as a list of floats.",
        )
        parser.add_argument(
            "--collect-tokens-histogram",
            action="store_true",
            default=ServerArgs.collect_tokens_histogram,
            help="Collect prompt/generation tokens histogram.",
        )
880
881
882
883
884
885
        parser.add_argument(
            "--kv-events-config",
            type=str,
            default=None,
            help="Config in json format for NVIDIA dynamo KV event publishing. Publishing will be enabled if this flag is used.",
        )
886
887
888
889
        parser.add_argument(
            "--decode-log-interval",
            type=int,
            default=ServerArgs.decode_log_interval,
890
            help="The log interval of decode batch.",
891
        )
892
893
894
895
896
897
        parser.add_argument(
            "--enable-request-time-stats-logging",
            action="store_true",
            default=ServerArgs.enable_request_time_stats_logging,
            help="Enable per request time stats logging",
        )
898

899
        # API related
Liangsheng Yin's avatar
Liangsheng Yin committed
900
901
902
903
        parser.add_argument(
            "--api-key",
            type=str,
            default=ServerArgs.api_key,
904
            help="Set API key of the server. It is also used in the OpenAI API compatible server.",
Liangsheng Yin's avatar
Liangsheng Yin committed
905
        )
906
        parser.add_argument(
907
            "--file-storage-path",
908
            type=str,
909
            default=ServerArgs.file_storage_path,
910
911
            help="The path of the file storage in backend.",
        )
912
913
914
915
916
        parser.add_argument(
            "--enable-cache-report",
            action="store_true",
            help="Return number of cached tokens in usage.prompt_tokens_details for each openai request.",
        )
Xihuai Wang's avatar
Xihuai Wang committed
917
918
919
920
921
922
923
        parser.add_argument(
            "--reasoning-parser",
            type=str,
            choices=list(ReasoningParser.DetectorMap.keys()),
            default=ServerArgs.reasoning_parser,
            help=f"Specify the parser for reasoning models, supported parsers are: {list(ReasoningParser.DetectorMap.keys())}.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
924

925
926
        # Data parallelism
        parser.add_argument(
927
            "--data-parallel-size",
928
929
930
            "--dp-size",
            type=int,
            default=ServerArgs.dp_size,
931
            help="The data parallelism size.",
932
933
934
935
936
        )
        parser.add_argument(
            "--load-balance-method",
            type=str,
            default=ServerArgs.load_balance_method,
937
            help="The load balancing strategy for data parallelism.",
938
939
940
941
942
            choices=[
                "round_robin",
                "shortest_queue",
            ],
        )
943

xiaobochen's avatar
xiaobochen committed
944
945
946
947
948
949
950
951
        # Expert parallelism
        parser.add_argument(
            "--expert-parallel-size",
            "--ep-size",
            type=int,
            default=ServerArgs.ep_size,
            help="The expert parallelism size.",
        )
952

953
        # Multi-node distributed serving
954
        parser.add_argument(
955
            "--dist-init-addr",
956
            "--nccl-init-addr",  # For backward compatibility. This will be removed in the future.
957
            type=str,
958
            help="The host address for initializing distributed backend (e.g., `192.168.0.2:25000`).",
959
960
        )
        parser.add_argument(
Liangsheng Yin's avatar
Liangsheng Yin committed
961
            "--nnodes", type=int, default=ServerArgs.nnodes, help="The number of nodes."
962
        )
963
964
965
        parser.add_argument(
            "--node-rank", type=int, default=ServerArgs.node_rank, help="The node rank."
        )
966

Lianmin Zheng's avatar
Lianmin Zheng committed
967
968
969
970
971
972
973
        # Model override args
        parser.add_argument(
            "--json-model-override-args",
            type=str,
            help="A dictionary in JSON string format used to override default model configurations.",
            default=ServerArgs.json_model_override_args,
        )
974
975
976
977
978
        parser.add_argument(
            "--preferred-sampling-params",
            type=str,
            help="json-formatted sampling settings that will be returned in /get_model_info",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
979

980
981
982
983
984
985
986
        # LoRA
        parser.add_argument(
            "--lora-paths",
            type=str,
            nargs="*",
            default=None,
            action=LoRAPathAction,
987
            help="The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}.",
988
989
990
991
992
        )
        parser.add_argument(
            "--max-loras-per-batch",
            type=int,
            default=8,
993
994
995
996
997
998
999
            help="Maximum number of adapters for a running batch, include base-only request.",
        )
        parser.add_argument(
            "--lora-backend",
            type=str,
            default="triton",
            help="Choose the kernel backend for multi-LoRA serving.",
1000
1001
1002
        )

        # Kernel backend
1003
1004
1005
        parser.add_argument(
            "--attention-backend",
            type=str,
1006
            choices=[
1007
                "aiter",
1008
1009
1010
1011
1012
1013
                "flashinfer",
                "triton",
                "torch_native",
                "fa3",
                "flashmla",
                "cutlass_mla",
1014
                "intel_amx",
1015
            ],
1016
1017
1018
            default=ServerArgs.attention_backend,
            help="Choose the kernels for attention layers.",
        )
1019
1020
1021
1022
1023
1024
1025
        parser.add_argument(
            "--sampling-backend",
            type=str,
            choices=["flashinfer", "pytorch"],
            default=ServerArgs.sampling_backend,
            help="Choose the kernels for sampling layers.",
        )
1026
1027
1028
        parser.add_argument(
            "--grammar-backend",
            type=str,
1029
            choices=["xgrammar", "outlines", "llguidance", "none"],
1030
            default=ServerArgs.grammar_backend,
Lianmin Zheng's avatar
Lianmin Zheng committed
1031
            help="Choose the backend for grammar-guided decoding.",
1032
        )
1033
1034
        parser.add_argument(
            "--enable-flashinfer-mla",
1035
1036
            action=DeprecatedAction,
            help="--enable-flashinfer-mla is deprecated. Please use '--attention-backend flashinfer' instead.",
1037
        )
lukec's avatar
lukec committed
1038
1039
        parser.add_argument(
            "--enable-flashmla",
1040
1041
            action=DeprecatedAction,
            help="--enable-flashmla is deprecated. Please use '--attention-backend flashmla' instead.",
lukec's avatar
lukec committed
1042
        )
1043
1044
1045
1046
1047
        parser.add_argument(
            "--flashinfer-mla-disable-ragged",
            action="store_true",
            help="Not using ragged prefill wrapper when running flashinfer mla",
        )
1048

1049
1050
1051
1052
        # Speculative decoding
        parser.add_argument(
            "--speculative-algorithm",
            type=str,
James Liu's avatar
James Liu committed
1053
            choices=["EAGLE", "EAGLE3", "NEXTN"],
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
            help="Speculative algorithm.",
        )
        parser.add_argument(
            "--speculative-draft-model-path",
            type=str,
            help="The path of the draft model weights. This can be a local folder or a Hugging Face repo ID.",
        )
        parser.add_argument(
            "--speculative-num-steps",
            type=int,
            help="The number of steps sampled from draft model in Speculative Decoding.",
            default=ServerArgs.speculative_num_steps,
        )
        parser.add_argument(
            "--speculative-eagle-topk",
            type=int,
1070
            help="The number of tokens sampled from the draft model in eagle2 each step.",
1071
1072
            default=ServerArgs.speculative_eagle_topk,
        )
1073
1074
1075
        parser.add_argument(
            "--speculative-num-draft-tokens",
            type=int,
1076
            help="The number of tokens sampled from the draft model in Speculative Decoding.",
1077
1078
            default=ServerArgs.speculative_num_draft_tokens,
        )
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
        parser.add_argument(
            "--speculative-accept-threshold-single",
            type=float,
            help="Accept a draft token if its probability in the target model is greater than this threshold.",
            default=ServerArgs.speculative_accept_threshold_single,
        )
        parser.add_argument(
            "--speculative-accept-threshold-acc",
            type=float,
            help="The accept probability of a draft token is raised from its target probability p to min(1, p / threshold_acc).",
            default=ServerArgs.speculative_accept_threshold_acc,
        )
1091
1092
1093
1094
1095
1096
        parser.add_argument(
            "--speculative-token-map",
            type=str,
            help="The path of the draft model's small vocab table.",
            default=ServerArgs.speculative_token_map,
        )
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134

        # Double Sparsity
        parser.add_argument(
            "--enable-double-sparsity",
            action="store_true",
            help="Enable double sparsity attention",
        )
        parser.add_argument(
            "--ds-channel-config-path",
            type=str,
            default=ServerArgs.ds_channel_config_path,
            help="The path of the double sparsity channel config",
        )
        parser.add_argument(
            "--ds-heavy-channel-num",
            type=int,
            default=ServerArgs.ds_heavy_channel_num,
            help="The number of heavy channels in double sparsity attention",
        )
        parser.add_argument(
            "--ds-heavy-token-num",
            type=int,
            default=ServerArgs.ds_heavy_token_num,
            help="The number of heavy tokens in double sparsity attention",
        )
        parser.add_argument(
            "--ds-heavy-channel-type",
            type=str,
            default=ServerArgs.ds_heavy_channel_type,
            help="The type of heavy channels in double sparsity attention",
        )
        parser.add_argument(
            "--ds-sparse-decode-threshold",
            type=int,
            default=ServerArgs.ds_sparse_decode_threshold,
            help="The type of heavy channels in double sparsity attention",
        )

1135
        # Optimization/debug options
Liangsheng Yin's avatar
Liangsheng Yin committed
1136
        parser.add_argument(
Lianmin Zheng's avatar
Lianmin Zheng committed
1137
            "--disable-radix-cache",
Liangsheng Yin's avatar
Liangsheng Yin committed
1138
            action="store_true",
Ying Sheng's avatar
Ying Sheng committed
1139
            help="Disable RadixAttention for prefix caching.",
Liangsheng Yin's avatar
Liangsheng Yin committed
1140
        )
1141
1142
1143
        parser.add_argument(
            "--disable-cuda-graph",
            action="store_true",
1144
            help="Disable cuda graph.",
1145
        )
1146
        parser.add_argument(
1147
1148
            "--disable-cuda-graph-padding",
            action="store_true",
1149
            help="Disable cuda graph when padding is needed. Still uses cuda graph when padding is not needed.",
1150
        )
1151
1152
1153
1154
1155
        parser.add_argument(
            "--enable-nccl-nvls",
            action="store_true",
            help="Enable NCCL NVLS for prefill heavy requests when available.",
        )
1156
1157
1158
1159
1160
        parser.add_argument(
            "--enable-tokenizer-batch-encode",
            action="store_true",
            help="Enable batch tokenization for improved performance when processing multiple text inputs. Do not use with image inputs, pre-tokenized input_ids, or input_embeds.",
        )
1161
        parser.add_argument(
1162
            "--disable-outlines-disk-cache",
1163
            action="store_true",
1164
            help="Disable disk cache of outlines to avoid possible crashes related to file system or high concurrency.",
1165
        )
1166
1167
1168
1169
1170
        parser.add_argument(
            "--disable-custom-all-reduce",
            action="store_true",
            help="Disable the custom all-reduce kernel and fall back to NCCL.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1171
        parser.add_argument(
1172
            "--disable-overlap-schedule",
Lianmin Zheng's avatar
Lianmin Zheng committed
1173
            action="store_true",
1174
            help="Disable the overlap scheduler, which overlaps the CPU scheduler with GPU model worker.",
Lianmin Zheng's avatar
Lianmin Zheng committed
1175
        )
1176
1177
1178
        parser.add_argument(
            "--enable-mixed-chunk",
            action="store_true",
1179
            help="Enabling mixing prefill and decode in a batch when using chunked prefill.",
1180
        )
Ke Bao's avatar
Ke Bao committed
1181
1182
1183
        parser.add_argument(
            "--enable-dp-attention",
            action="store_true",
1184
            help="Enabling data parallelism for attention and tensor parallelism for FFN. The dp size should be equal to the tp size. Currently DeepSeek-V2 and Qwen 2/3 MoE models are supported.",
Ke Bao's avatar
Ke Bao committed
1185
        )
1186
1187
1188
1189
1190
        parser.add_argument(
            "--enable-dp-lm-head",
            action="store_true",
            help="Enable vocabulary parallel across the attention TP group to avoid all-gather across DP groups, optimizing performance under DP attention.",
        )
xiaobochen's avatar
xiaobochen committed
1191
1192
1193
        parser.add_argument(
            "--enable-ep-moe",
            action="store_true",
1194
            help="Enabling expert parallelism for moe. The ep size is equal to the tp size.",
xiaobochen's avatar
xiaobochen committed
1195
        )
1196
1197
1198
1199
1200
        parser.add_argument(
            "--enable-two-batch-overlap",
            action="store_true",
            help="Enabling two micro batches to overlap.",
        )
1201
1202
1203
        parser.add_argument(
            "--enable-torch-compile",
            action="store_true",
1204
1205
            help="Optimize the model with torch.compile. Experimental feature.",
        )
1206
        parser.add_argument(
1207
            "--torch-compile-max-bs",
1208
            type=int,
1209
            default=ServerArgs.torch_compile_max_bs,
1210
1211
            help="Set the maximum batch size when using torch compile.",
        )
1212
        parser.add_argument(
1213
            "--cuda-graph-max-bs",
1214
            type=int,
1215
            default=ServerArgs.cuda_graph_max_bs,
1216
            help="Set the maximum batch size for cuda graph. It will extend the cuda graph capture batch size to this value.",
1217
        )
1218
1219
1220
1221
        parser.add_argument(
            "--cuda-graph-bs",
            type=int,
            nargs="+",
1222
            help="Set the list of batch sizes for cuda graph.",
1223
        )
1224
1225
1226
1227
        parser.add_argument(
            "--torchao-config",
            type=str,
            default=ServerArgs.torchao_config,
1228
            help="Optimize the model with torchao. Experimental feature. Current choices are: int8dq, int8wo, int4wo-<group_size>, fp8wo, fp8dq-per_tensor, fp8dq-per_row",
1229
        )
1230
1231
1232
1233
1234
        parser.add_argument(
            "--enable-nan-detection",
            action="store_true",
            help="Enable the NaN detection for debugging purposes.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1235
        parser.add_argument(
1236
            "--enable-p2p-check",
Lianmin Zheng's avatar
Lianmin Zheng committed
1237
            action="store_true",
1238
            help="Enable P2P check for GPU access, otherwise the p2p access is allowed by default.",
Lianmin Zheng's avatar
Lianmin Zheng committed
1239
        )
1240
        parser.add_argument(
1241
            "--triton-attention-reduce-in-fp32",
1242
            action="store_true",
1243
            help="Cast the intermediate attention results to fp32 to avoid possible crashes related to fp16."
1244
            "This only affects Triton attention kernels.",
1245
        )
1246
1247
1248
1249
1250
1251
        parser.add_argument(
            "--triton-attention-num-kv-splits",
            type=int,
            default=ServerArgs.triton_attention_num_kv_splits,
            help="The number of KV splits in flash decoding Triton kernel. Larger value is better in longer context scenarios. The default value is 8.",
        )
1252
1253
1254
1255
1256
1257
1258
1259
        parser.add_argument(
            "--num-continuous-decode-steps",
            type=int,
            default=ServerArgs.num_continuous_decode_steps,
            help="Run multiple continuous decoding steps to reduce scheduling overhead. "
            "This can potentially increase throughput but may also increase time-to-first-token latency. "
            "The default value is 1, meaning only run one decoding step at a time.",
        )
1260
1261
1262
1263
1264
        parser.add_argument(
            "--delete-ckpt-after-loading",
            action="store_true",
            help="Delete the model checkpoint after loading the model.",
        )
1265
1266
1267
1268
1269
        parser.add_argument(
            "--enable-memory-saver",
            action="store_true",
            help="Allow saving memory using release_memory_occupation and resume_memory_occupation",
        )
1270
1271
1272
1273
1274
        parser.add_argument(
            "--allow-auto-truncate",
            action="store_true",
            help="Allow automatically truncating requests that exceed the maximum input length instead of returning an error.",
        )
1275
1276
1277
1278
1279
        parser.add_argument(
            "--enable-custom-logit-processor",
            action="store_true",
            help="Enable users to pass custom logit processors to the server (disabled by default for security)",
        )
YAMY's avatar
YAMY committed
1280
1281
1282
        parser.add_argument(
            "--tool-call-parser",
            type=str,
1283
            choices=["qwen25", "mistral", "llama3", "deepseekv3", "pythonic"],
YAMY's avatar
YAMY committed
1284
            default=ServerArgs.tool_call_parser,
1285
            help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', 'llama3', 'deepseekv3', and 'pythonic'.",
YAMY's avatar
YAMY committed
1286
        )
1287
1288
1289
1290
1291
        parser.add_argument(
            "--enable-hierarchical-cache",
            action="store_true",
            help="Enable hierarchical cache",
        )
1292
1293
1294
1295
1296
1297
        parser.add_argument(
            "--hicache-ratio",
            type=float,
            default=ServerArgs.hicache_ratio,
            help="The ratio of the size of host KV cache memory pool to the size of device pool.",
        )
Zhiqiang Xie's avatar
Zhiqiang Xie committed
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
        parser.add_argument(
            "--hicache-size",
            type=int,
            default=ServerArgs.hicache_size,
            help="The size of host KV cache memory pool in gigabytes, which will override the hicache_ratio if set.",
        )
        parser.add_argument(
            "--hicache-write-policy",
            type=str,
            choices=["write_back", "write_through", "write_through_selective"],
            default=ServerArgs.hicache_write_policy,
            help="The write policy of hierarchical cache.",
        )
1311
1312
1313
1314
1315
        parser.add_argument(
            "--enable-deepep-moe",
            action="store_true",
            help="Enabling DeepEP MoE implementation for EP MoE.",
        )
1316
1317
1318
1319
1320
1321
        parser.add_argument(
            "--moe-dense-tp-size",
            type=int,
            default=ServerArgs.moe_dense_tp_size,
            help="TP size for MoE dense MLP layers. This flag is useful when, with large TP size, there are errors caused by weights in MLP layers having dimension smaller than the min dimension GEMM supports.",
        )
1322
1323
1324
1325
        parser.add_argument(
            "--deepep-mode",
            type=str,
            choices=["normal", "low_latency", "auto"],
1326
            default="auto",
1327
1328
            help="Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch.",
        )
1329
1330
1331
1332
1333
1334
        parser.add_argument(
            "--ep-num-redundant-experts",
            type=int,
            default=ServerArgs.ep_num_redundant_experts,
            help="Allocate this number of redundant experts in expert parallel.",
        )
1335
1336
1337
1338
1339
1340
        parser.add_argument(
            "--ep-dispatch-algorithm",
            type=str,
            default=ServerArgs.ep_dispatch_algorithm,
            help="The algorithm to choose ranks for redundant experts in expert parallel.",
        )
1341
1342
1343
1344
1345
1346
        parser.add_argument(
            "--init-expert-location",
            type=str,
            default=ServerArgs.init_expert_location,
            help="Initial location of EP experts.",
        )
1347
1348
1349
1350
1351
        parser.add_argument(
            "--enable-eplb",
            action="store_true",
            help="Enable EPLB algorithm",
        )
1352
1353
1354
1355
1356
1357
        parser.add_argument(
            "--eplb-algorithm",
            type=str,
            default=ServerArgs.eplb_algorithm,
            help="Chosen EPLB algorithm",
        )
1358
1359
1360
1361
1362
1363
        parser.add_argument(
            "--eplb-rebalance-num-iterations",
            type=int,
            default=ServerArgs.eplb_rebalance_num_iterations,
            help="Number of iterations to automatically trigger a EPLB re-balance.",
        )
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
        parser.add_argument(
            "--expert-distribution-recorder-mode",
            type=str,
            default=ServerArgs.expert_distribution_recorder_mode,
            help="Mode of expert distribution recorder.",
        )
        parser.add_argument(
            "--expert-distribution-recorder-buffer-size",
            type=int,
            default=ServerArgs.expert_distribution_recorder_buffer_size,
            help="Circular buffer size of expert distribution recorder. Set to -1 to denote infinite buffer.",
        )
1376
1377
1378
1379
1380
        parser.add_argument(
            "--enable-expert-distribution-metrics",
            action="store_true",
            help="Enable logging metrics for expert balancedness",
        )
1381
1382
1383
1384
        parser.add_argument(
            "--deepep-config",
            type=str,
            default=ServerArgs.deepep_config,
1385
            help="Tuned DeepEP config suitable for your own cluster. It can be either a string with JSON content or a file path.",
1386
        )
1387
        parser.add_argument(
1388
1389
1390
            "--disable-shared-experts-fusion",
            action="store_true",
            help="Disable shared experts fusion optimization for deepseek v3/r1.",
1391
        )
1392
1393
1394
1395
1396
        parser.add_argument(
            "--disable-chunked-prefix-cache",
            action="store_true",
            help="Disable chunked prefix cache feature for deepseek, which should save overhead for short sequences.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1397
1398
1399
1400
1401
        parser.add_argument(
            "--disable-fast-image-processor",
            action="store_true",
            help="Adopt base image processor instead of fast image processor.",
        )
1402

1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
        # Server warmups
        parser.add_argument(
            "--warmups",
            type=str,
            required=False,
            help="Specify custom warmup functions (csv) to run before server starts eg. --warmups=warmup_name1,warmup_name2 "
            "will run the functions `warmup_name1` and `warmup_name2` specified in warmup.py before the server starts listening for requests",
        )

        # Debug tensor dumps
        parser.add_argument(
            "--debug-tensor-dump-output-folder",
            type=str,
            default=ServerArgs.debug_tensor_dump_output_folder,
            help="The output folder for dumping tensors.",
        )
        parser.add_argument(
            "--debug-tensor-dump-input-file",
            type=str,
            default=ServerArgs.debug_tensor_dump_input_file,
            help="The input filename for dumping tensors",
        )
        parser.add_argument(
            "--debug-tensor-dump-inject",
            type=str,
            default=ServerArgs.debug_tensor_dump_inject,
            help="Inject the outputs from jax as the input of every layer.",
        )

Byron Hsu's avatar
Byron Hsu committed
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
        # Disaggregation
        parser.add_argument(
            "--disaggregation-mode",
            type=str,
            default="null",
            choices=["null", "prefill", "decode"],
            help='Only used for PD disaggregation. "prefill" for prefill-only server, and "decode" for decode-only server. If not specified, it is not PD disaggregated',
        )
        parser.add_argument(
            "--disaggregation-bootstrap-port",
            type=int,
            default=ServerArgs.disaggregation_bootstrap_port,
            help="Bootstrap server port on the prefill server. Default is 8998.",
        )
1446
1447
1448
1449
        parser.add_argument(
            "--disaggregation-transfer-backend",
            type=str,
            default=ServerArgs.disaggregation_transfer_backend,
1450
            choices=["mooncake", "nixl"],
1451
1452
            help="The backend for disaggregation transfer. Default is mooncake.",
        )
1453
1454
1455
1456
        parser.add_argument(
            "--disaggregation-ib-device",
            type=str,
            default=ServerArgs.disaggregation_ib_device,
1457
1458
1459
            help="The InfiniBand devices for disaggregation transfer, accepts single device (e.g., --disaggregation-ib-device mlx5_0) "
            "or multiple comma-separated devices (e.g., --disaggregation-ib-device mlx5_0,mlx5_1). "
            "Default is None, which triggers automatic device detection when mooncake backend is enabled.",
1460
        )
1461
1462
1463
1464
1465
1466
        parser.add_argument(
            "--pdlb-url",
            type=str,
            default=None,
            help="The URL of the PD disaggregation load balancer. If set, the prefill/decode server will register with the load balancer.",
        )
Byron Hsu's avatar
Byron Hsu committed
1467

1468
1469
1470
1471
1472
1473
1474
1475
        parser.add_argument(
            "--mm-attention-backend",
            type=str,
            choices=["sdpa", "fa3", "triton_attn"],
            default=ServerArgs.mm_attention_backend,
            help="Set multimodal attention backend.",
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
1476
1477
    @classmethod
    def from_cli_args(cls, args: argparse.Namespace):
1478
        args.tp_size = args.tensor_parallel_size
1479
        args.pp_size = args.pipeline_parallel_size
1480
        args.dp_size = args.data_parallel_size
xiaobochen's avatar
xiaobochen committed
1481
        args.ep_size = args.expert_parallel_size
Lianmin Zheng's avatar
Lianmin Zheng committed
1482
1483
1484
1485
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        return cls(**{attr: getattr(args, attr) for attr in attrs})

    def url(self):
1486
        if is_valid_ipv6_address(self.host):
1487
1488
1489
            return f"http://[{self.host}]:{self.port}"
        else:
            return f"http://{self.host}:{self.port}"
Lianmin Zheng's avatar
Lianmin Zheng committed
1490

1491
1492
    def check_server_args(self):
        assert (
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
            self.tp_size * self.pp_size
        ) % self.nnodes == 0, "tp_size must be divisible by number of nodes"

        # FIXME pp constraints
        if self.pp_size > 1:
            assert (
                self.disable_overlap_schedule
                and self.speculative_algorithm is None
                and not self.enable_mixed_chunk
            ), "Pipeline parallelism is not compatible with overlap schedule, speculative decoding, mixed chunked prefill."

1504
        assert not (
1505
1506
            self.dp_size > 1 and self.nnodes != 1 and not self.enable_dp_attention
        ), "multi-node data parallel is not supported unless dp attention!"
1507
1508
1509
1510
        assert (
            self.max_loras_per_batch > 0
            # FIXME
            and (self.lora_paths is None or self.disable_radix_cache)
1511
        ), "compatibility of lora and radix attention is in progress"
1512
        assert self.base_gpu_id >= 0, "base_gpu_id must be non-negative"
1513
        assert self.gpu_id_step >= 1, "gpu_id_step must be positive"
1514

1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
        if isinstance(self.lora_paths, list):
            lora_paths = self.lora_paths
            self.lora_paths = {}
            for lora_path in lora_paths:
                if "=" in lora_path:
                    name, path = lora_path.split("=", 1)
                    self.lora_paths[name] = path
                else:
                    self.lora_paths[lora_path] = lora_path

Lianmin Zheng's avatar
Lianmin Zheng committed
1525

Lianmin Zheng's avatar
Lianmin Zheng committed
1526
def prepare_server_args(argv: List[str]) -> ServerArgs:
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
    """
    Prepare the server arguments from the command line arguments.

    Args:
        args: The command line arguments. Typically, it should be `sys.argv[1:]`
            to ensure compatibility with `parse_args` when no arguments are passed.

    Returns:
        The server arguments.
    """
    parser = argparse.ArgumentParser()
    ServerArgs.add_cli_args(parser)
Lianmin Zheng's avatar
Lianmin Zheng committed
1539
    raw_args = parser.parse_args(argv)
1540
1541
1542
1543
    server_args = ServerArgs.from_cli_args(raw_args)
    return server_args


1544
1545
1546
ZMQ_TCP_PORT_DELTA = 233


Lianmin Zheng's avatar
Lianmin Zheng committed
1547
1548
@dataclasses.dataclass
class PortArgs:
1549
1550
1551
1552
1553
1554
    # The ipc filename for tokenizer to receive inputs from detokenizer (zmq)
    tokenizer_ipc_name: str
    # The ipc filename for scheduler (rank 0) to receive inputs from tokenizer (zmq)
    scheduler_input_ipc_name: str
    # The ipc filename for detokenizer to receive inputs from scheduler (zmq)
    detokenizer_ipc_name: str
1555

1556
1557
    # The port for nccl initialization (torch.dist)
    nccl_port: int
1558

1559
1560
1561
    # The ipc filename for rpc call between Engine and Scheduler
    rpc_ipc_name: str

1562
    @staticmethod
1563
    def init_new(server_args, dp_rank: Optional[int] = None) -> "PortArgs":
1564
        port = server_args.port + random.randint(100, 1000)
1565
1566
1567
        while True:
            if is_port_available(port):
                break
TianYu GUO's avatar
TianYu GUO committed
1568
1569
1570
1571
            if port < 60000:
                port += 42
            else:
                port -= 43
1572

1573
1574
1575
1576
1577
1578
1579
        if not server_args.enable_dp_attention:
            # Normal case, use IPC within a single node
            return PortArgs(
                tokenizer_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
                scheduler_input_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
                detokenizer_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
                nccl_port=port,
1580
                rpc_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
1581
1582
1583
1584
1585
            )
        else:
            # DP attention. Use TCP + port to handle both single-node and multi-node.
            if server_args.nnodes == 1 and server_args.dist_init_addr is None:
                dist_init_addr = ("127.0.0.1", server_args.port + ZMQ_TCP_PORT_DELTA)
Vincent's avatar
Vincent committed
1586
1587
1588
            elif server_args.dist_init_addr.startswith("["):  # ipv6 address
                port_num, host = configure_ipv6(server_args.dist_init_addr)
                dist_init_addr = (host, str(port_num))
1589
1590
            else:
                dist_init_addr = server_args.dist_init_addr.split(":")
Vincent's avatar
Vincent committed
1591

1592
1593
1594
1595
1596
1597
1598
1599
            assert (
                len(dist_init_addr) == 2
            ), "please provide --dist-init-addr as host:port of head node"

            dist_init_host, dist_init_port = dist_init_addr
            port_base = int(dist_init_port) + 1
            if dp_rank is None:
                scheduler_input_port = (
1600
                    port_base + 3
1601
                )  # TokenizerManager to DataParallelController
1602
            else:
1603
                scheduler_input_port = port_base + 3 + 1 + dp_rank
1604
1605
1606
1607
1608
1609

            return PortArgs(
                tokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base}",
                scheduler_input_ipc_name=f"tcp://{dist_init_host}:{scheduler_input_port}",
                detokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base + 1}",
                nccl_port=port,
1610
                rpc_ipc_name=f"tcp://{dist_init_host}:{port_base + 2}",
1611
            )
1612

1613
1614
1615
1616
1617
1618
1619
1620
1621
1622

class LoRAPathAction(argparse.Action):
    def __call__(self, parser, namespace, values, option_string=None):
        setattr(namespace, self.dest, {})
        for lora_path in values:
            if "=" in lora_path:
                name, path = lora_path.split("=", 1)
                getattr(namespace, self.dest)[name] = path
            else:
                getattr(namespace, self.dest)[lora_path] = lora_path
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632


class DeprecatedAction(argparse.Action):
    def __init__(self, option_strings, dest, nargs=0, **kwargs):
        super(DeprecatedAction, self).__init__(
            option_strings, dest, nargs=nargs, **kwargs
        )

    def __call__(self, parser, namespace, values, option_string=None):
        raise ValueError(self.help)
1633
1634


1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
def get_model_arch(args: ServerArgs):
    hf_config = get_config(
        args.model_path,
        trust_remote_code=args.trust_remote_code,
        revision=args.revision,
        model_override_args=json.loads(args.json_model_override_args),
    )
    return hf_config.architectures[0]


def auto_choose_speculative_params(arch: str):
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
    """
    Automatically choose the parameters for speculative decoding.

    You can tune them on your own models and prompts with scripts/playground/bench_speculative.py
    """
    if arch in ["LlamaForCausalLM"]:
        # The default value for llama
        return (5, 4, 8)
    elif arch in ["DeepseekV3ForCausalLM", "DeepseekV2ForCausalLM"]:
        # The default value for deepseek
1656
        return (3, 1, 4)
1657
1658
1659
1660
1661
    elif arch in ["Grok1ForCausalLM", "Grok1VForCausalLM"]:
        return (5, 4, 8)
    else:
        # The default value for all other models
        return (5, 4, 8)