"vscode:/vscode.git/clone" did not exist on "1605ae121e6c792e4f38813814b287b3c8669eb5"
server_args.py 47.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
Lianmin Zheng's avatar
Lianmin Zheng committed
14
15
"""The arguments of the server."""

Lianmin Zheng's avatar
Lianmin Zheng committed
16
17
import argparse
import dataclasses
18
import logging
19
import os
20
import random
21
22
import tempfile
from typing import List, Optional
Lianmin Zheng's avatar
Lianmin Zheng committed
23

24
from sglang.srt.hf_transformers_utils import check_gguf_file
Xihuai Wang's avatar
Xihuai Wang committed
25
from sglang.srt.reasoning_parser import ReasoningParser
26
from sglang.srt.utils import (
HAI's avatar
HAI committed
27
    get_amdgpu_memory_capacity,
28
    get_device,
29
    get_hpu_memory_capacity,
HAI's avatar
HAI committed
30
    get_nvgpu_memory_capacity,
31
    is_cuda,
32
    is_flashinfer_available,
HAI's avatar
HAI committed
33
    is_hip,
34
    is_port_available,
35
    is_remote_url,
36
    is_valid_ipv6_address,
bjmsong's avatar
bjmsong committed
37
    nullable_str,
38
)
39

40
41
logger = logging.getLogger(__name__)

Lianmin Zheng's avatar
Lianmin Zheng committed
42
43
44

@dataclasses.dataclass
class ServerArgs:
Lianmin Zheng's avatar
Lianmin Zheng committed
45
    # Model and tokenizer
Lianmin Zheng's avatar
Lianmin Zheng committed
46
47
48
    model_path: str
    tokenizer_path: Optional[str] = None
    tokenizer_mode: str = "auto"
49
    skip_tokenizer_init: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
50
    load_format: str = "auto"
51
    trust_remote_code: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
52
    dtype: str = "auto"
53
    kv_cache_dtype: str = "auto"
Lianmin Zheng's avatar
Lianmin Zheng committed
54
    quantization: Optional[str] = None
55
    quantization_param_path: nullable_str = None
56
    context_length: Optional[int] = None
57
    device: Optional[str] = None
58
    served_model_name: Optional[str] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
59
    chat_template: Optional[str] = None
60
    completion_template: Optional[str] = None
61
    is_embedding: bool = False
62
    revision: Optional[str] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
63

64
    # Port for the HTTP server
Lianmin Zheng's avatar
Lianmin Zheng committed
65
66
67
68
    host: str = "127.0.0.1"
    port: int = 30000

    # Memory and scheduling
Lianmin Zheng's avatar
Lianmin Zheng committed
69
    mem_fraction_static: Optional[float] = None
70
    max_running_requests: Optional[int] = None
71
    max_total_tokens: Optional[int] = None
72
    chunked_prefill_size: Optional[int] = None
73
    max_prefill_tokens: int = 16384
74
    schedule_policy: str = "fcfs"
75
    schedule_conservativeness: float = 1.0
76
    cpu_offload_gb: int = 0
77
    page_size: int = 1
Lianmin Zheng's avatar
Lianmin Zheng committed
78
79
80

    # Other runtime options
    tp_size: int = 1
81
    stream_interval: int = 1
82
    stream_output: bool = False
83
    random_seed: Optional[int] = None
84
    constrained_json_whitespace_pattern: Optional[str] = None
85
    watchdog_timeout: float = 300
86
    dist_timeout: Optional[int] = None  # timeout for torch.distributed
87
    download_dir: Optional[str] = None
88
    base_gpu_id: int = 0
89
    gpu_id_step: int = 1
Lianmin Zheng's avatar
Lianmin Zheng committed
90
91
92

    # Logging
    log_level: str = "info"
93
    log_level_http: Optional[str] = None
94
    log_requests: bool = False
95
    log_requests_level: int = 0
Liangsheng Yin's avatar
Liangsheng Yin committed
96
    show_time_cost: bool = False
97
    enable_metrics: bool = False
98
    decode_log_interval: int = 40
Liangsheng Yin's avatar
Liangsheng Yin committed
99

100
    # API related
101
    api_key: Optional[str] = None
102
    file_storage_path: str = "sglang_storage"
103
    enable_cache_report: bool = False
Xihuai Wang's avatar
Xihuai Wang committed
104
    reasoning_parser: Optional[str] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
105

106
107
108
    # Data parallelism
    dp_size: int = 1
    load_balance_method: str = "round_robin"
109

xiaobochen's avatar
xiaobochen committed
110
111
    # Expert parallelism
    ep_size: int = 1
112

113
    # Multi-node distributed serving
114
    dist_init_addr: Optional[str] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
115
    nnodes: int = 1
116
    node_rank: int = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
117
118
119
120

    # Model override args in JSON
    json_model_override_args: str = "{}"

121
122
123
    # LoRA
    lora_paths: Optional[List[str]] = None
    max_loras_per_batch: int = 8
124
    lora_backend: str = "triton"
125
126

    # Kernel backend
127
128
    attention_backend: Optional[str] = None
    sampling_backend: Optional[str] = None
129
    grammar_backend: Optional[str] = "xgrammar"
130

131
132
    # Speculative decoding
    speculative_algorithm: Optional[str] = None
133
    speculative_draft_model_path: Optional[str] = None
134
    speculative_num_steps: int = 5
135
136
137
138
    speculative_eagle_topk: int = 4
    speculative_num_draft_tokens: int = 8
    speculative_accept_threshold_single: float = 1.0
    speculative_accept_threshold_acc: float = 1.0
139
    speculative_token_map: Optional[str] = None
140
141
142
143
144
145
146
147
148

    # Double Sparsity
    enable_double_sparsity: bool = False
    ds_channel_config_path: str = None
    ds_heavy_channel_num: int = 32
    ds_heavy_token_num: int = 256
    ds_heavy_channel_type: str = "qk"
    ds_sparse_decode_threshold: int = 4096

149
    # Optimization/debug options
Lianmin Zheng's avatar
Lianmin Zheng committed
150
    disable_radix_cache: bool = False
151
    disable_cuda_graph: bool = False
152
    disable_cuda_graph_padding: bool = False
153
    enable_nccl_nvls: bool = False
154
    disable_outlines_disk_cache: bool = False
155
    disable_custom_all_reduce: bool = False
Ke Bao's avatar
Ke Bao committed
156
    disable_mla: bool = False
157
    disable_overlap_schedule: bool = False
158
    enable_mixed_chunk: bool = False
Ke Bao's avatar
Ke Bao committed
159
    enable_dp_attention: bool = False
xiaobochen's avatar
xiaobochen committed
160
    enable_ep_moe: bool = False
161
    enable_deepep_moe: bool = False
162
    enable_torch_compile: bool = False
163
    torch_compile_max_bs: int = 32
164
    cuda_graph_max_bs: Optional[int] = None
165
    cuda_graph_bs: Optional[List[int]] = None
166
    torchao_config: str = ""
167
    enable_nan_detection: bool = False
168
    enable_p2p_check: bool = False
169
    triton_attention_reduce_in_fp32: bool = False
170
    triton_attention_num_kv_splits: int = 8
171
    num_continuous_decode_steps: int = 1
172
    delete_ckpt_after_loading: bool = False
173
    enable_memory_saver: bool = False
174
    allow_auto_truncate: bool = False
175
    enable_custom_logit_processor: bool = False
YAMY's avatar
YAMY committed
176
    tool_call_parser: str = None
177
    enable_hierarchical_cache: bool = False
178
    hicache_ratio: float = 2.0
179
    enable_flashinfer_mla: bool = False
lukec's avatar
lukec committed
180
    enable_flashmla: bool = False
181
    flashinfer_mla_disable_ragged: bool = False
182
183
184
185
186
187
    warmups: Optional[str] = None

    # Debug tensor dumps
    debug_tensor_dump_output_folder: Optional[str] = None
    debug_tensor_dump_input_file: Optional[str] = None
    debug_tensor_dump_inject: bool = False
188

Byron Hsu's avatar
Byron Hsu committed
189
190
191
192
    # For PD disaggregation: can be "null" (not disaggregated), "prefill" (prefill-only), or "decode" (decode-only)
    disaggregation_mode: str = "null"
    disaggregation_bootstrap_port: int = 8998

Lianmin Zheng's avatar
Lianmin Zheng committed
193
    def __post_init__(self):
194
        # Set missing default values
Lianmin Zheng's avatar
Lianmin Zheng committed
195
196
        if self.tokenizer_path is None:
            self.tokenizer_path = self.model_path
197

198
199
200
        if self.device is None:
            self.device = get_device()

201
202
203
        if self.served_model_name is None:
            self.served_model_name = self.model_path

204
205
206
        if self.random_seed is None:
            self.random_seed = random.randint(0, 1 << 30)

207
        if is_cuda():
208
            gpu_mem = get_nvgpu_memory_capacity()
209
210
        elif is_hip():
            gpu_mem = get_amdgpu_memory_capacity()
211
212
        elif self.device == "hpu":
            gpu_mem = get_hpu_memory_capacity()
213
214
215
        else:
            # GPU memory is not known yet or no GPU is available.
            gpu_mem = None
216
217

        # Set mem fraction static, which depends on the tensor parallelism size
Lianmin Zheng's avatar
Lianmin Zheng committed
218
        if self.mem_fraction_static is None:
219
            if self.tp_size >= 16:
220
                self.mem_fraction_static = 0.79
221
            elif self.tp_size >= 8:
222
                self.mem_fraction_static = 0.81
Lianmin Zheng's avatar
Lianmin Zheng committed
223
            elif self.tp_size >= 4:
224
                self.mem_fraction_static = 0.85
Lianmin Zheng's avatar
Lianmin Zheng committed
225
            elif self.tp_size >= 2:
226
                self.mem_fraction_static = 0.87
Ying Sheng's avatar
Ying Sheng committed
227
            else:
228
                self.mem_fraction_static = 0.88
229

230
231
        # Set chunked prefill size, which depends on the gpu memory capacity
        if self.chunked_prefill_size is None:
232
            if gpu_mem is not None and gpu_mem < 25_000:
233
234
235
                self.chunked_prefill_size = 2048
            else:
                self.chunked_prefill_size = 8192
236

Lianmin Zheng's avatar
Lianmin Zheng committed
237
238
        assert self.chunked_prefill_size % self.page_size == 0

lukec's avatar
lukec committed
239
        if self.enable_flashmla is True:
240
241
242
243
            logger.warning(
                "FlashMLA only supports a page_size of 64, change page_size to 64."
            )
            self.page_size = 64
244
245
        # Set cuda graph max batch size
        if self.cuda_graph_max_bs is None:
246
            # Based on detailed statistics, when serving TP1/TP2 models on lower-end GPUs with HBM<25G, you can either disable cuda graph or set `cuda_graph_max_bs` to a very small value to reduce the memory overhead of creating cuda graphs, with almost no impact on performance. However, when serving models with TP4 or TP8, we need to enable cuda graph to maintain high performance. In this case, we can set `cuda_graph_max_bs` to 80 (half of the default value 160) to reduce the memory overhead of creating cuda graphs. Looking at the logs from TP4 serving of qwen2-72b, a value of 80 is sufficient and can reduce the memory overhead of creating cuda graphs on lower-end GPUs compared to the original 160, avoiding OOM issues.
247
            if gpu_mem is not None and gpu_mem < 25_000:
248
249
250
251
                if self.tp_size < 4:
                    self.cuda_graph_max_bs = 8
                else:
                    self.cuda_graph_max_bs = 80
252
253
            else:
                self.cuda_graph_max_bs = 160
254

255
        # Choose kernel backends
256
257
258
259
        if self.device == "hpu":
            self.attention_backend = "torch_native"
            self.sampling_backend = "pytorch"

260
        if self.attention_backend is None:
261
262
263
            self.attention_backend = (
                "flashinfer" if is_flashinfer_available() else "triton"
            )
264
        if self.sampling_backend is None:
265
266
267
268
269
            self.sampling_backend = (
                "flashinfer" if is_flashinfer_available() else "pytorch"
            )

        if self.attention_backend == "torch_native":
270
            logger.warning(
271
272
273
                "Cuda graph is disabled because of using torch native attention backend"
            )
            self.disable_cuda_graph = True
274

275
276
277
278
279
280
281
        # Expert parallelism
        if self.enable_ep_moe:
            self.ep_size = self.tp_size
            logger.info(
                f"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
            )

282
        # Data parallelism attention
Ke Bao's avatar
Ke Bao committed
283
        if self.enable_dp_attention:
284
            self.schedule_conservativeness = self.schedule_conservativeness * 0.3
Lianmin Zheng's avatar
Lianmin Zheng committed
285
286
287
288
289
            assert (
                self.dp_size > 1
            ), "Please set a dp-size > 1. You can use 1 < dp-size <= tp-size "
            assert self.tp_size % self.dp_size == 0
            self.chunked_prefill_size = self.chunked_prefill_size // self.dp_size
290
            logger.warning(
291
                f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE kernel issues. "
292
            )
293
294
295
296
297
298
299
300
301
302
303

        self.enable_sp_layernorm = False
        # DeepEP MoE
        if self.enable_deepep_moe:
            self.ep_size = self.tp_size
            self.enable_sp_layernorm = (
                self.dp_size < self.tp_size if self.enable_dp_attention else True
            )
            logger.info(
                f"DeepEP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
            )
304

305
        # Speculative Decoding
306
307
308
309
        if self.speculative_algorithm == "NEXTN":
            # NEXTN shares the same implementation of EAGLE
            self.speculative_algorithm = "EAGLE"

James Liu's avatar
James Liu committed
310
311
312
313
        if (
            self.speculative_algorithm == "EAGLE"
            or self.speculative_algorithm == "EAGLE3"
        ):
314
315
            if self.max_running_requests is None:
                self.max_running_requests = 32
316
            self.disable_overlap_schedule = True
317
            logger.info(
318
                "Overlap scheduler is disabled because of using "
319
                "eagle speculative decoding."
320
            )
321
            # The token generated from the verify step is counted.
322
            # If sepculative_num_steps >= speculative_num_draft_tokens, the additional tokens will definitely be discarded.
323
            # assert self.speculative_num_steps < self.speculative_num_draft_tokens
324

325
326
327
328
329
330
        # GGUF
        if (
            self.load_format == "auto" or self.load_format == "gguf"
        ) and check_gguf_file(self.model_path):
            self.quantization = self.load_format = "gguf"

331
332
333
        if is_remote_url(self.model_path):
            self.load_format = "remote"

334
335
336
337
        # AMD-specific Triton attention KV splits default number
        if is_hip():
            self.triton_attention_num_kv_splits = 16

Byron Hsu's avatar
Byron Hsu committed
338
339
340
341
342
343
344
345
346
347
348
349
        # PD disaggregation
        if self.disaggregation_mode == "prefill":
            self.disable_cuda_graph = True
            logger.warning("KV cache is forced as chunk cache for decode server")
            self.disable_overlap_schedule = True
            logger.warning("Overlap scheduler is disabled for prefill server")
        elif self.disaggregation_mode == "decode":
            self.disable_radix_cache = True
            logger.warning("Cuda graph is disabled for prefill server")
            self.disable_overlap_schedule = True
            logger.warning("Overlap scheduler is disabled for decode server")

350
351
352
353
        os.environ["SGLANG_ENABLE_TORCH_COMPILE"] = (
            "1" if self.enable_torch_compile else "0"
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
354
355
    @staticmethod
    def add_cli_args(parser: argparse.ArgumentParser):
356
        # Model and port args
Lianmin Zheng's avatar
Lianmin Zheng committed
357
358
359
360
361
362
363
364
365
366
367
368
        parser.add_argument(
            "--model-path",
            type=str,
            help="The path of the model weights. This can be a local folder or a Hugging Face repo ID.",
            required=True,
        )
        parser.add_argument(
            "--tokenizer-path",
            type=str,
            default=ServerArgs.tokenizer_path,
            help="The path of the tokenizer.",
        )
Yuanhan Zhang's avatar
Yuanhan Zhang committed
369
370
371
372
373
374
        parser.add_argument(
            "--host", type=str, default=ServerArgs.host, help="The host of the server."
        )
        parser.add_argument(
            "--port", type=int, default=ServerArgs.port, help="The port of the server."
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
375
376
377
378
379
380
381
382
383
        parser.add_argument(
            "--tokenizer-mode",
            type=str,
            default=ServerArgs.tokenizer_mode,
            choices=["auto", "slow"],
            help="Tokenizer mode. 'auto' will use the fast "
            "tokenizer if available, and 'slow' will "
            "always use the slow tokenizer.",
        )
384
385
386
387
388
        parser.add_argument(
            "--skip-tokenizer-init",
            action="store_true",
            help="If set, skip init tokenizer and pass input_ids in generate request",
        )
389
        parser.add_argument(
Lianmin Zheng's avatar
Lianmin Zheng committed
390
391
392
            "--load-format",
            type=str,
            default=ServerArgs.load_format,
393
394
395
396
397
398
            choices=[
                "auto",
                "pt",
                "safetensors",
                "npcache",
                "dummy",
399
                "sharded_state",
400
401
                "gguf",
                "bitsandbytes",
402
                "layered",
403
                "remote",
404
            ],
Lianmin Zheng's avatar
Lianmin Zheng committed
405
406
407
408
409
410
411
412
413
            help="The format of the model weights to load. "
            '"auto" will try to load the weights in the safetensors format '
            "and fall back to the pytorch bin format if safetensors format "
            "is not available. "
            '"pt" will load the weights in the pytorch bin format. '
            '"safetensors" will load the weights in the safetensors format. '
            '"npcache" will load the weights in pytorch format and store '
            "a numpy cache to speed up the loading. "
            '"dummy" will initialize the weights with random values, '
414
            "which is mainly for profiling."
415
416
            '"gguf" will load the weights in the gguf format. '
            '"bitsandbytes" will load the weights using bitsandbytes '
417
418
419
420
            "quantization."
            '"layered" loads weights layer by layer so that one can quantize a '
            "layer before loading another to make the peak memory envelope "
            "smaller.",
Lianmin Zheng's avatar
Lianmin Zheng committed
421
        )
422
423
424
425
426
        parser.add_argument(
            "--trust-remote-code",
            action="store_true",
            help="Whether or not to allow for custom models defined on the Hub in their own modeling files.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
427
        parser.add_argument(
Lianmin Zheng's avatar
Lianmin Zheng committed
428
            "--dtype",
Cody Yu's avatar
Cody Yu committed
429
            type=str,
Lianmin Zheng's avatar
Lianmin Zheng committed
430
            default=ServerArgs.dtype,
Ying Sheng's avatar
Ying Sheng committed
431
432
            choices=["auto", "half", "float16", "bfloat16", "float", "float32"],
            help="Data type for model weights and activations.\n\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
433
            '* "auto" will use FP16 precision for FP32 and FP16 models, and '
Ying Sheng's avatar
Ying Sheng committed
434
            "BF16 precision for BF16 models.\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
435
436
437
438
            '* "half" for FP16. Recommended for AWQ quantization.\n'
            '* "float16" is the same as "half".\n'
            '* "bfloat16" for a balance between precision and range.\n'
            '* "float" is shorthand for FP32 precision.\n'
Ying Sheng's avatar
Ying Sheng committed
439
440
            '* "float32" for FP32 precision.',
        )
441
442
443
444
        parser.add_argument(
            "--kv-cache-dtype",
            type=str,
            default=ServerArgs.kv_cache_dtype,
bjmsong's avatar
bjmsong committed
445
446
447
            choices=["auto", "fp8_e5m2", "fp8_e4m3"],
            help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" and "fp8_e4m3" is supported for CUDA 11.8+.',
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
448
449
450
451
        parser.add_argument(
            "--quantization",
            type=str,
            default=ServerArgs.quantization,
Ying Sheng's avatar
Ying Sheng committed
452
453
454
455
456
457
            choices=[
                "awq",
                "fp8",
                "gptq",
                "marlin",
                "gptq_marlin",
Ying Sheng's avatar
Ying Sheng committed
458
                "awq_marlin",
Ying Sheng's avatar
Ying Sheng committed
459
                "bitsandbytes",
460
                "gguf",
461
                "modelopt",
462
                "w8a8_int8",
HandH1998's avatar
HandH1998 committed
463
                "w8a8_fp8",
Ying Sheng's avatar
Ying Sheng committed
464
            ],
Lianmin Zheng's avatar
Lianmin Zheng committed
465
466
            help="The quantization method.",
        )
467
468
469
470
471
472
473
474
475
        parser.add_argument(
            "--quantization-param-path",
            type=nullable_str,
            default=None,
            help="Path to the JSON file containing the KV cache "
            "scaling factors. This should generally be supplied, when "
            "KV cache dtype is FP8. Otherwise, KV cache scaling factors "
            "default to 1.0, which may cause accuracy issues. ",
        )
476
477
478
479
480
481
482
483
484
        parser.add_argument(
            "--context-length",
            type=int,
            default=ServerArgs.context_length,
            help="The model's maximum context length. Defaults to None (will use the value from the model's config.json instead).",
        )
        parser.add_argument(
            "--device",
            type=str,
485
486
            default=ServerArgs.device,
            help="The device to use ('cuda', 'xpu', 'hpu', 'cpu'). Defaults to auto-detection if not specified.",
487
        )
488
489
490
491
492
493
        parser.add_argument(
            "--served-model-name",
            type=str,
            default=ServerArgs.served_model_name,
            help="Override the model name returned by the v1/models endpoint in OpenAI API server.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
494
495
496
497
498
499
        parser.add_argument(
            "--chat-template",
            type=str,
            default=ServerArgs.chat_template,
            help="The buliltin chat template name or the path of the chat template file. This is only used for OpenAI-compatible API server.",
        )
500
501
502
503
504
505
        parser.add_argument(
            "--completion-template",
            type=str,
            default=ServerArgs.completion_template,
            help="The buliltin completion template name or the path of the completion template file. This is only used for OpenAI-compatible API server. only for code completion currently.",
        )
506
507
508
509
510
        parser.add_argument(
            "--is-embedding",
            action="store_true",
            help="Whether to use a CausalLM as an embedding model.",
        )
511
512
513
514
515
516
517
518
        parser.add_argument(
            "--revision",
            type=str,
            default=None,
            help="The specific model version to use. It can be a branch "
            "name, a tag name, or a commit id. If unspecified, will use "
            "the default version.",
        )
519
        # Memory and scheduling
Lianmin Zheng's avatar
Lianmin Zheng committed
520
521
522
523
        parser.add_argument(
            "--mem-fraction-static",
            type=float,
            default=ServerArgs.mem_fraction_static,
524
            help="The fraction of the memory used for static allocation (model weights and KV cache memory pool). Use a smaller value if you see out-of-memory errors.",
Lianmin Zheng's avatar
Lianmin Zheng committed
525
        )
526
527
528
529
530
531
        parser.add_argument(
            "--max-running-requests",
            type=int,
            default=ServerArgs.max_running_requests,
            help="The maximum number of running requests.",
        )
532
533
534
535
        parser.add_argument(
            "--max-total-tokens",
            type=int,
            default=ServerArgs.max_total_tokens,
536
537
            help="The maximum number of tokens in the memory pool. If not specified, it will be automatically calculated based on the memory usage fraction. "
            "This option is typically used for development and debugging purposes.",
538
        )
539
540
541
542
        parser.add_argument(
            "--chunked-prefill-size",
            type=int,
            default=ServerArgs.chunked_prefill_size,
543
            help="The maximum number of tokens in a chunk for the chunked prefill. Setting this to -1 means disabling chunked prefill.",
544
545
546
547
548
549
550
        )
        parser.add_argument(
            "--max-prefill-tokens",
            type=int,
            default=ServerArgs.max_prefill_tokens,
            help="The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model's maximum context length.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
551
        parser.add_argument(
552
            "--schedule-policy",
Lianmin Zheng's avatar
Lianmin Zheng committed
553
            type=str,
554
            default=ServerArgs.schedule_policy,
Liangsheng Yin's avatar
Liangsheng Yin committed
555
            choices=["lpm", "random", "fcfs", "dfs-weight"],
556
            help="The scheduling policy of the requests.",
Lianmin Zheng's avatar
Lianmin Zheng committed
557
        )
558
559
560
561
        parser.add_argument(
            "--schedule-conservativeness",
            type=float,
            default=ServerArgs.schedule_conservativeness,
562
            help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see requests being retracted frequently.",
563
        )
564
565
566
567
        parser.add_argument(
            "--cpu-offload-gb",
            type=int,
            default=ServerArgs.cpu_offload_gb,
568
            help="How many GBs of RAM to reserve for CPU offloading.",
569
        )
570
571
572
573
574
575
        parser.add_argument(
            "--page-size",
            type=int,
            default=ServerArgs.page_size,
            help="The number of tokens in a page.",
        )
576

577
        # Other runtime options
Lianmin Zheng's avatar
Lianmin Zheng committed
578
        parser.add_argument(
579
            "--tensor-parallel-size",
Lianmin Zheng's avatar
Lianmin Zheng committed
580
            "--tp-size",
Lianmin Zheng's avatar
Lianmin Zheng committed
581
            type=int,
Lianmin Zheng's avatar
Lianmin Zheng committed
582
            default=ServerArgs.tp_size,
583
            help="The tensor parallelism size.",
584
        )
585
586
587
        parser.add_argument(
            "--stream-interval",
            type=int,
Lianmin Zheng's avatar
Lianmin Zheng committed
588
            default=ServerArgs.stream_interval,
589
            help="The interval (or buffer size) for streaming in terms of the token length. A smaller value makes streaming smoother, while a larger value makes the throughput higher",
590
        )
591
592
593
594
595
        parser.add_argument(
            "--stream-output",
            action="store_true",
            help="Whether to output as a sequence of disjoint segments.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
596
597
598
599
        parser.add_argument(
            "--random-seed",
            type=int,
            default=ServerArgs.random_seed,
600
            help="The random seed.",
Lianmin Zheng's avatar
Lianmin Zheng committed
601
        )
602
603
604
605
606
607
        parser.add_argument(
            "--constrained-json-whitespace-pattern",
            type=str,
            default=ServerArgs.constrained_json_whitespace_pattern,
            help=r"Regex pattern for syntactic whitespaces allowed in JSON constrained output. For example, to allow the model generate consecutive whitespaces, set the pattern to [\n\t ]*",
        )
608
609
610
611
612
613
        parser.add_argument(
            "--watchdog-timeout",
            type=float,
            default=ServerArgs.watchdog_timeout,
            help="Set watchdog timeout in seconds. If a forward batch takes longer than this, the server will crash to prevent hanging.",
        )
614
615
616
617
618
619
        parser.add_argument(
            "--dist-timeout",
            type=int,
            default=ServerArgs.dist_timeout,
            help="Set timeout for torch.distributed initialization.",
        )
620
621
622
623
        parser.add_argument(
            "--download-dir",
            type=str,
            default=ServerArgs.download_dir,
Lianmin Zheng's avatar
Lianmin Zheng committed
624
            help="Model download directory for huggingface.",
625
        )
626
627
628
629
630
631
        parser.add_argument(
            "--base-gpu-id",
            type=int,
            default=ServerArgs.base_gpu_id,
            help="The base GPU ID to start allocating GPUs from. Useful when running multiple instances on the same machine.",
        )
632
633
634
635
636
637
        parser.add_argument(
            "--gpu-id-step",
            type=int,
            default=ServerArgs.gpu_id_step,
            help="The delta between consecutive GPU IDs that are used. For example, setting it to 2 will use GPU 0,2,4,...",
        )
638
639

        # Logging
Lianmin Zheng's avatar
Lianmin Zheng committed
640
641
642
643
        parser.add_argument(
            "--log-level",
            type=str,
            default=ServerArgs.log_level,
644
            help="The logging level of all loggers.",
Lianmin Zheng's avatar
Lianmin Zheng committed
645
        )
646
        parser.add_argument(
647
648
649
650
            "--log-level-http",
            type=str,
            default=ServerArgs.log_level_http,
            help="The logging level of HTTP server. If not set, reuse --log-level by default.",
651
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
652
        parser.add_argument(
653
            "--log-requests",
Lianmin Zheng's avatar
Lianmin Zheng committed
654
            action="store_true",
655
656
657
658
659
660
661
662
            help="Log metadata, inputs, outputs of all requests. The verbosity is decided by --log-requests-level",
        )
        parser.add_argument(
            "--log-requests-level",
            type=int,
            default=0,
            help="0: Log metadata. 1. Log metadata and partial input/output. 2. Log every input/output.",
            choices=[0, 1, 2],
Lianmin Zheng's avatar
Lianmin Zheng committed
663
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
664
665
666
        parser.add_argument(
            "--show-time-cost",
            action="store_true",
Ying Sheng's avatar
Ying Sheng committed
667
            help="Show time cost of custom marks.",
Lianmin Zheng's avatar
Lianmin Zheng committed
668
        )
669
670
671
672
673
        parser.add_argument(
            "--enable-metrics",
            action="store_true",
            help="Enable log prometheus metrics.",
        )
674
675
676
677
        parser.add_argument(
            "--decode-log-interval",
            type=int,
            default=ServerArgs.decode_log_interval,
678
            help="The log interval of decode batch.",
679
        )
680

681
        # API related
Liangsheng Yin's avatar
Liangsheng Yin committed
682
683
684
685
        parser.add_argument(
            "--api-key",
            type=str,
            default=ServerArgs.api_key,
686
            help="Set API key of the server. It is also used in the OpenAI API compatible server.",
Liangsheng Yin's avatar
Liangsheng Yin committed
687
        )
688
        parser.add_argument(
689
            "--file-storage-path",
690
            type=str,
691
            default=ServerArgs.file_storage_path,
692
693
            help="The path of the file storage in backend.",
        )
694
695
696
697
698
        parser.add_argument(
            "--enable-cache-report",
            action="store_true",
            help="Return number of cached tokens in usage.prompt_tokens_details for each openai request.",
        )
Xihuai Wang's avatar
Xihuai Wang committed
699
700
701
702
703
704
705
        parser.add_argument(
            "--reasoning-parser",
            type=str,
            choices=list(ReasoningParser.DetectorMap.keys()),
            default=ServerArgs.reasoning_parser,
            help=f"Specify the parser for reasoning models, supported parsers are: {list(ReasoningParser.DetectorMap.keys())}.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
706

707
708
        # Data parallelism
        parser.add_argument(
709
            "--data-parallel-size",
710
711
712
            "--dp-size",
            type=int,
            default=ServerArgs.dp_size,
713
            help="The data parallelism size.",
714
715
716
717
718
        )
        parser.add_argument(
            "--load-balance-method",
            type=str,
            default=ServerArgs.load_balance_method,
719
            help="The load balancing strategy for data parallelism.",
720
721
722
723
724
            choices=[
                "round_robin",
                "shortest_queue",
            ],
        )
725

xiaobochen's avatar
xiaobochen committed
726
727
728
729
730
731
732
733
        # Expert parallelism
        parser.add_argument(
            "--expert-parallel-size",
            "--ep-size",
            type=int,
            default=ServerArgs.ep_size,
            help="The expert parallelism size.",
        )
734

735
        # Multi-node distributed serving
736
        parser.add_argument(
737
738
            "--dist-init-addr",
            "--nccl-init-addr",  # For backward compatbility. This will be removed in the future.
739
            type=str,
740
            help="The host address for initializing distributed backend (e.g., `192.168.0.2:25000`).",
741
742
        )
        parser.add_argument(
Liangsheng Yin's avatar
Liangsheng Yin committed
743
            "--nnodes", type=int, default=ServerArgs.nnodes, help="The number of nodes."
744
        )
745
746
747
        parser.add_argument(
            "--node-rank", type=int, default=ServerArgs.node_rank, help="The node rank."
        )
748

Lianmin Zheng's avatar
Lianmin Zheng committed
749
750
751
752
753
754
755
756
        # Model override args
        parser.add_argument(
            "--json-model-override-args",
            type=str,
            help="A dictionary in JSON string format used to override default model configurations.",
            default=ServerArgs.json_model_override_args,
        )

757
758
759
760
761
762
763
        # LoRA
        parser.add_argument(
            "--lora-paths",
            type=str,
            nargs="*",
            default=None,
            action=LoRAPathAction,
764
            help="The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}.",
765
766
767
768
769
        )
        parser.add_argument(
            "--max-loras-per-batch",
            type=int,
            default=8,
770
771
772
773
774
775
776
            help="Maximum number of adapters for a running batch, include base-only request.",
        )
        parser.add_argument(
            "--lora-backend",
            type=str,
            default="triton",
            help="Choose the kernel backend for multi-LoRA serving.",
777
778
779
        )

        # Kernel backend
780
781
782
        parser.add_argument(
            "--attention-backend",
            type=str,
783
            choices=["flashinfer", "triton", "torch_native", "fa3"],
784
785
786
            default=ServerArgs.attention_backend,
            help="Choose the kernels for attention layers.",
        )
787
788
789
790
791
792
793
        parser.add_argument(
            "--sampling-backend",
            type=str,
            choices=["flashinfer", "pytorch"],
            default=ServerArgs.sampling_backend,
            help="Choose the kernels for sampling layers.",
        )
794
795
796
        parser.add_argument(
            "--grammar-backend",
            type=str,
797
            choices=["xgrammar", "outlines", "llguidance"],
798
            default=ServerArgs.grammar_backend,
Lianmin Zheng's avatar
Lianmin Zheng committed
799
            help="Choose the backend for grammar-guided decoding.",
800
        )
801
802
803
804
805
        parser.add_argument(
            "--enable-flashinfer-mla",
            action="store_true",
            help="Enable FlashInfer MLA optimization",
        )
lukec's avatar
lukec committed
806
807
808
809
810
        parser.add_argument(
            "--enable-flashmla",
            action="store_true",
            help="Enable FlashMLA decode optimization",
        )
811
812
813
814
815
        parser.add_argument(
            "--flashinfer-mla-disable-ragged",
            action="store_true",
            help="Not using ragged prefill wrapper when running flashinfer mla",
        )
816

817
818
819
820
        # Speculative decoding
        parser.add_argument(
            "--speculative-algorithm",
            type=str,
James Liu's avatar
James Liu committed
821
            choices=["EAGLE", "EAGLE3", "NEXTN"],
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
            help="Speculative algorithm.",
        )
        parser.add_argument(
            "--speculative-draft-model-path",
            type=str,
            help="The path of the draft model weights. This can be a local folder or a Hugging Face repo ID.",
        )
        parser.add_argument(
            "--speculative-num-steps",
            type=int,
            help="The number of steps sampled from draft model in Speculative Decoding.",
            default=ServerArgs.speculative_num_steps,
        )
        parser.add_argument(
            "--speculative-eagle-topk",
            type=int,
838
            help="The number of tokens sampled from the draft model in eagle2 each step.",
839
840
            default=ServerArgs.speculative_eagle_topk,
        )
841
842
843
        parser.add_argument(
            "--speculative-num-draft-tokens",
            type=int,
844
            help="The number of tokens sampled from the draft model in Speculative Decoding.",
845
846
            default=ServerArgs.speculative_num_draft_tokens,
        )
847
848
849
850
851
852
853
854
855
856
857
858
        parser.add_argument(
            "--speculative-accept-threshold-single",
            type=float,
            help="Accept a draft token if its probability in the target model is greater than this threshold.",
            default=ServerArgs.speculative_accept_threshold_single,
        )
        parser.add_argument(
            "--speculative-accept-threshold-acc",
            type=float,
            help="The accept probability of a draft token is raised from its target probability p to min(1, p / threshold_acc).",
            default=ServerArgs.speculative_accept_threshold_acc,
        )
859
860
861
862
863
864
        parser.add_argument(
            "--speculative-token-map",
            type=str,
            help="The path of the draft model's small vocab table.",
            default=ServerArgs.speculative_token_map,
        )
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902

        # Double Sparsity
        parser.add_argument(
            "--enable-double-sparsity",
            action="store_true",
            help="Enable double sparsity attention",
        )
        parser.add_argument(
            "--ds-channel-config-path",
            type=str,
            default=ServerArgs.ds_channel_config_path,
            help="The path of the double sparsity channel config",
        )
        parser.add_argument(
            "--ds-heavy-channel-num",
            type=int,
            default=ServerArgs.ds_heavy_channel_num,
            help="The number of heavy channels in double sparsity attention",
        )
        parser.add_argument(
            "--ds-heavy-token-num",
            type=int,
            default=ServerArgs.ds_heavy_token_num,
            help="The number of heavy tokens in double sparsity attention",
        )
        parser.add_argument(
            "--ds-heavy-channel-type",
            type=str,
            default=ServerArgs.ds_heavy_channel_type,
            help="The type of heavy channels in double sparsity attention",
        )
        parser.add_argument(
            "--ds-sparse-decode-threshold",
            type=int,
            default=ServerArgs.ds_sparse_decode_threshold,
            help="The type of heavy channels in double sparsity attention",
        )

903
        # Optimization/debug options
Liangsheng Yin's avatar
Liangsheng Yin committed
904
        parser.add_argument(
Lianmin Zheng's avatar
Lianmin Zheng committed
905
            "--disable-radix-cache",
Liangsheng Yin's avatar
Liangsheng Yin committed
906
            action="store_true",
Ying Sheng's avatar
Ying Sheng committed
907
            help="Disable RadixAttention for prefix caching.",
Liangsheng Yin's avatar
Liangsheng Yin committed
908
        )
909
910
911
912
913
        parser.add_argument(
            "--disable-cuda-graph",
            action="store_true",
            help="Disable cuda graph.",
        )
914
        parser.add_argument(
915
916
917
918
            "--disable-cuda-graph-padding",
            action="store_true",
            help="Disable cuda graph when padding is needed. Still uses cuda graph when padding is not needed.",
        )
919
920
921
922
923
        parser.add_argument(
            "--enable-nccl-nvls",
            action="store_true",
            help="Enable NCCL NVLS for prefill heavy requests when available.",
        )
924
        parser.add_argument(
925
            "--disable-outlines-disk-cache",
926
            action="store_true",
927
            help="Disable disk cache of outlines to avoid possible crashes related to file system or high concurrency.",
928
        )
929
930
931
932
933
        parser.add_argument(
            "--disable-custom-all-reduce",
            action="store_true",
            help="Disable the custom all-reduce kernel and fall back to NCCL.",
        )
Ke Bao's avatar
Ke Bao committed
934
935
936
        parser.add_argument(
            "--disable-mla",
            action="store_true",
Xiaoyu Zhang's avatar
Xiaoyu Zhang committed
937
            help="Disable Multi-head Latent Attention (MLA) for DeepSeek V2/V3/R1 series models.",
Ke Bao's avatar
Ke Bao committed
938
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
939
        parser.add_argument(
940
            "--disable-overlap-schedule",
Lianmin Zheng's avatar
Lianmin Zheng committed
941
            action="store_true",
942
            help="Disable the overlap scheduler, which overlaps the CPU scheduler with GPU model worker.",
Lianmin Zheng's avatar
Lianmin Zheng committed
943
        )
944
945
946
        parser.add_argument(
            "--enable-mixed-chunk",
            action="store_true",
947
            help="Enabling mixing prefill and decode in a batch when using chunked prefill.",
948
        )
Ke Bao's avatar
Ke Bao committed
949
950
951
952
953
        parser.add_argument(
            "--enable-dp-attention",
            action="store_true",
            help="Enabling data parallelism for attention and tensor parallelism for FFN. The dp size should be equal to the tp size. Currently only DeepSeek-V2 is supported.",
        )
xiaobochen's avatar
xiaobochen committed
954
955
956
957
958
        parser.add_argument(
            "--enable-ep-moe",
            action="store_true",
            help="Enabling expert parallelism for moe. The ep size is equal to the tp size.",
        )
959
960
961
        parser.add_argument(
            "--enable-torch-compile",
            action="store_true",
962
963
            help="Optimize the model with torch.compile. Experimental feature.",
        )
964
        parser.add_argument(
965
            "--torch-compile-max-bs",
966
            type=int,
967
            default=ServerArgs.torch_compile_max_bs,
968
969
            help="Set the maximum batch size when using torch compile.",
        )
970
        parser.add_argument(
971
            "--cuda-graph-max-bs",
972
            type=int,
973
            default=ServerArgs.cuda_graph_max_bs,
974
975
            help="Set the maximum batch size for cuda graph.",
        )
976
977
978
979
980
981
        parser.add_argument(
            "--cuda-graph-bs",
            type=int,
            nargs="+",
            help="Set the list of batch sizes for cuda graph.",
        )
982
983
984
985
        parser.add_argument(
            "--torchao-config",
            type=str,
            default=ServerArgs.torchao_config,
986
            help="Optimize the model with torchao. Experimental feature. Current choices are: int8dq, int8wo, int4wo-<group_size>, fp8wo, fp8dq-per_tensor, fp8dq-per_row",
987
        )
988
989
990
991
992
        parser.add_argument(
            "--enable-nan-detection",
            action="store_true",
            help="Enable the NaN detection for debugging purposes.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
993
        parser.add_argument(
994
            "--enable-p2p-check",
Lianmin Zheng's avatar
Lianmin Zheng committed
995
            action="store_true",
996
            help="Enable P2P check for GPU access, otherwise the p2p access is allowed by default.",
Lianmin Zheng's avatar
Lianmin Zheng committed
997
        )
998
        parser.add_argument(
999
            "--triton-attention-reduce-in-fp32",
1000
            action="store_true",
1001
            help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
1002
            "This only affects Triton attention kernels.",
1003
        )
1004
1005
1006
1007
1008
1009
        parser.add_argument(
            "--triton-attention-num-kv-splits",
            type=int,
            default=ServerArgs.triton_attention_num_kv_splits,
            help="The number of KV splits in flash decoding Triton kernel. Larger value is better in longer context scenarios. The default value is 8.",
        )
1010
1011
1012
1013
1014
1015
1016
1017
        parser.add_argument(
            "--num-continuous-decode-steps",
            type=int,
            default=ServerArgs.num_continuous_decode_steps,
            help="Run multiple continuous decoding steps to reduce scheduling overhead. "
            "This can potentially increase throughput but may also increase time-to-first-token latency. "
            "The default value is 1, meaning only run one decoding step at a time.",
        )
1018
1019
1020
1021
1022
        parser.add_argument(
            "--delete-ckpt-after-loading",
            action="store_true",
            help="Delete the model checkpoint after loading the model.",
        )
1023
1024
1025
1026
1027
        parser.add_argument(
            "--enable-memory-saver",
            action="store_true",
            help="Allow saving memory using release_memory_occupation and resume_memory_occupation",
        )
1028
1029
1030
1031
1032
        parser.add_argument(
            "--allow-auto-truncate",
            action="store_true",
            help="Allow automatically truncating requests that exceed the maximum input length instead of returning an error.",
        )
1033
1034
1035
1036
1037
        parser.add_argument(
            "--enable-custom-logit-processor",
            action="store_true",
            help="Enable users to pass custom logit processors to the server (disabled by default for security)",
        )
YAMY's avatar
YAMY committed
1038
1039
1040
1041
1042
1043
1044
        parser.add_argument(
            "--tool-call-parser",
            type=str,
            choices=["qwen25", "mistral", "llama3"],
            default=ServerArgs.tool_call_parser,
            help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', and 'llama3'.",
        )
1045
1046
1047
1048
1049
        parser.add_argument(
            "--enable-hierarchical-cache",
            action="store_true",
            help="Enable hierarchical cache",
        )
1050
1051
1052
1053
1054
1055
1056
        parser.add_argument(
            "--hicache-ratio",
            type=float,
            required=False,
            default=ServerArgs.hicache_ratio,
            help="The ratio of the size of host KV cache memory pool to the size of device pool.",
        )
1057
1058
1059
1060
1061
        parser.add_argument(
            "--enable-deepep-moe",
            action="store_true",
            help="Enabling DeepEP MoE implementation for EP MoE.",
        )
1062

1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
        # Server warmups
        parser.add_argument(
            "--warmups",
            type=str,
            required=False,
            help="Specify custom warmup functions (csv) to run before server starts eg. --warmups=warmup_name1,warmup_name2 "
            "will run the functions `warmup_name1` and `warmup_name2` specified in warmup.py before the server starts listening for requests",
        )

        # Debug tensor dumps
        parser.add_argument(
            "--debug-tensor-dump-output-folder",
            type=str,
            default=ServerArgs.debug_tensor_dump_output_folder,
            help="The output folder for dumping tensors.",
        )
        parser.add_argument(
            "--debug-tensor-dump-input-file",
            type=str,
            default=ServerArgs.debug_tensor_dump_input_file,
            help="The input filename for dumping tensors",
        )
        parser.add_argument(
            "--debug-tensor-dump-inject",
            type=str,
            default=ServerArgs.debug_tensor_dump_inject,
            help="Inject the outputs from jax as the input of every layer.",
        )

Byron Hsu's avatar
Byron Hsu committed
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
        # Disaggregation
        parser.add_argument(
            "--disaggregation-mode",
            type=str,
            default="null",
            choices=["null", "prefill", "decode"],
            help='Only used for PD disaggregation. "prefill" for prefill-only server, and "decode" for decode-only server. If not specified, it is not PD disaggregated',
        )
        parser.add_argument(
            "--disaggregation-bootstrap-port",
            type=int,
            default=ServerArgs.disaggregation_bootstrap_port,
            help="Bootstrap server port on the prefill server. Default is 8998.",
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
1107
1108
    @classmethod
    def from_cli_args(cls, args: argparse.Namespace):
1109
1110
        args.tp_size = args.tensor_parallel_size
        args.dp_size = args.data_parallel_size
xiaobochen's avatar
xiaobochen committed
1111
        args.ep_size = args.expert_parallel_size
Lianmin Zheng's avatar
Lianmin Zheng committed
1112
1113
1114
1115
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        return cls(**{attr: getattr(args, attr) for attr in attrs})

    def url(self):
1116
        if is_valid_ipv6_address(self.host):
1117
1118
1119
            return f"http://[{self.host}]:{self.port}"
        else:
            return f"http://{self.host}:{self.port}"
Lianmin Zheng's avatar
Lianmin Zheng committed
1120

1121
1122
1123
1124
1125
    def check_server_args(self):
        assert (
            self.tp_size % self.nnodes == 0
        ), "tp_size must be divisible by number of nodes"
        assert not (
1126
1127
            self.dp_size > 1 and self.nnodes != 1 and not self.enable_dp_attention
        ), "multi-node data parallel is not supported unless dp attention!"
1128
1129
1130
1131
1132
1133
        assert (
            self.max_loras_per_batch > 0
            # FIXME
            and (self.lora_paths is None or self.disable_cuda_graph)
            and (self.lora_paths is None or self.disable_radix_cache)
        ), "compatibility of lora and cuda graph and radix attention is in progress"
1134
        assert self.base_gpu_id >= 0, "base_gpu_id must be non-negative"
1135
        assert self.gpu_id_step >= 1, "gpu_id_step must be positive"
1136

1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
        if isinstance(self.lora_paths, list):
            lora_paths = self.lora_paths
            self.lora_paths = {}
            for lora_path in lora_paths:
                if "=" in lora_path:
                    name, path = lora_path.split("=", 1)
                    self.lora_paths[name] = path
                else:
                    self.lora_paths[lora_path] = lora_path

Lianmin Zheng's avatar
Lianmin Zheng committed
1147

Lianmin Zheng's avatar
Lianmin Zheng committed
1148
def prepare_server_args(argv: List[str]) -> ServerArgs:
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
    """
    Prepare the server arguments from the command line arguments.

    Args:
        args: The command line arguments. Typically, it should be `sys.argv[1:]`
            to ensure compatibility with `parse_args` when no arguments are passed.

    Returns:
        The server arguments.
    """
    parser = argparse.ArgumentParser()
    ServerArgs.add_cli_args(parser)
Lianmin Zheng's avatar
Lianmin Zheng committed
1161
    raw_args = parser.parse_args(argv)
1162
1163
1164
1165
    server_args = ServerArgs.from_cli_args(raw_args)
    return server_args


1166
1167
1168
ZMQ_TCP_PORT_DELTA = 233


Lianmin Zheng's avatar
Lianmin Zheng committed
1169
1170
@dataclasses.dataclass
class PortArgs:
1171
1172
1173
1174
1175
1176
    # The ipc filename for tokenizer to receive inputs from detokenizer (zmq)
    tokenizer_ipc_name: str
    # The ipc filename for scheduler (rank 0) to receive inputs from tokenizer (zmq)
    scheduler_input_ipc_name: str
    # The ipc filename for detokenizer to receive inputs from scheduler (zmq)
    detokenizer_ipc_name: str
1177

1178
1179
    # The port for nccl initialization (torch.dist)
    nccl_port: int
1180

1181
1182
1183
    # The ipc filename for rpc call between Engine and Scheduler
    rpc_ipc_name: str

1184
    @staticmethod
1185
    def init_new(server_args, dp_rank: Optional[int] = None) -> "PortArgs":
1186
        port = server_args.port + random.randint(100, 1000)
1187
1188
1189
        while True:
            if is_port_available(port):
                break
TianYu GUO's avatar
TianYu GUO committed
1190
1191
1192
1193
            if port < 60000:
                port += 42
            else:
                port -= 43
1194

1195
1196
1197
1198
1199
1200
1201
        if not server_args.enable_dp_attention:
            # Normal case, use IPC within a single node
            return PortArgs(
                tokenizer_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
                scheduler_input_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
                detokenizer_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
                nccl_port=port,
1202
                rpc_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
            )
        else:
            # DP attention. Use TCP + port to handle both single-node and multi-node.
            if server_args.nnodes == 1 and server_args.dist_init_addr is None:
                dist_init_addr = ("127.0.0.1", server_args.port + ZMQ_TCP_PORT_DELTA)
            else:
                dist_init_addr = server_args.dist_init_addr.split(":")
            assert (
                len(dist_init_addr) == 2
            ), "please provide --dist-init-addr as host:port of head node"

            dist_init_host, dist_init_port = dist_init_addr
            port_base = int(dist_init_port) + 1
            if dp_rank is None:
                scheduler_input_port = (
                    port_base + 2
1219
                )  # TokenizerManager to DataParallelController
1220
1221
1222
1223
1224
1225
1226
1227
            else:
                scheduler_input_port = port_base + 2 + 1 + dp_rank

            return PortArgs(
                tokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base}",
                scheduler_input_ipc_name=f"tcp://{dist_init_host}:{scheduler_input_port}",
                detokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base + 1}",
                nccl_port=port,
1228
                rpc_ipc_name=f"tcp://{dist_init_host}:{port_base + 2}",
1229
            )
1230

1231
1232
1233
1234
1235
1236
1237
1238
1239
1240

class LoRAPathAction(argparse.Action):
    def __call__(self, parser, namespace, values, option_string=None):
        setattr(namespace, self.dest, {})
        for lora_path in values:
            if "=" in lora_path:
                name, path = lora_path.split("=", 1)
                getattr(namespace, self.dest)[name] = path
            else:
                getattr(namespace, self.dest)[lora_path] = lora_path
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250


class DeprecatedAction(argparse.Action):
    def __init__(self, option_strings, dest, nargs=0, **kwargs):
        super(DeprecatedAction, self).__init__(
            option_strings, dest, nargs=nargs, **kwargs
        )

    def __call__(self, parser, namespace, values, option_string=None):
        raise ValueError(self.help)