"vscode:/vscode.git/clone" did not exist on "3d64fda376d34c573a37572450a59e6cb45f4cd5"
server_args.py 39.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
Lianmin Zheng's avatar
Lianmin Zheng committed
14
15
"""The arguments of the server."""

Lianmin Zheng's avatar
Lianmin Zheng committed
16
17
import argparse
import dataclasses
18
import logging
19
import random
20
21
import tempfile
from typing import List, Optional
Lianmin Zheng's avatar
Lianmin Zheng committed
22

23
24
import torch

25
from sglang.srt.hf_transformers_utils import check_gguf_file
26
from sglang.srt.utils import (
HAI's avatar
HAI committed
27
    get_amdgpu_memory_capacity,
28
    get_hpu_memory_capacity,
HAI's avatar
HAI committed
29
    get_nvgpu_memory_capacity,
30
    is_flashinfer_available,
HAI's avatar
HAI committed
31
    is_hip,
32
    is_port_available,
33
    is_valid_ipv6_address,
bjmsong's avatar
bjmsong committed
34
    nullable_str,
35
)
36

37
38
logger = logging.getLogger(__name__)

Lianmin Zheng's avatar
Lianmin Zheng committed
39
40
41

@dataclasses.dataclass
class ServerArgs:
Lianmin Zheng's avatar
Lianmin Zheng committed
42
    # Model and tokenizer
Lianmin Zheng's avatar
Lianmin Zheng committed
43
44
45
    model_path: str
    tokenizer_path: Optional[str] = None
    tokenizer_mode: str = "auto"
Lianmin Zheng's avatar
Lianmin Zheng committed
46
    load_format: str = "auto"
47
    trust_remote_code: bool = True
Lianmin Zheng's avatar
Lianmin Zheng committed
48
    dtype: str = "auto"
49
    kv_cache_dtype: str = "auto"
bjmsong's avatar
bjmsong committed
50
    quantization_param_path: nullable_str = None
Lianmin Zheng's avatar
Lianmin Zheng committed
51
    quantization: Optional[str] = None
52
53
    context_length: Optional[int] = None
    device: str = "cuda"
54
    served_model_name: Optional[str] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
55
    chat_template: Optional[str] = None
56
    is_embedding: bool = False
57
    revision: Optional[str] = None
58
    skip_tokenizer_init: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
59

60
    # Port for the HTTP server
Lianmin Zheng's avatar
Lianmin Zheng committed
61
62
63
64
    host: str = "127.0.0.1"
    port: int = 30000

    # Memory and scheduling
Lianmin Zheng's avatar
Lianmin Zheng committed
65
    mem_fraction_static: Optional[float] = None
66
    max_running_requests: Optional[int] = None
67
    max_total_tokens: Optional[int] = None
68
    chunked_prefill_size: Optional[int] = None
69
    max_prefill_tokens: int = 16384
70
    schedule_policy: str = "lpm"
71
    schedule_conservativeness: float = 1.0
72
    cpu_offload_gb: int = 0
73
    prefill_only_one_req: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
74
75
76

    # Other runtime options
    tp_size: int = 1
77
    stream_interval: int = 1
78
    stream_output: bool = False
79
    random_seed: Optional[int] = None
80
    constrained_json_whitespace_pattern: Optional[str] = None
81
    watchdog_timeout: float = 300
82
    download_dir: Optional[str] = None
83
    base_gpu_id: int = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
84
85
86

    # Logging
    log_level: str = "info"
87
    log_level_http: Optional[str] = None
88
    log_requests: bool = False
Liangsheng Yin's avatar
Liangsheng Yin committed
89
    show_time_cost: bool = False
90
    enable_metrics: bool = False
91
    decode_log_interval: int = 40
Liangsheng Yin's avatar
Liangsheng Yin committed
92

93
    # API related
94
    api_key: Optional[str] = None
95
    file_storage_pth: str = "sglang_storage"
96
    enable_cache_report: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
97

98
99
100
    # Data parallelism
    dp_size: int = 1
    load_balance_method: str = "round_robin"
101

xiaobochen's avatar
xiaobochen committed
102
103
    # Expert parallelism
    ep_size: int = 1
104

105
    # Multi-node distributed serving
106
    dist_init_addr: Optional[str] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
107
    nnodes: int = 1
108
    node_rank: int = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
109
110
111
112

    # Model override args in JSON
    json_model_override_args: str = "{}"

113
114
115
116
117
    # LoRA
    lora_paths: Optional[List[str]] = None
    max_loras_per_batch: int = 8

    # Kernel backend
118
119
    attention_backend: Optional[str] = None
    sampling_backend: Optional[str] = None
120
    grammar_backend: Optional[str] = "outlines"
121

122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
    # Speculative decoding
    speculative_draft_model_path: Optional[str] = None
    speculative_algorithm: Optional[str] = None
    speculative_num_steps: int = 5
    speculative_num_draft_tokens: int = 64
    speculative_eagle_topk: int = 8

    # Double Sparsity
    enable_double_sparsity: bool = False
    ds_channel_config_path: str = None
    ds_heavy_channel_num: int = 32
    ds_heavy_token_num: int = 256
    ds_heavy_channel_type: str = "qk"
    ds_sparse_decode_threshold: int = 4096

137
    # Optimization/debug options
Lianmin Zheng's avatar
Lianmin Zheng committed
138
    disable_radix_cache: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
139
    disable_jump_forward: bool = False
140
    disable_cuda_graph: bool = False
141
    disable_cuda_graph_padding: bool = False
142
    disable_outlines_disk_cache: bool = False
143
    disable_custom_all_reduce: bool = False
Ke Bao's avatar
Ke Bao committed
144
    disable_mla: bool = False
145
    disable_overlap_schedule: bool = False
146
    enable_mixed_chunk: bool = False
Ke Bao's avatar
Ke Bao committed
147
    enable_dp_attention: bool = False
xiaobochen's avatar
xiaobochen committed
148
    enable_ep_moe: bool = False
149
    enable_torch_compile: bool = False
150
    torch_compile_max_bs: int = 32
151
    cuda_graph_max_bs: Optional[int] = None
152
    cuda_graph_bs: Optional[List[int]] = None
153
    torchao_config: str = ""
154
    enable_nan_detection: bool = False
155
    enable_p2p_check: bool = False
156
    triton_attention_reduce_in_fp32: bool = False
157
    triton_attention_num_kv_splits: int = 8
158
    num_continuous_decode_steps: int = 1
159
    delete_ckpt_after_loading: bool = False
160
    enable_memory_saver: bool = False
161
    allow_auto_truncate: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
162

163
164
    # Custom logit processor
    enable_custom_logit_processor: bool = False
YAMY's avatar
YAMY committed
165
    tool_call_parser: str = None
166
    enable_hierarchical_cache: bool = False
167

Lianmin Zheng's avatar
Lianmin Zheng committed
168
    def __post_init__(self):
169
        # Set missing default values
Lianmin Zheng's avatar
Lianmin Zheng committed
170
171
        if self.tokenizer_path is None:
            self.tokenizer_path = self.model_path
172
173
174
175

        if self.served_model_name is None:
            self.served_model_name = self.model_path

176
177
178
        if self.random_seed is None:
            self.random_seed = random.randint(0, 1 << 30)

179
180
        if is_hip():
            gpu_mem = get_amdgpu_memory_capacity()
181
        elif torch.cuda.is_available():
182
            gpu_mem = get_nvgpu_memory_capacity()
183
184
        elif self.device == "hpu":
            gpu_mem = get_hpu_memory_capacity()
185
186
187
        else:
            # GPU memory is not known yet or no GPU is available.
            gpu_mem = None
188
189

        # Set mem fraction static, which depends on the tensor parallelism size
Lianmin Zheng's avatar
Lianmin Zheng committed
190
        if self.mem_fraction_static is None:
191
            if self.tp_size >= 16:
192
                self.mem_fraction_static = 0.79
193
            elif self.tp_size >= 8:
194
                self.mem_fraction_static = 0.81
Lianmin Zheng's avatar
Lianmin Zheng committed
195
            elif self.tp_size >= 4:
196
                self.mem_fraction_static = 0.85
Lianmin Zheng's avatar
Lianmin Zheng committed
197
            elif self.tp_size >= 2:
198
                self.mem_fraction_static = 0.87
Ying Sheng's avatar
Ying Sheng committed
199
            else:
200
                self.mem_fraction_static = 0.88
201

202
203
        # Set chunked prefill size, which depends on the gpu memory capacity
        if self.chunked_prefill_size is None:
204
            if gpu_mem is not None and gpu_mem < 25_000:
205
206
207
                self.chunked_prefill_size = 2048
            else:
                self.chunked_prefill_size = 8192
208

209
210
        # Set cuda graph max batch size
        if self.cuda_graph_max_bs is None:
211
            # Based on detailed statistics, when serving TP1/TP2 models on lower-end GPUs with HBM<25G, you can either disable cuda graph or set `cuda_graph_max_bs` to a very small value to reduce the memory overhead of creating cuda graphs, with almost no impact on performance. However, when serving models with TP4 or TP8, we need to enable cuda graph to maintain high performance. In this case, we can set `cuda_graph_max_bs` to 80 (half of the default value 160) to reduce the memory overhead of creating cuda graphs. Looking at the logs from TP4 serving of qwen2-72b, a value of 80 is sufficient and can reduce the memory overhead of creating cuda graphs on lower-end GPUs compared to the original 160, avoiding OOM issues.
212
            if gpu_mem is not None and gpu_mem < 25_000:
213
214
215
216
                if self.tp_size < 4:
                    self.cuda_graph_max_bs = 8
                else:
                    self.cuda_graph_max_bs = 80
217
218
            else:
                self.cuda_graph_max_bs = 160
219

220
        # Choose kernel backends
221
222
223
224
        if self.device == "hpu":
            self.attention_backend = "torch_native"
            self.sampling_backend = "pytorch"

225
        if self.attention_backend is None:
226
227
228
            self.attention_backend = (
                "flashinfer" if is_flashinfer_available() else "triton"
            )
229
        if self.sampling_backend is None:
230
231
232
233
234
            self.sampling_backend = (
                "flashinfer" if is_flashinfer_available() else "pytorch"
            )

        if self.attention_backend == "torch_native":
235
            logger.warning(
236
237
238
                "Cuda graph is disabled because of using torch native attention backend"
            )
            self.disable_cuda_graph = True
239

240
241
242
243
244
245
246
        # Expert parallelism
        if self.enable_ep_moe:
            self.ep_size = self.tp_size
            logger.info(
                f"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
            )

247
        # Others
Ke Bao's avatar
Ke Bao committed
248
249
        if self.enable_dp_attention:
            self.dp_size = self.tp_size
250
            assert self.tp_size % self.dp_size == 0
Ke Bao's avatar
Ke Bao committed
251
            self.chunked_prefill_size = self.chunked_prefill_size // 2
252
            self.schedule_conservativeness = self.schedule_conservativeness * 0.3
253
            logger.warning(
254
                f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE kernel issues. "
255
                f"The schedule conservativeness is adjusted to {self.schedule_conservativeness}. "
256
257
258
                "Data parallel size is adjusted to be the same as tensor parallel size. "
            )

259
260
261
262
263
264
265
266
267
268
269
        # Speculative Decoding
        if self.speculative_algorithm == "EAGLE":
            self.prefill_only_one_req = True
            self.disable_cuda_graph_padding = True
            self.disable_radix_cache = True
            self.disable_overlap_schedule = True
            self.chunked_prefill_size = -1
            logger.info(
                "The radix cache, chunked prefill, and overlap scheduler are disabled because of using eagle speculative decoding."
            )

270
271
272
273
274
275
        # GGUF
        if (
            self.load_format == "auto" or self.load_format == "gguf"
        ) and check_gguf_file(self.model_path):
            self.quantization = self.load_format = "gguf"

Lianmin Zheng's avatar
Lianmin Zheng committed
276
277
    @staticmethod
    def add_cli_args(parser: argparse.ArgumentParser):
278
        # Model and port args
Lianmin Zheng's avatar
Lianmin Zheng committed
279
280
281
282
283
284
285
286
287
288
289
290
        parser.add_argument(
            "--model-path",
            type=str,
            help="The path of the model weights. This can be a local folder or a Hugging Face repo ID.",
            required=True,
        )
        parser.add_argument(
            "--tokenizer-path",
            type=str,
            default=ServerArgs.tokenizer_path,
            help="The path of the tokenizer.",
        )
Yuanhan Zhang's avatar
Yuanhan Zhang committed
291
292
293
294
295
296
        parser.add_argument(
            "--host", type=str, default=ServerArgs.host, help="The host of the server."
        )
        parser.add_argument(
            "--port", type=int, default=ServerArgs.port, help="The port of the server."
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
297
298
299
300
301
302
303
304
305
        parser.add_argument(
            "--tokenizer-mode",
            type=str,
            default=ServerArgs.tokenizer_mode,
            choices=["auto", "slow"],
            help="Tokenizer mode. 'auto' will use the fast "
            "tokenizer if available, and 'slow' will "
            "always use the slow tokenizer.",
        )
306
307
308
309
310
        parser.add_argument(
            "--skip-tokenizer-init",
            action="store_true",
            help="If set, skip init tokenizer and pass input_ids in generate request",
        )
311
        parser.add_argument(
Lianmin Zheng's avatar
Lianmin Zheng committed
312
313
314
            "--load-format",
            type=str,
            default=ServerArgs.load_format,
315
316
317
318
319
320
321
322
            choices=[
                "auto",
                "pt",
                "safetensors",
                "npcache",
                "dummy",
                "gguf",
                "bitsandbytes",
323
                "layered",
324
            ],
Lianmin Zheng's avatar
Lianmin Zheng committed
325
326
327
328
329
330
331
332
333
            help="The format of the model weights to load. "
            '"auto" will try to load the weights in the safetensors format '
            "and fall back to the pytorch bin format if safetensors format "
            "is not available. "
            '"pt" will load the weights in the pytorch bin format. '
            '"safetensors" will load the weights in the safetensors format. '
            '"npcache" will load the weights in pytorch format and store '
            "a numpy cache to speed up the loading. "
            '"dummy" will initialize the weights with random values, '
334
            "which is mainly for profiling."
335
336
            '"gguf" will load the weights in the gguf format. '
            '"bitsandbytes" will load the weights using bitsandbytes '
337
338
339
340
            "quantization."
            '"layered" loads weights layer by layer so that one can quantize a '
            "layer before loading another to make the peak memory envelope "
            "smaller.",
Lianmin Zheng's avatar
Lianmin Zheng committed
341
        )
342
343
344
345
346
        parser.add_argument(
            "--trust-remote-code",
            action="store_true",
            help="Whether or not to allow for custom models defined on the Hub in their own modeling files.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
347
        parser.add_argument(
Lianmin Zheng's avatar
Lianmin Zheng committed
348
            "--dtype",
Cody Yu's avatar
Cody Yu committed
349
            type=str,
Lianmin Zheng's avatar
Lianmin Zheng committed
350
            default=ServerArgs.dtype,
Ying Sheng's avatar
Ying Sheng committed
351
352
            choices=["auto", "half", "float16", "bfloat16", "float", "float32"],
            help="Data type for model weights and activations.\n\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
353
            '* "auto" will use FP16 precision for FP32 and FP16 models, and '
Ying Sheng's avatar
Ying Sheng committed
354
            "BF16 precision for BF16 models.\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
355
356
357
358
            '* "half" for FP16. Recommended for AWQ quantization.\n'
            '* "float16" is the same as "half".\n'
            '* "bfloat16" for a balance between precision and range.\n'
            '* "float" is shorthand for FP32 precision.\n'
Ying Sheng's avatar
Ying Sheng committed
359
360
            '* "float32" for FP32 precision.',
        )
361
362
363
364
        parser.add_argument(
            "--kv-cache-dtype",
            type=str,
            default=ServerArgs.kv_cache_dtype,
bjmsong's avatar
bjmsong committed
365
366
367
368
369
370
371
372
373
374
375
            choices=["auto", "fp8_e5m2", "fp8_e4m3"],
            help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" and "fp8_e4m3" is supported for CUDA 11.8+.',
        )
        parser.add_argument(
            "--quantization-param-path",
            type=nullable_str,
            default=None,
            help="Path to the JSON file containing the KV cache "
            "scaling factors. This should generally be supplied, when "
            "KV cache dtype is FP8. Otherwise, KV cache scaling factors "
            "default to 1.0, which may cause accuracy issues. ",
376
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
377
378
379
380
        parser.add_argument(
            "--quantization",
            type=str,
            default=ServerArgs.quantization,
Ying Sheng's avatar
Ying Sheng committed
381
382
383
384
385
386
            choices=[
                "awq",
                "fp8",
                "gptq",
                "marlin",
                "gptq_marlin",
Ying Sheng's avatar
Ying Sheng committed
387
                "awq_marlin",
Ying Sheng's avatar
Ying Sheng committed
388
                "bitsandbytes",
389
                "gguf",
390
                "modelopt",
391
                "w8a8_int8",
Ying Sheng's avatar
Ying Sheng committed
392
            ],
Lianmin Zheng's avatar
Lianmin Zheng committed
393
394
            help="The quantization method.",
        )
395
396
397
398
399
400
401
402
403
404
        parser.add_argument(
            "--context-length",
            type=int,
            default=ServerArgs.context_length,
            help="The model's maximum context length. Defaults to None (will use the value from the model's config.json instead).",
        )
        parser.add_argument(
            "--device",
            type=str,
            default="cuda",
405
            choices=["cuda", "xpu", "hpu", "cpu"],
406
407
            help="The device type.",
        )
408
409
410
411
412
413
        parser.add_argument(
            "--served-model-name",
            type=str,
            default=ServerArgs.served_model_name,
            help="Override the model name returned by the v1/models endpoint in OpenAI API server.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
414
415
416
417
418
419
        parser.add_argument(
            "--chat-template",
            type=str,
            default=ServerArgs.chat_template,
            help="The buliltin chat template name or the path of the chat template file. This is only used for OpenAI-compatible API server.",
        )
420
421
422
423
424
        parser.add_argument(
            "--is-embedding",
            action="store_true",
            help="Whether to use a CausalLM as an embedding model.",
        )
425
426
427
428
429
430
431
432
        parser.add_argument(
            "--revision",
            type=str,
            default=None,
            help="The specific model version to use. It can be a branch "
            "name, a tag name, or a commit id. If unspecified, will use "
            "the default version.",
        )
433
        # Memory and scheduling
Lianmin Zheng's avatar
Lianmin Zheng committed
434
435
436
437
        parser.add_argument(
            "--mem-fraction-static",
            type=float,
            default=ServerArgs.mem_fraction_static,
438
            help="The fraction of the memory used for static allocation (model weights and KV cache memory pool). Use a smaller value if you see out-of-memory errors.",
Lianmin Zheng's avatar
Lianmin Zheng committed
439
        )
440
441
442
443
444
445
        parser.add_argument(
            "--max-running-requests",
            type=int,
            default=ServerArgs.max_running_requests,
            help="The maximum number of running requests.",
        )
446
447
448
449
        parser.add_argument(
            "--max-total-tokens",
            type=int,
            default=ServerArgs.max_total_tokens,
450
451
            help="The maximum number of tokens in the memory pool. If not specified, it will be automatically calculated based on the memory usage fraction. "
            "This option is typically used for development and debugging purposes.",
452
        )
453
454
455
456
457
458
459
460
461
462
463
464
        parser.add_argument(
            "--chunked-prefill-size",
            type=int,
            default=ServerArgs.chunked_prefill_size,
            help="The maximum number of tokens in a chunk for the chunked prefill. Setting this to -1 means disabling chunked prefill",
        )
        parser.add_argument(
            "--max-prefill-tokens",
            type=int,
            default=ServerArgs.max_prefill_tokens,
            help="The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model's maximum context length.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
465
        parser.add_argument(
466
            "--schedule-policy",
Lianmin Zheng's avatar
Lianmin Zheng committed
467
            type=str,
468
            default=ServerArgs.schedule_policy,
Liangsheng Yin's avatar
Liangsheng Yin committed
469
            choices=["lpm", "random", "fcfs", "dfs-weight"],
470
            help="The scheduling policy of the requests.",
Lianmin Zheng's avatar
Lianmin Zheng committed
471
        )
472
473
474
475
        parser.add_argument(
            "--schedule-conservativeness",
            type=float,
            default=ServerArgs.schedule_conservativeness,
476
            help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see requests being retracted frequently.",
477
        )
478
479
480
481
482
483
        parser.add_argument(
            "--cpu-offload-gb",
            type=int,
            default=ServerArgs.cpu_offload_gb,
            help="How many GBs of RAM to reserve for CPU offloading",
        )
484
485
486
487
488
489
        parser.add_argument(
            "--prefill-only-one-req",
            type=bool,
            help="If true, we only prefill one request at one prefill batch",
            default=ServerArgs.prefill_only_one_req,
        )
490

491
        # Other runtime options
Lianmin Zheng's avatar
Lianmin Zheng committed
492
        parser.add_argument(
493
            "--tensor-parallel-size",
Lianmin Zheng's avatar
Lianmin Zheng committed
494
            "--tp-size",
Lianmin Zheng's avatar
Lianmin Zheng committed
495
            type=int,
Lianmin Zheng's avatar
Lianmin Zheng committed
496
            default=ServerArgs.tp_size,
497
            help="The tensor parallelism size.",
498
        )
499
500
501
        parser.add_argument(
            "--stream-interval",
            type=int,
Lianmin Zheng's avatar
Lianmin Zheng committed
502
            default=ServerArgs.stream_interval,
503
            help="The interval (or buffer size) for streaming in terms of the token length. A smaller value makes streaming smoother, while a larger value makes the throughput higher",
504
        )
505
506
507
508
509
        parser.add_argument(
            "--stream-output",
            action="store_true",
            help="Whether to output as a sequence of disjoint segments.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
510
511
512
513
        parser.add_argument(
            "--random-seed",
            type=int,
            default=ServerArgs.random_seed,
514
            help="The random seed.",
Lianmin Zheng's avatar
Lianmin Zheng committed
515
        )
516
517
518
519
520
521
        parser.add_argument(
            "--constrained-json-whitespace-pattern",
            type=str,
            default=ServerArgs.constrained_json_whitespace_pattern,
            help=r"Regex pattern for syntactic whitespaces allowed in JSON constrained output. For example, to allow the model generate consecutive whitespaces, set the pattern to [\n\t ]*",
        )
522
523
524
525
526
527
        parser.add_argument(
            "--watchdog-timeout",
            type=float,
            default=ServerArgs.watchdog_timeout,
            help="Set watchdog timeout in seconds. If a forward batch takes longer than this, the server will crash to prevent hanging.",
        )
528
529
530
531
532
533
        parser.add_argument(
            "--download-dir",
            type=str,
            default=ServerArgs.download_dir,
            help="Model download directory.",
        )
534
535
536
537
538
539
        parser.add_argument(
            "--base-gpu-id",
            type=int,
            default=ServerArgs.base_gpu_id,
            help="The base GPU ID to start allocating GPUs from. Useful when running multiple instances on the same machine.",
        )
540
541

        # Logging
Lianmin Zheng's avatar
Lianmin Zheng committed
542
543
544
545
        parser.add_argument(
            "--log-level",
            type=str,
            default=ServerArgs.log_level,
546
            help="The logging level of all loggers.",
Lianmin Zheng's avatar
Lianmin Zheng committed
547
        )
548
        parser.add_argument(
549
550
551
552
            "--log-level-http",
            type=str,
            default=ServerArgs.log_level_http,
            help="The logging level of HTTP server. If not set, reuse --log-level by default.",
553
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
554
        parser.add_argument(
555
            "--log-requests",
Lianmin Zheng's avatar
Lianmin Zheng committed
556
            action="store_true",
557
            help="Log the inputs and outputs of all requests.",
Lianmin Zheng's avatar
Lianmin Zheng committed
558
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
559
560
561
        parser.add_argument(
            "--show-time-cost",
            action="store_true",
Ying Sheng's avatar
Ying Sheng committed
562
            help="Show time cost of custom marks.",
Lianmin Zheng's avatar
Lianmin Zheng committed
563
        )
564
565
566
567
568
        parser.add_argument(
            "--enable-metrics",
            action="store_true",
            help="Enable log prometheus metrics.",
        )
569
570
571
572
        parser.add_argument(
            "--decode-log-interval",
            type=int,
            default=ServerArgs.decode_log_interval,
573
            help="The log interval of decode batch.",
574
        )
575

576
        # API related
Liangsheng Yin's avatar
Liangsheng Yin committed
577
578
579
580
        parser.add_argument(
            "--api-key",
            type=str,
            default=ServerArgs.api_key,
581
            help="Set API key of the server. It is also used in the OpenAI API compatible server.",
Liangsheng Yin's avatar
Liangsheng Yin committed
582
        )
583
584
585
586
587
588
        parser.add_argument(
            "--file-storage-pth",
            type=str,
            default=ServerArgs.file_storage_pth,
            help="The path of the file storage in backend.",
        )
589
590
591
592
593
        parser.add_argument(
            "--enable-cache-report",
            action="store_true",
            help="Return number of cached tokens in usage.prompt_tokens_details for each openai request.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
594

595
596
        # Data parallelism
        parser.add_argument(
597
            "--data-parallel-size",
598
599
600
            "--dp-size",
            type=int,
            default=ServerArgs.dp_size,
601
            help="The data parallelism size.",
602
603
604
605
606
        )
        parser.add_argument(
            "--load-balance-method",
            type=str,
            default=ServerArgs.load_balance_method,
607
            help="The load balancing strategy for data parallelism.",
608
609
610
611
612
            choices=[
                "round_robin",
                "shortest_queue",
            ],
        )
613

xiaobochen's avatar
xiaobochen committed
614
615
616
617
618
619
620
621
        # Expert parallelism
        parser.add_argument(
            "--expert-parallel-size",
            "--ep-size",
            type=int,
            default=ServerArgs.ep_size,
            help="The expert parallelism size.",
        )
622

623
        # Multi-node distributed serving
624
        parser.add_argument(
625
626
            "--dist-init-addr",
            "--nccl-init-addr",  # For backward compatbility. This will be removed in the future.
627
            type=str,
628
            help="The host address for initializing distributed backend (e.g., `192.168.0.2:25000`).",
629
630
        )
        parser.add_argument(
Liangsheng Yin's avatar
Liangsheng Yin committed
631
            "--nnodes", type=int, default=ServerArgs.nnodes, help="The number of nodes."
632
        )
633
634
635
        parser.add_argument(
            "--node-rank", type=int, default=ServerArgs.node_rank, help="The node rank."
        )
636

Lianmin Zheng's avatar
Lianmin Zheng committed
637
638
639
640
641
642
643
644
        # Model override args
        parser.add_argument(
            "--json-model-override-args",
            type=str,
            help="A dictionary in JSON string format used to override default model configurations.",
            default=ServerArgs.json_model_override_args,
        )

645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
        # LoRA
        parser.add_argument(
            "--lora-paths",
            type=str,
            nargs="*",
            default=None,
            action=LoRAPathAction,
            help="The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}",
        )
        parser.add_argument(
            "--max-loras-per-batch",
            type=int,
            default=8,
            help="Maximum number of adapters for a running batch, include base-only request",
        )

        # Kernel backend
662
663
664
        parser.add_argument(
            "--attention-backend",
            type=str,
665
            choices=["flashinfer", "triton", "torch_native"],
666
667
668
669
670
671
672
673
674
675
            default=ServerArgs.attention_backend,
            help="Choose the kernels for attention layers.",
        )
        parser.add_argument(
            "--sampling-backend",
            type=str,
            choices=["flashinfer", "pytorch"],
            default=ServerArgs.sampling_backend,
            help="Choose the kernels for sampling layers.",
        )
676
677
678
679
680
        parser.add_argument(
            "--grammar-backend",
            type=str,
            choices=["xgrammar", "outlines"],
            default=ServerArgs.grammar_backend,
Lianmin Zheng's avatar
Lianmin Zheng committed
681
            help="Choose the backend for grammar-guided decoding.",
682
        )
683

684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
        # Speculative decoding
        parser.add_argument(
            "--speculative-algorithm",
            type=str,
            choices=["EAGLE"],
            help="Speculative algorithm.",
        )
        parser.add_argument(
            "--speculative-draft-model-path",
            type=str,
            help="The path of the draft model weights. This can be a local folder or a Hugging Face repo ID.",
        )
        parser.add_argument(
            "--speculative-num-steps",
            type=int,
            help="The number of steps sampled from draft model in Speculative Decoding.",
            default=ServerArgs.speculative_num_steps,
        )
        parser.add_argument(
            "--speculative-num-draft-tokens",
            type=int,
            help="The number of token sampled from draft model in Speculative Decoding.",
            default=ServerArgs.speculative_num_draft_tokens,
        )
        parser.add_argument(
            "--speculative-eagle-topk",
            type=int,
            help="The number of token sampled from draft model in eagle2 each step.",
            choices=[1, 2, 4, 8],
            default=ServerArgs.speculative_eagle_topk,
        )

        # Double Sparsity
        parser.add_argument(
            "--enable-double-sparsity",
            action="store_true",
            help="Enable double sparsity attention",
        )
        parser.add_argument(
            "--ds-channel-config-path",
            type=str,
            default=ServerArgs.ds_channel_config_path,
            help="The path of the double sparsity channel config",
        )
        parser.add_argument(
            "--ds-heavy-channel-num",
            type=int,
            default=ServerArgs.ds_heavy_channel_num,
            help="The number of heavy channels in double sparsity attention",
        )
        parser.add_argument(
            "--ds-heavy-token-num",
            type=int,
            default=ServerArgs.ds_heavy_token_num,
            help="The number of heavy tokens in double sparsity attention",
        )
        parser.add_argument(
            "--ds-heavy-channel-type",
            type=str,
            default=ServerArgs.ds_heavy_channel_type,
            help="The type of heavy channels in double sparsity attention",
        )
        parser.add_argument(
            "--ds-sparse-decode-threshold",
            type=int,
            default=ServerArgs.ds_sparse_decode_threshold,
            help="The type of heavy channels in double sparsity attention",
        )

753
        # Optimization/debug options
Liangsheng Yin's avatar
Liangsheng Yin committed
754
        parser.add_argument(
Lianmin Zheng's avatar
Lianmin Zheng committed
755
            "--disable-radix-cache",
Liangsheng Yin's avatar
Liangsheng Yin committed
756
            action="store_true",
Ying Sheng's avatar
Ying Sheng committed
757
            help="Disable RadixAttention for prefix caching.",
Liangsheng Yin's avatar
Liangsheng Yin committed
758
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
759
        parser.add_argument(
Lianmin Zheng's avatar
Lianmin Zheng committed
760
            "--disable-jump-forward",
Liangsheng Yin's avatar
Liangsheng Yin committed
761
            action="store_true",
Lianmin Zheng's avatar
Lianmin Zheng committed
762
            help="Disable jump-forward for grammar-guided decoding.",
Liangsheng Yin's avatar
Liangsheng Yin committed
763
        )
764
765
766
767
768
        parser.add_argument(
            "--disable-cuda-graph",
            action="store_true",
            help="Disable cuda graph.",
        )
769
        parser.add_argument(
770
771
772
773
774
            "--disable-cuda-graph-padding",
            action="store_true",
            help="Disable cuda graph when padding is needed. Still uses cuda graph when padding is not needed.",
        )
        parser.add_argument(
775
            "--disable-outlines-disk-cache",
776
            action="store_true",
777
            help="Disable disk cache of outlines to avoid possible crashes related to file system or high concurrency.",
778
        )
779
780
781
782
783
        parser.add_argument(
            "--disable-custom-all-reduce",
            action="store_true",
            help="Disable the custom all-reduce kernel and fall back to NCCL.",
        )
Ke Bao's avatar
Ke Bao committed
784
785
786
787
788
        parser.add_argument(
            "--disable-mla",
            action="store_true",
            help="Disable Multi-head Latent Attention (MLA) for DeepSeek-V2.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
789
        parser.add_argument(
790
            "--disable-overlap-schedule",
Lianmin Zheng's avatar
Lianmin Zheng committed
791
            action="store_true",
792
            help="Disable the overlap scheduler, which overlaps the CPU scheduler with GPU model worker.",
Lianmin Zheng's avatar
Lianmin Zheng committed
793
        )
794
795
796
        parser.add_argument(
            "--enable-mixed-chunk",
            action="store_true",
797
            help="Enabling mixing prefill and decode in a batch when using chunked prefill.",
798
        )
Ke Bao's avatar
Ke Bao committed
799
800
801
802
803
        parser.add_argument(
            "--enable-dp-attention",
            action="store_true",
            help="Enabling data parallelism for attention and tensor parallelism for FFN. The dp size should be equal to the tp size. Currently only DeepSeek-V2 is supported.",
        )
xiaobochen's avatar
xiaobochen committed
804
805
806
807
808
        parser.add_argument(
            "--enable-ep-moe",
            action="store_true",
            help="Enabling expert parallelism for moe. The ep size is equal to the tp size.",
        )
809
810
811
        parser.add_argument(
            "--enable-torch-compile",
            action="store_true",
812
813
            help="Optimize the model with torch.compile. Experimental feature.",
        )
814
        parser.add_argument(
815
            "--torch-compile-max-bs",
816
            type=int,
817
            default=ServerArgs.torch_compile_max_bs,
818
819
            help="Set the maximum batch size when using torch compile.",
        )
820
        parser.add_argument(
821
            "--cuda-graph-max-bs",
822
            type=int,
823
            default=ServerArgs.cuda_graph_max_bs,
824
825
            help="Set the maximum batch size for cuda graph.",
        )
826
827
828
829
830
831
        parser.add_argument(
            "--cuda-graph-bs",
            type=int,
            nargs="+",
            help="Set the list of batch sizes for cuda graph.",
        )
832
833
834
835
        parser.add_argument(
            "--torchao-config",
            type=str,
            default=ServerArgs.torchao_config,
836
            help="Optimize the model with torchao. Experimental feature. Current choices are: int8dq, int8wo, int4wo-<group_size>, fp8wo, fp8dq-per_tensor, fp8dq-per_row",
837
        )
838
839
840
841
842
        parser.add_argument(
            "--enable-nan-detection",
            action="store_true",
            help="Enable the NaN detection for debugging purposes.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
843
        parser.add_argument(
844
            "--enable-p2p-check",
Lianmin Zheng's avatar
Lianmin Zheng committed
845
            action="store_true",
846
            help="Enable P2P check for GPU access, otherwise the p2p access is allowed by default.",
Lianmin Zheng's avatar
Lianmin Zheng committed
847
        )
848
        parser.add_argument(
849
            "--triton-attention-reduce-in-fp32",
850
            action="store_true",
851
            help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
852
            "This only affects Triton attention kernels.",
853
        )
854
855
856
857
858
859
        parser.add_argument(
            "--triton-attention-num-kv-splits",
            type=int,
            default=ServerArgs.triton_attention_num_kv_splits,
            help="The number of KV splits in flash decoding Triton kernel. Larger value is better in longer context scenarios. The default value is 8.",
        )
860
861
862
863
864
865
866
867
        parser.add_argument(
            "--num-continuous-decode-steps",
            type=int,
            default=ServerArgs.num_continuous_decode_steps,
            help="Run multiple continuous decoding steps to reduce scheduling overhead. "
            "This can potentially increase throughput but may also increase time-to-first-token latency. "
            "The default value is 1, meaning only run one decoding step at a time.",
        )
868
869
870
871
872
        parser.add_argument(
            "--delete-ckpt-after-loading",
            action="store_true",
            help="Delete the model checkpoint after loading the model.",
        )
873
874
875
876
877
        parser.add_argument(
            "--enable-memory-saver",
            action="store_true",
            help="Allow saving memory using release_memory_occupation and resume_memory_occupation",
        )
878
879
880
881
882
        parser.add_argument(
            "--allow-auto-truncate",
            action="store_true",
            help="Allow automatically truncating requests that exceed the maximum input length instead of returning an error.",
        )
883
884
885
886
887
        parser.add_argument(
            "--enable-custom-logit-processor",
            action="store_true",
            help="Enable users to pass custom logit processors to the server (disabled by default for security)",
        )
YAMY's avatar
YAMY committed
888
889
890
891
892
893
894
895
        # Function Calling
        parser.add_argument(
            "--tool-call-parser",
            type=str,
            choices=["qwen25", "mistral", "llama3"],
            default=ServerArgs.tool_call_parser,
            help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', and 'llama3'.",
        )
896
897
898
899
900
        parser.add_argument(
            "--enable-hierarchical-cache",
            action="store_true",
            help="Enable hierarchical cache",
        )
901

Lianmin Zheng's avatar
Lianmin Zheng committed
902
903
    @classmethod
    def from_cli_args(cls, args: argparse.Namespace):
904
905
        args.tp_size = args.tensor_parallel_size
        args.dp_size = args.data_parallel_size
xiaobochen's avatar
xiaobochen committed
906
        args.ep_size = args.expert_parallel_size
Lianmin Zheng's avatar
Lianmin Zheng committed
907
908
909
910
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        return cls(**{attr: getattr(args, attr) for attr in attrs})

    def url(self):
911
        if is_valid_ipv6_address(self.host):
912
913
914
            return f"http://[{self.host}]:{self.port}"
        else:
            return f"http://{self.host}:{self.port}"
Lianmin Zheng's avatar
Lianmin Zheng committed
915

916
917
918
919
920
    def check_server_args(self):
        assert (
            self.tp_size % self.nnodes == 0
        ), "tp_size must be divisible by number of nodes"
        assert not (
921
922
            self.dp_size > 1 and self.nnodes != 1 and not self.enable_dp_attention
        ), "multi-node data parallel is not supported unless dp attention!"
923
924
925
926
927
928
        assert (
            self.max_loras_per_batch > 0
            # FIXME
            and (self.lora_paths is None or self.disable_cuda_graph)
            and (self.lora_paths is None or self.disable_radix_cache)
        ), "compatibility of lora and cuda graph and radix attention is in progress"
929
        assert self.base_gpu_id >= 0, "base_gpu_id must be non-negative"
930

931
932
933
934
935
936
937
938
939
940
        if isinstance(self.lora_paths, list):
            lora_paths = self.lora_paths
            self.lora_paths = {}
            for lora_path in lora_paths:
                if "=" in lora_path:
                    name, path = lora_path.split("=", 1)
                    self.lora_paths[name] = path
                else:
                    self.lora_paths[lora_path] = lora_path

Lianmin Zheng's avatar
Lianmin Zheng committed
941

Lianmin Zheng's avatar
Lianmin Zheng committed
942
def prepare_server_args(argv: List[str]) -> ServerArgs:
943
944
945
946
947
948
949
950
951
952
953
954
    """
    Prepare the server arguments from the command line arguments.

    Args:
        args: The command line arguments. Typically, it should be `sys.argv[1:]`
            to ensure compatibility with `parse_args` when no arguments are passed.

    Returns:
        The server arguments.
    """
    parser = argparse.ArgumentParser()
    ServerArgs.add_cli_args(parser)
Lianmin Zheng's avatar
Lianmin Zheng committed
955
    raw_args = parser.parse_args(argv)
956
957
958
959
    server_args = ServerArgs.from_cli_args(raw_args)
    return server_args


960
961
962
ZMQ_TCP_PORT_DELTA = 233


Lianmin Zheng's avatar
Lianmin Zheng committed
963
964
@dataclasses.dataclass
class PortArgs:
965
966
967
968
969
970
    # The ipc filename for tokenizer to receive inputs from detokenizer (zmq)
    tokenizer_ipc_name: str
    # The ipc filename for scheduler (rank 0) to receive inputs from tokenizer (zmq)
    scheduler_input_ipc_name: str
    # The ipc filename for detokenizer to receive inputs from scheduler (zmq)
    detokenizer_ipc_name: str
971

972
973
    # The port for nccl initialization (torch.dist)
    nccl_port: int
974

975
    @staticmethod
976
    def init_new(server_args, dp_rank: Optional[int] = None) -> "PortArgs":
977
        port = server_args.port + random.randint(100, 1000)
978
979
980
        while True:
            if is_port_available(port):
                break
TianYu GUO's avatar
TianYu GUO committed
981
982
983
984
            if port < 60000:
                port += 42
            else:
                port -= 43
985

986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
        if not server_args.enable_dp_attention:
            # Normal case, use IPC within a single node
            return PortArgs(
                tokenizer_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
                scheduler_input_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
                detokenizer_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
                nccl_port=port,
            )
        else:
            # DP attention. Use TCP + port to handle both single-node and multi-node.
            if server_args.nnodes == 1 and server_args.dist_init_addr is None:
                dist_init_addr = ("127.0.0.1", server_args.port + ZMQ_TCP_PORT_DELTA)
            else:
                dist_init_addr = server_args.dist_init_addr.split(":")
            assert (
                len(dist_init_addr) == 2
            ), "please provide --dist-init-addr as host:port of head node"

            dist_init_host, dist_init_port = dist_init_addr
            port_base = int(dist_init_port) + 1
            if dp_rank is None:
                scheduler_input_port = (
                    port_base + 2
                )  # TokenizerManager to DataParallelController
            else:
                scheduler_input_port = port_base + 2 + 1 + dp_rank

            return PortArgs(
                tokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base}",
                scheduler_input_ipc_name=f"tcp://{dist_init_host}:{scheduler_input_port}",
                detokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base + 1}",
                nccl_port=port,
            )
1019

1020
1021
1022
1023
1024
1025
1026
1027
1028
1029

class LoRAPathAction(argparse.Action):
    def __call__(self, parser, namespace, values, option_string=None):
        setattr(namespace, self.dest, {})
        for lora_path in values:
            if "=" in lora_path:
                name, path = lora_path.split("=", 1)
                getattr(namespace, self.dest)[name] = path
            else:
                getattr(namespace, self.dest)[lora_path] = lora_path
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039


class DeprecatedAction(argparse.Action):
    def __init__(self, option_strings, dest, nargs=0, **kwargs):
        super(DeprecatedAction, self).__init__(
            option_strings, dest, nargs=nargs, **kwargs
        )

    def __call__(self, parser, namespace, values, option_string=None):
        raise ValueError(self.help)