"docs/vscode:/vscode.git/clone" did not exist on "95c231e50d97406b0dd1974632415a99fae2e701"
server_args.py 38.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
Lianmin Zheng's avatar
Lianmin Zheng committed
14
15
"""The arguments of the server."""

Lianmin Zheng's avatar
Lianmin Zheng committed
16
17
import argparse
import dataclasses
18
import logging
19
import random
20
21
import tempfile
from typing import List, Optional
Lianmin Zheng's avatar
Lianmin Zheng committed
22

23
24
import torch

25
from sglang.srt.hf_transformers_utils import check_gguf_file
26
from sglang.srt.utils import (
HAI's avatar
HAI committed
27
    get_amdgpu_memory_capacity,
28
    get_hpu_memory_capacity,
HAI's avatar
HAI committed
29
    get_nvgpu_memory_capacity,
30
    is_flashinfer_available,
HAI's avatar
HAI committed
31
    is_hip,
32
    is_port_available,
33
    is_valid_ipv6_address,
bjmsong's avatar
bjmsong committed
34
    nullable_str,
35
)
36

37
38
logger = logging.getLogger(__name__)

Lianmin Zheng's avatar
Lianmin Zheng committed
39
40
41

@dataclasses.dataclass
class ServerArgs:
Lianmin Zheng's avatar
Lianmin Zheng committed
42
    # Model and tokenizer
Lianmin Zheng's avatar
Lianmin Zheng committed
43
44
45
    model_path: str
    tokenizer_path: Optional[str] = None
    tokenizer_mode: str = "auto"
Lianmin Zheng's avatar
Lianmin Zheng committed
46
    load_format: str = "auto"
47
    trust_remote_code: bool = True
Lianmin Zheng's avatar
Lianmin Zheng committed
48
    dtype: str = "auto"
49
    kv_cache_dtype: str = "auto"
bjmsong's avatar
bjmsong committed
50
    quantization_param_path: nullable_str = None
Lianmin Zheng's avatar
Lianmin Zheng committed
51
    quantization: Optional[str] = None
52
53
    context_length: Optional[int] = None
    device: str = "cuda"
54
    served_model_name: Optional[str] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
55
    chat_template: Optional[str] = None
56
    is_embedding: bool = False
57
    revision: Optional[str] = None
58
    skip_tokenizer_init: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
59

60
    # Port for the HTTP server
Lianmin Zheng's avatar
Lianmin Zheng committed
61
62
63
64
    host: str = "127.0.0.1"
    port: int = 30000

    # Memory and scheduling
Lianmin Zheng's avatar
Lianmin Zheng committed
65
    mem_fraction_static: Optional[float] = None
66
    max_running_requests: Optional[int] = None
67
    max_total_tokens: Optional[int] = None
68
    chunked_prefill_size: Optional[int] = None
69
    max_prefill_tokens: int = 16384
70
    schedule_policy: str = "lpm"
71
    schedule_conservativeness: float = 1.0
72
    cpu_offload_gb: int = 0
73
    prefill_only_one_req: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
74
75
76

    # Other runtime options
    tp_size: int = 1
77
    stream_interval: int = 1
78
    random_seed: Optional[int] = None
79
    constrained_json_whitespace_pattern: Optional[str] = None
80
    watchdog_timeout: float = 300
81
    download_dir: Optional[str] = None
82
    base_gpu_id: int = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
83
84
85

    # Logging
    log_level: str = "info"
86
    log_level_http: Optional[str] = None
87
    log_requests: bool = False
Liangsheng Yin's avatar
Liangsheng Yin committed
88
    show_time_cost: bool = False
89
    enable_metrics: bool = False
90
    decode_log_interval: int = 40
Liangsheng Yin's avatar
Liangsheng Yin committed
91

92
    # API related
93
    api_key: Optional[str] = None
94
    file_storage_pth: str = "sglang_storage"
95
    enable_cache_report: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
96

97
98
99
    # Data parallelism
    dp_size: int = 1
    load_balance_method: str = "round_robin"
100

xiaobochen's avatar
xiaobochen committed
101
102
    # Expert parallelism
    ep_size: int = 1
103

104
    # Multi-node distributed serving
105
    dist_init_addr: Optional[str] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
106
    nnodes: int = 1
107
    node_rank: int = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
108
109
110
111

    # Model override args in JSON
    json_model_override_args: str = "{}"

112
113
114
115
116
    # LoRA
    lora_paths: Optional[List[str]] = None
    max_loras_per_batch: int = 8

    # Kernel backend
117
118
    attention_backend: Optional[str] = None
    sampling_backend: Optional[str] = None
119
    grammar_backend: Optional[str] = "outlines"
120

121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
    # Speculative decoding
    speculative_draft_model_path: Optional[str] = None
    speculative_algorithm: Optional[str] = None
    speculative_num_steps: int = 5
    speculative_num_draft_tokens: int = 64
    speculative_eagle_topk: int = 8

    # Double Sparsity
    enable_double_sparsity: bool = False
    ds_channel_config_path: str = None
    ds_heavy_channel_num: int = 32
    ds_heavy_token_num: int = 256
    ds_heavy_channel_type: str = "qk"
    ds_sparse_decode_threshold: int = 4096

136
    # Optimization/debug options
Lianmin Zheng's avatar
Lianmin Zheng committed
137
    disable_radix_cache: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
138
    disable_jump_forward: bool = False
139
    disable_cuda_graph: bool = False
140
    disable_cuda_graph_padding: bool = False
141
    disable_outlines_disk_cache: bool = False
142
    disable_custom_all_reduce: bool = False
Ke Bao's avatar
Ke Bao committed
143
    disable_mla: bool = False
144
    disable_overlap_schedule: bool = False
145
    enable_mixed_chunk: bool = False
Ke Bao's avatar
Ke Bao committed
146
    enable_dp_attention: bool = False
xiaobochen's avatar
xiaobochen committed
147
    enable_ep_moe: bool = False
148
    enable_torch_compile: bool = False
149
    torch_compile_max_bs: int = 32
150
    cuda_graph_max_bs: Optional[int] = None
151
    cuda_graph_bs: Optional[List[int]] = None
152
    torchao_config: str = ""
153
    enable_nan_detection: bool = False
154
    enable_p2p_check: bool = False
155
    triton_attention_reduce_in_fp32: bool = False
156
    triton_attention_num_kv_splits: int = 8
157
    num_continuous_decode_steps: int = 1
158
    delete_ckpt_after_loading: bool = False
159
    enable_memory_saver: bool = False
160
    allow_auto_truncate: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
161

162
163
    # Custom logit processor
    enable_custom_logit_processor: bool = False
YAMY's avatar
YAMY committed
164
    tool_call_parser: str = None
165

Lianmin Zheng's avatar
Lianmin Zheng committed
166
    def __post_init__(self):
167
        # Set missing default values
Lianmin Zheng's avatar
Lianmin Zheng committed
168
169
        if self.tokenizer_path is None:
            self.tokenizer_path = self.model_path
170
171
172
173

        if self.served_model_name is None:
            self.served_model_name = self.model_path

174
175
176
        if self.random_seed is None:
            self.random_seed = random.randint(0, 1 << 30)

177
178
        if is_hip():
            gpu_mem = get_amdgpu_memory_capacity()
179
        elif torch.cuda.is_available():
180
            gpu_mem = get_nvgpu_memory_capacity()
181
182
        elif self.device == "hpu":
            gpu_mem = get_hpu_memory_capacity()
183
184
185
        else:
            # GPU memory is not known yet or no GPU is available.
            gpu_mem = None
186
187

        # Set mem fraction static, which depends on the tensor parallelism size
Lianmin Zheng's avatar
Lianmin Zheng committed
188
        if self.mem_fraction_static is None:
189
            if self.tp_size >= 16:
190
                self.mem_fraction_static = 0.79
191
            elif self.tp_size >= 8:
192
                self.mem_fraction_static = 0.81
Lianmin Zheng's avatar
Lianmin Zheng committed
193
            elif self.tp_size >= 4:
194
                self.mem_fraction_static = 0.85
Lianmin Zheng's avatar
Lianmin Zheng committed
195
            elif self.tp_size >= 2:
196
                self.mem_fraction_static = 0.87
Ying Sheng's avatar
Ying Sheng committed
197
            else:
198
                self.mem_fraction_static = 0.88
199

200
201
        # Set chunked prefill size, which depends on the gpu memory capacity
        if self.chunked_prefill_size is None:
202
            if gpu_mem is not None and gpu_mem < 25_000:
203
204
205
                self.chunked_prefill_size = 2048
            else:
                self.chunked_prefill_size = 8192
206

207
208
        # Set cuda graph max batch size
        if self.cuda_graph_max_bs is None:
209
            # Based on detailed statistics, when serving TP1/TP2 models on lower-end GPUs with HBM<25G, you can either disable cuda graph or set `cuda_graph_max_bs` to a very small value to reduce the memory overhead of creating cuda graphs, with almost no impact on performance. However, when serving models with TP4 or TP8, we need to enable cuda graph to maintain high performance. In this case, we can set `cuda_graph_max_bs` to 80 (half of the default value 160) to reduce the memory overhead of creating cuda graphs. Looking at the logs from TP4 serving of qwen2-72b, a value of 80 is sufficient and can reduce the memory overhead of creating cuda graphs on lower-end GPUs compared to the original 160, avoiding OOM issues.
210
            if gpu_mem is not None and gpu_mem < 25_000:
211
212
213
214
                if self.tp_size < 4:
                    self.cuda_graph_max_bs = 8
                else:
                    self.cuda_graph_max_bs = 80
215
216
            else:
                self.cuda_graph_max_bs = 160
217

218
        # Choose kernel backends
219
220
221
222
        if self.device == "hpu":
            self.attention_backend = "torch_native"
            self.sampling_backend = "pytorch"

223
        if self.attention_backend is None:
224
225
226
            self.attention_backend = (
                "flashinfer" if is_flashinfer_available() else "triton"
            )
227
        if self.sampling_backend is None:
228
229
230
231
232
            self.sampling_backend = (
                "flashinfer" if is_flashinfer_available() else "pytorch"
            )

        if self.attention_backend == "torch_native":
233
            logger.warning(
234
235
236
                "Cuda graph is disabled because of using torch native attention backend"
            )
            self.disable_cuda_graph = True
237

238
239
240
241
242
243
244
        # Expert parallelism
        if self.enable_ep_moe:
            self.ep_size = self.tp_size
            logger.info(
                f"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
            )

245
        # Others
Ke Bao's avatar
Ke Bao committed
246
247
        if self.enable_dp_attention:
            self.dp_size = self.tp_size
248
            assert self.tp_size % self.dp_size == 0
Ke Bao's avatar
Ke Bao committed
249
            self.chunked_prefill_size = self.chunked_prefill_size // 2
250
            self.schedule_conservativeness = self.schedule_conservativeness * 0.3
251
            logger.warning(
252
                f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE kernel issues. "
253
                f"The schedule conservativeness is adjusted to {self.schedule_conservativeness}. "
254
255
256
                "Data parallel size is adjusted to be the same as tensor parallel size. "
            )

257
258
259
260
261
262
263
264
265
266
267
        # Speculative Decoding
        if self.speculative_algorithm == "EAGLE":
            self.prefill_only_one_req = True
            self.disable_cuda_graph_padding = True
            self.disable_radix_cache = True
            self.disable_overlap_schedule = True
            self.chunked_prefill_size = -1
            logger.info(
                "The radix cache, chunked prefill, and overlap scheduler are disabled because of using eagle speculative decoding."
            )

268
269
270
271
272
273
        # GGUF
        if (
            self.load_format == "auto" or self.load_format == "gguf"
        ) and check_gguf_file(self.model_path):
            self.quantization = self.load_format = "gguf"

Lianmin Zheng's avatar
Lianmin Zheng committed
274
275
    @staticmethod
    def add_cli_args(parser: argparse.ArgumentParser):
276
        # Model and port args
Lianmin Zheng's avatar
Lianmin Zheng committed
277
278
279
280
281
282
283
284
285
286
287
288
        parser.add_argument(
            "--model-path",
            type=str,
            help="The path of the model weights. This can be a local folder or a Hugging Face repo ID.",
            required=True,
        )
        parser.add_argument(
            "--tokenizer-path",
            type=str,
            default=ServerArgs.tokenizer_path,
            help="The path of the tokenizer.",
        )
Yuanhan Zhang's avatar
Yuanhan Zhang committed
289
290
291
292
293
294
        parser.add_argument(
            "--host", type=str, default=ServerArgs.host, help="The host of the server."
        )
        parser.add_argument(
            "--port", type=int, default=ServerArgs.port, help="The port of the server."
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
295
296
297
298
299
300
301
302
303
        parser.add_argument(
            "--tokenizer-mode",
            type=str,
            default=ServerArgs.tokenizer_mode,
            choices=["auto", "slow"],
            help="Tokenizer mode. 'auto' will use the fast "
            "tokenizer if available, and 'slow' will "
            "always use the slow tokenizer.",
        )
304
305
306
307
308
        parser.add_argument(
            "--skip-tokenizer-init",
            action="store_true",
            help="If set, skip init tokenizer and pass input_ids in generate request",
        )
309
        parser.add_argument(
Lianmin Zheng's avatar
Lianmin Zheng committed
310
311
312
            "--load-format",
            type=str,
            default=ServerArgs.load_format,
313
314
315
316
317
318
319
320
            choices=[
                "auto",
                "pt",
                "safetensors",
                "npcache",
                "dummy",
                "gguf",
                "bitsandbytes",
321
                "layered",
322
            ],
Lianmin Zheng's avatar
Lianmin Zheng committed
323
324
325
326
327
328
329
330
331
            help="The format of the model weights to load. "
            '"auto" will try to load the weights in the safetensors format '
            "and fall back to the pytorch bin format if safetensors format "
            "is not available. "
            '"pt" will load the weights in the pytorch bin format. '
            '"safetensors" will load the weights in the safetensors format. '
            '"npcache" will load the weights in pytorch format and store '
            "a numpy cache to speed up the loading. "
            '"dummy" will initialize the weights with random values, '
332
            "which is mainly for profiling."
333
334
            '"gguf" will load the weights in the gguf format. '
            '"bitsandbytes" will load the weights using bitsandbytes '
335
336
337
338
            "quantization."
            '"layered" loads weights layer by layer so that one can quantize a '
            "layer before loading another to make the peak memory envelope "
            "smaller.",
Lianmin Zheng's avatar
Lianmin Zheng committed
339
        )
340
341
342
343
344
        parser.add_argument(
            "--trust-remote-code",
            action="store_true",
            help="Whether or not to allow for custom models defined on the Hub in their own modeling files.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
345
        parser.add_argument(
Lianmin Zheng's avatar
Lianmin Zheng committed
346
            "--dtype",
Cody Yu's avatar
Cody Yu committed
347
            type=str,
Lianmin Zheng's avatar
Lianmin Zheng committed
348
            default=ServerArgs.dtype,
Ying Sheng's avatar
Ying Sheng committed
349
350
            choices=["auto", "half", "float16", "bfloat16", "float", "float32"],
            help="Data type for model weights and activations.\n\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
351
            '* "auto" will use FP16 precision for FP32 and FP16 models, and '
Ying Sheng's avatar
Ying Sheng committed
352
            "BF16 precision for BF16 models.\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
353
354
355
356
            '* "half" for FP16. Recommended for AWQ quantization.\n'
            '* "float16" is the same as "half".\n'
            '* "bfloat16" for a balance between precision and range.\n'
            '* "float" is shorthand for FP32 precision.\n'
Ying Sheng's avatar
Ying Sheng committed
357
358
            '* "float32" for FP32 precision.',
        )
359
360
361
362
        parser.add_argument(
            "--kv-cache-dtype",
            type=str,
            default=ServerArgs.kv_cache_dtype,
bjmsong's avatar
bjmsong committed
363
364
365
366
367
368
369
370
371
372
373
            choices=["auto", "fp8_e5m2", "fp8_e4m3"],
            help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" and "fp8_e4m3" is supported for CUDA 11.8+.',
        )
        parser.add_argument(
            "--quantization-param-path",
            type=nullable_str,
            default=None,
            help="Path to the JSON file containing the KV cache "
            "scaling factors. This should generally be supplied, when "
            "KV cache dtype is FP8. Otherwise, KV cache scaling factors "
            "default to 1.0, which may cause accuracy issues. ",
374
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
375
376
377
378
        parser.add_argument(
            "--quantization",
            type=str,
            default=ServerArgs.quantization,
Ying Sheng's avatar
Ying Sheng committed
379
380
381
382
383
384
            choices=[
                "awq",
                "fp8",
                "gptq",
                "marlin",
                "gptq_marlin",
Ying Sheng's avatar
Ying Sheng committed
385
                "awq_marlin",
Ying Sheng's avatar
Ying Sheng committed
386
                "bitsandbytes",
387
                "gguf",
388
                "modelopt",
389
                "w8a8_int8",
Ying Sheng's avatar
Ying Sheng committed
390
            ],
Lianmin Zheng's avatar
Lianmin Zheng committed
391
392
            help="The quantization method.",
        )
393
394
395
396
397
398
399
400
401
402
        parser.add_argument(
            "--context-length",
            type=int,
            default=ServerArgs.context_length,
            help="The model's maximum context length. Defaults to None (will use the value from the model's config.json instead).",
        )
        parser.add_argument(
            "--device",
            type=str,
            default="cuda",
403
            choices=["cuda", "xpu", "hpu", "cpu"],
404
405
            help="The device type.",
        )
406
407
408
409
410
411
        parser.add_argument(
            "--served-model-name",
            type=str,
            default=ServerArgs.served_model_name,
            help="Override the model name returned by the v1/models endpoint in OpenAI API server.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
412
413
414
415
416
417
        parser.add_argument(
            "--chat-template",
            type=str,
            default=ServerArgs.chat_template,
            help="The buliltin chat template name or the path of the chat template file. This is only used for OpenAI-compatible API server.",
        )
418
419
420
421
422
        parser.add_argument(
            "--is-embedding",
            action="store_true",
            help="Whether to use a CausalLM as an embedding model.",
        )
423
424
425
426
427
428
429
430
        parser.add_argument(
            "--revision",
            type=str,
            default=None,
            help="The specific model version to use. It can be a branch "
            "name, a tag name, or a commit id. If unspecified, will use "
            "the default version.",
        )
431
        # Memory and scheduling
Lianmin Zheng's avatar
Lianmin Zheng committed
432
433
434
435
        parser.add_argument(
            "--mem-fraction-static",
            type=float,
            default=ServerArgs.mem_fraction_static,
436
            help="The fraction of the memory used for static allocation (model weights and KV cache memory pool). Use a smaller value if you see out-of-memory errors.",
Lianmin Zheng's avatar
Lianmin Zheng committed
437
        )
438
439
440
441
442
443
        parser.add_argument(
            "--max-running-requests",
            type=int,
            default=ServerArgs.max_running_requests,
            help="The maximum number of running requests.",
        )
444
445
446
447
        parser.add_argument(
            "--max-total-tokens",
            type=int,
            default=ServerArgs.max_total_tokens,
448
449
            help="The maximum number of tokens in the memory pool. If not specified, it will be automatically calculated based on the memory usage fraction. "
            "This option is typically used for development and debugging purposes.",
450
        )
451
452
453
454
455
456
457
458
459
460
461
462
        parser.add_argument(
            "--chunked-prefill-size",
            type=int,
            default=ServerArgs.chunked_prefill_size,
            help="The maximum number of tokens in a chunk for the chunked prefill. Setting this to -1 means disabling chunked prefill",
        )
        parser.add_argument(
            "--max-prefill-tokens",
            type=int,
            default=ServerArgs.max_prefill_tokens,
            help="The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model's maximum context length.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
463
        parser.add_argument(
464
            "--schedule-policy",
Lianmin Zheng's avatar
Lianmin Zheng committed
465
            type=str,
466
            default=ServerArgs.schedule_policy,
Liangsheng Yin's avatar
Liangsheng Yin committed
467
            choices=["lpm", "random", "fcfs", "dfs-weight"],
468
            help="The scheduling policy of the requests.",
Lianmin Zheng's avatar
Lianmin Zheng committed
469
        )
470
471
472
473
        parser.add_argument(
            "--schedule-conservativeness",
            type=float,
            default=ServerArgs.schedule_conservativeness,
474
            help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see requests being retracted frequently.",
475
        )
476
477
478
479
480
481
        parser.add_argument(
            "--cpu-offload-gb",
            type=int,
            default=ServerArgs.cpu_offload_gb,
            help="How many GBs of RAM to reserve for CPU offloading",
        )
482
483
484
485
486
487
        parser.add_argument(
            "--prefill-only-one-req",
            type=bool,
            help="If true, we only prefill one request at one prefill batch",
            default=ServerArgs.prefill_only_one_req,
        )
488

489
        # Other runtime options
Lianmin Zheng's avatar
Lianmin Zheng committed
490
        parser.add_argument(
491
            "--tensor-parallel-size",
Lianmin Zheng's avatar
Lianmin Zheng committed
492
            "--tp-size",
Lianmin Zheng's avatar
Lianmin Zheng committed
493
            type=int,
Lianmin Zheng's avatar
Lianmin Zheng committed
494
            default=ServerArgs.tp_size,
495
            help="The tensor parallelism size.",
496
        )
497
498
499
        parser.add_argument(
            "--stream-interval",
            type=int,
Lianmin Zheng's avatar
Lianmin Zheng committed
500
            default=ServerArgs.stream_interval,
501
            help="The interval (or buffer size) for streaming in terms of the token length. A smaller value makes streaming smoother, while a larger value makes the throughput higher",
502
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
503
504
505
506
        parser.add_argument(
            "--random-seed",
            type=int,
            default=ServerArgs.random_seed,
507
            help="The random seed.",
Lianmin Zheng's avatar
Lianmin Zheng committed
508
        )
509
510
511
512
513
514
        parser.add_argument(
            "--constrained-json-whitespace-pattern",
            type=str,
            default=ServerArgs.constrained_json_whitespace_pattern,
            help=r"Regex pattern for syntactic whitespaces allowed in JSON constrained output. For example, to allow the model generate consecutive whitespaces, set the pattern to [\n\t ]*",
        )
515
516
517
518
519
520
        parser.add_argument(
            "--watchdog-timeout",
            type=float,
            default=ServerArgs.watchdog_timeout,
            help="Set watchdog timeout in seconds. If a forward batch takes longer than this, the server will crash to prevent hanging.",
        )
521
522
523
524
525
526
        parser.add_argument(
            "--download-dir",
            type=str,
            default=ServerArgs.download_dir,
            help="Model download directory.",
        )
527
528
529
530
531
532
        parser.add_argument(
            "--base-gpu-id",
            type=int,
            default=ServerArgs.base_gpu_id,
            help="The base GPU ID to start allocating GPUs from. Useful when running multiple instances on the same machine.",
        )
533
534

        # Logging
Lianmin Zheng's avatar
Lianmin Zheng committed
535
536
537
538
        parser.add_argument(
            "--log-level",
            type=str,
            default=ServerArgs.log_level,
539
            help="The logging level of all loggers.",
Lianmin Zheng's avatar
Lianmin Zheng committed
540
        )
541
        parser.add_argument(
542
543
544
545
            "--log-level-http",
            type=str,
            default=ServerArgs.log_level_http,
            help="The logging level of HTTP server. If not set, reuse --log-level by default.",
546
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
547
        parser.add_argument(
548
            "--log-requests",
Lianmin Zheng's avatar
Lianmin Zheng committed
549
            action="store_true",
550
            help="Log the inputs and outputs of all requests.",
Lianmin Zheng's avatar
Lianmin Zheng committed
551
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
552
553
554
        parser.add_argument(
            "--show-time-cost",
            action="store_true",
Ying Sheng's avatar
Ying Sheng committed
555
            help="Show time cost of custom marks.",
Lianmin Zheng's avatar
Lianmin Zheng committed
556
        )
557
558
559
560
561
        parser.add_argument(
            "--enable-metrics",
            action="store_true",
            help="Enable log prometheus metrics.",
        )
562
563
564
565
        parser.add_argument(
            "--decode-log-interval",
            type=int,
            default=ServerArgs.decode_log_interval,
566
            help="The log interval of decode batch.",
567
        )
568

569
        # API related
Liangsheng Yin's avatar
Liangsheng Yin committed
570
571
572
573
        parser.add_argument(
            "--api-key",
            type=str,
            default=ServerArgs.api_key,
574
            help="Set API key of the server. It is also used in the OpenAI API compatible server.",
Liangsheng Yin's avatar
Liangsheng Yin committed
575
        )
576
577
578
579
580
581
        parser.add_argument(
            "--file-storage-pth",
            type=str,
            default=ServerArgs.file_storage_pth,
            help="The path of the file storage in backend.",
        )
582
583
584
585
586
        parser.add_argument(
            "--enable-cache-report",
            action="store_true",
            help="Return number of cached tokens in usage.prompt_tokens_details for each openai request.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
587

588
589
        # Data parallelism
        parser.add_argument(
590
            "--data-parallel-size",
591
592
593
            "--dp-size",
            type=int,
            default=ServerArgs.dp_size,
594
            help="The data parallelism size.",
595
596
597
598
599
        )
        parser.add_argument(
            "--load-balance-method",
            type=str,
            default=ServerArgs.load_balance_method,
600
            help="The load balancing strategy for data parallelism.",
601
602
603
604
605
            choices=[
                "round_robin",
                "shortest_queue",
            ],
        )
606

xiaobochen's avatar
xiaobochen committed
607
608
609
610
611
612
613
614
        # Expert parallelism
        parser.add_argument(
            "--expert-parallel-size",
            "--ep-size",
            type=int,
            default=ServerArgs.ep_size,
            help="The expert parallelism size.",
        )
615

616
        # Multi-node distributed serving
617
        parser.add_argument(
618
619
            "--dist-init-addr",
            "--nccl-init-addr",  # For backward compatbility. This will be removed in the future.
620
            type=str,
621
            help="The host address for initializing distributed backend (e.g., `192.168.0.2:25000`).",
622
623
        )
        parser.add_argument(
Liangsheng Yin's avatar
Liangsheng Yin committed
624
            "--nnodes", type=int, default=ServerArgs.nnodes, help="The number of nodes."
625
        )
626
627
628
        parser.add_argument(
            "--node-rank", type=int, default=ServerArgs.node_rank, help="The node rank."
        )
629

Lianmin Zheng's avatar
Lianmin Zheng committed
630
631
632
633
634
635
636
637
        # Model override args
        parser.add_argument(
            "--json-model-override-args",
            type=str,
            help="A dictionary in JSON string format used to override default model configurations.",
            default=ServerArgs.json_model_override_args,
        )

638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
        # LoRA
        parser.add_argument(
            "--lora-paths",
            type=str,
            nargs="*",
            default=None,
            action=LoRAPathAction,
            help="The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}",
        )
        parser.add_argument(
            "--max-loras-per-batch",
            type=int,
            default=8,
            help="Maximum number of adapters for a running batch, include base-only request",
        )

        # Kernel backend
655
656
657
        parser.add_argument(
            "--attention-backend",
            type=str,
658
            choices=["flashinfer", "triton", "torch_native"],
659
660
661
662
663
664
665
666
667
668
            default=ServerArgs.attention_backend,
            help="Choose the kernels for attention layers.",
        )
        parser.add_argument(
            "--sampling-backend",
            type=str,
            choices=["flashinfer", "pytorch"],
            default=ServerArgs.sampling_backend,
            help="Choose the kernels for sampling layers.",
        )
669
670
671
672
673
        parser.add_argument(
            "--grammar-backend",
            type=str,
            choices=["xgrammar", "outlines"],
            default=ServerArgs.grammar_backend,
Lianmin Zheng's avatar
Lianmin Zheng committed
674
            help="Choose the backend for grammar-guided decoding.",
675
        )
676

677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
        # Speculative decoding
        parser.add_argument(
            "--speculative-algorithm",
            type=str,
            choices=["EAGLE"],
            help="Speculative algorithm.",
        )
        parser.add_argument(
            "--speculative-draft-model-path",
            type=str,
            help="The path of the draft model weights. This can be a local folder or a Hugging Face repo ID.",
        )
        parser.add_argument(
            "--speculative-num-steps",
            type=int,
            help="The number of steps sampled from draft model in Speculative Decoding.",
            default=ServerArgs.speculative_num_steps,
        )
        parser.add_argument(
            "--speculative-num-draft-tokens",
            type=int,
            help="The number of token sampled from draft model in Speculative Decoding.",
            default=ServerArgs.speculative_num_draft_tokens,
        )
        parser.add_argument(
            "--speculative-eagle-topk",
            type=int,
            help="The number of token sampled from draft model in eagle2 each step.",
            choices=[1, 2, 4, 8],
            default=ServerArgs.speculative_eagle_topk,
        )

        # Double Sparsity
        parser.add_argument(
            "--enable-double-sparsity",
            action="store_true",
            help="Enable double sparsity attention",
        )
        parser.add_argument(
            "--ds-channel-config-path",
            type=str,
            default=ServerArgs.ds_channel_config_path,
            help="The path of the double sparsity channel config",
        )
        parser.add_argument(
            "--ds-heavy-channel-num",
            type=int,
            default=ServerArgs.ds_heavy_channel_num,
            help="The number of heavy channels in double sparsity attention",
        )
        parser.add_argument(
            "--ds-heavy-token-num",
            type=int,
            default=ServerArgs.ds_heavy_token_num,
            help="The number of heavy tokens in double sparsity attention",
        )
        parser.add_argument(
            "--ds-heavy-channel-type",
            type=str,
            default=ServerArgs.ds_heavy_channel_type,
            help="The type of heavy channels in double sparsity attention",
        )
        parser.add_argument(
            "--ds-sparse-decode-threshold",
            type=int,
            default=ServerArgs.ds_sparse_decode_threshold,
            help="The type of heavy channels in double sparsity attention",
        )

746
        # Optimization/debug options
Liangsheng Yin's avatar
Liangsheng Yin committed
747
        parser.add_argument(
Lianmin Zheng's avatar
Lianmin Zheng committed
748
            "--disable-radix-cache",
Liangsheng Yin's avatar
Liangsheng Yin committed
749
            action="store_true",
Ying Sheng's avatar
Ying Sheng committed
750
            help="Disable RadixAttention for prefix caching.",
Liangsheng Yin's avatar
Liangsheng Yin committed
751
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
752
        parser.add_argument(
Lianmin Zheng's avatar
Lianmin Zheng committed
753
            "--disable-jump-forward",
Liangsheng Yin's avatar
Liangsheng Yin committed
754
            action="store_true",
Lianmin Zheng's avatar
Lianmin Zheng committed
755
            help="Disable jump-forward for grammar-guided decoding.",
Liangsheng Yin's avatar
Liangsheng Yin committed
756
        )
757
758
759
760
761
        parser.add_argument(
            "--disable-cuda-graph",
            action="store_true",
            help="Disable cuda graph.",
        )
762
        parser.add_argument(
763
764
765
766
767
            "--disable-cuda-graph-padding",
            action="store_true",
            help="Disable cuda graph when padding is needed. Still uses cuda graph when padding is not needed.",
        )
        parser.add_argument(
768
            "--disable-outlines-disk-cache",
769
            action="store_true",
770
            help="Disable disk cache of outlines to avoid possible crashes related to file system or high concurrency.",
771
        )
772
773
774
775
776
        parser.add_argument(
            "--disable-custom-all-reduce",
            action="store_true",
            help="Disable the custom all-reduce kernel and fall back to NCCL.",
        )
Ke Bao's avatar
Ke Bao committed
777
778
779
780
781
        parser.add_argument(
            "--disable-mla",
            action="store_true",
            help="Disable Multi-head Latent Attention (MLA) for DeepSeek-V2.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
782
        parser.add_argument(
783
            "--disable-overlap-schedule",
Lianmin Zheng's avatar
Lianmin Zheng committed
784
            action="store_true",
785
            help="Disable the overlap scheduler, which overlaps the CPU scheduler with GPU model worker.",
Lianmin Zheng's avatar
Lianmin Zheng committed
786
        )
787
788
789
        parser.add_argument(
            "--enable-mixed-chunk",
            action="store_true",
790
            help="Enabling mixing prefill and decode in a batch when using chunked prefill.",
791
        )
Ke Bao's avatar
Ke Bao committed
792
793
794
795
796
        parser.add_argument(
            "--enable-dp-attention",
            action="store_true",
            help="Enabling data parallelism for attention and tensor parallelism for FFN. The dp size should be equal to the tp size. Currently only DeepSeek-V2 is supported.",
        )
xiaobochen's avatar
xiaobochen committed
797
798
799
800
801
        parser.add_argument(
            "--enable-ep-moe",
            action="store_true",
            help="Enabling expert parallelism for moe. The ep size is equal to the tp size.",
        )
802
803
804
        parser.add_argument(
            "--enable-torch-compile",
            action="store_true",
805
806
            help="Optimize the model with torch.compile. Experimental feature.",
        )
807
        parser.add_argument(
808
            "--torch-compile-max-bs",
809
            type=int,
810
            default=ServerArgs.torch_compile_max_bs,
811
812
            help="Set the maximum batch size when using torch compile.",
        )
813
        parser.add_argument(
814
            "--cuda-graph-max-bs",
815
            type=int,
816
            default=ServerArgs.cuda_graph_max_bs,
817
818
            help="Set the maximum batch size for cuda graph.",
        )
819
820
821
822
823
824
        parser.add_argument(
            "--cuda-graph-bs",
            type=int,
            nargs="+",
            help="Set the list of batch sizes for cuda graph.",
        )
825
826
827
828
        parser.add_argument(
            "--torchao-config",
            type=str,
            default=ServerArgs.torchao_config,
829
            help="Optimize the model with torchao. Experimental feature. Current choices are: int8dq, int8wo, int4wo-<group_size>, fp8wo, fp8dq-per_tensor, fp8dq-per_row",
830
        )
831
832
833
834
835
        parser.add_argument(
            "--enable-nan-detection",
            action="store_true",
            help="Enable the NaN detection for debugging purposes.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
836
        parser.add_argument(
837
            "--enable-p2p-check",
Lianmin Zheng's avatar
Lianmin Zheng committed
838
            action="store_true",
839
            help="Enable P2P check for GPU access, otherwise the p2p access is allowed by default.",
Lianmin Zheng's avatar
Lianmin Zheng committed
840
        )
841
        parser.add_argument(
842
            "--triton-attention-reduce-in-fp32",
843
            action="store_true",
844
            help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
845
            "This only affects Triton attention kernels.",
846
        )
847
848
849
850
851
852
        parser.add_argument(
            "--triton-attention-num-kv-splits",
            type=int,
            default=ServerArgs.triton_attention_num_kv_splits,
            help="The number of KV splits in flash decoding Triton kernel. Larger value is better in longer context scenarios. The default value is 8.",
        )
853
854
855
856
857
858
859
860
        parser.add_argument(
            "--num-continuous-decode-steps",
            type=int,
            default=ServerArgs.num_continuous_decode_steps,
            help="Run multiple continuous decoding steps to reduce scheduling overhead. "
            "This can potentially increase throughput but may also increase time-to-first-token latency. "
            "The default value is 1, meaning only run one decoding step at a time.",
        )
861
862
863
864
865
        parser.add_argument(
            "--delete-ckpt-after-loading",
            action="store_true",
            help="Delete the model checkpoint after loading the model.",
        )
866
867
868
869
870
        parser.add_argument(
            "--enable-memory-saver",
            action="store_true",
            help="Allow saving memory using release_memory_occupation and resume_memory_occupation",
        )
871
872
873
874
875
        parser.add_argument(
            "--allow-auto-truncate",
            action="store_true",
            help="Allow automatically truncating requests that exceed the maximum input length instead of returning an error.",
        )
876
877
878
879
880
        parser.add_argument(
            "--enable-custom-logit-processor",
            action="store_true",
            help="Enable users to pass custom logit processors to the server (disabled by default for security)",
        )
YAMY's avatar
YAMY committed
881
882
883
884
885
886
887
888
        # Function Calling
        parser.add_argument(
            "--tool-call-parser",
            type=str,
            choices=["qwen25", "mistral", "llama3"],
            default=ServerArgs.tool_call_parser,
            help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', and 'llama3'.",
        )
889

Lianmin Zheng's avatar
Lianmin Zheng committed
890
891
    @classmethod
    def from_cli_args(cls, args: argparse.Namespace):
892
893
        args.tp_size = args.tensor_parallel_size
        args.dp_size = args.data_parallel_size
xiaobochen's avatar
xiaobochen committed
894
        args.ep_size = args.expert_parallel_size
Lianmin Zheng's avatar
Lianmin Zheng committed
895
896
897
898
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        return cls(**{attr: getattr(args, attr) for attr in attrs})

    def url(self):
899
        if is_valid_ipv6_address(self.host):
900
901
902
            return f"http://[{self.host}]:{self.port}"
        else:
            return f"http://{self.host}:{self.port}"
Lianmin Zheng's avatar
Lianmin Zheng committed
903

904
905
906
907
908
    def check_server_args(self):
        assert (
            self.tp_size % self.nnodes == 0
        ), "tp_size must be divisible by number of nodes"
        assert not (
909
910
            self.dp_size > 1 and self.nnodes != 1 and not self.enable_dp_attention
        ), "multi-node data parallel is not supported unless dp attention!"
911
912
913
914
915
916
        assert (
            self.max_loras_per_batch > 0
            # FIXME
            and (self.lora_paths is None or self.disable_cuda_graph)
            and (self.lora_paths is None or self.disable_radix_cache)
        ), "compatibility of lora and cuda graph and radix attention is in progress"
917
        assert self.base_gpu_id >= 0, "base_gpu_id must be non-negative"
918

919
920
921
922
923
924
925
926
927
928
        if isinstance(self.lora_paths, list):
            lora_paths = self.lora_paths
            self.lora_paths = {}
            for lora_path in lora_paths:
                if "=" in lora_path:
                    name, path = lora_path.split("=", 1)
                    self.lora_paths[name] = path
                else:
                    self.lora_paths[lora_path] = lora_path

Lianmin Zheng's avatar
Lianmin Zheng committed
929

Lianmin Zheng's avatar
Lianmin Zheng committed
930
def prepare_server_args(argv: List[str]) -> ServerArgs:
931
932
933
934
935
936
937
938
939
940
941
942
    """
    Prepare the server arguments from the command line arguments.

    Args:
        args: The command line arguments. Typically, it should be `sys.argv[1:]`
            to ensure compatibility with `parse_args` when no arguments are passed.

    Returns:
        The server arguments.
    """
    parser = argparse.ArgumentParser()
    ServerArgs.add_cli_args(parser)
Lianmin Zheng's avatar
Lianmin Zheng committed
943
    raw_args = parser.parse_args(argv)
944
945
946
947
    server_args = ServerArgs.from_cli_args(raw_args)
    return server_args


948
949
950
ZMQ_TCP_PORT_DELTA = 233


Lianmin Zheng's avatar
Lianmin Zheng committed
951
952
@dataclasses.dataclass
class PortArgs:
953
954
955
956
957
958
    # The ipc filename for tokenizer to receive inputs from detokenizer (zmq)
    tokenizer_ipc_name: str
    # The ipc filename for scheduler (rank 0) to receive inputs from tokenizer (zmq)
    scheduler_input_ipc_name: str
    # The ipc filename for detokenizer to receive inputs from scheduler (zmq)
    detokenizer_ipc_name: str
959

960
961
    # The port for nccl initialization (torch.dist)
    nccl_port: int
962

963
    @staticmethod
964
    def init_new(server_args, dp_rank: Optional[int] = None) -> "PortArgs":
965
        port = server_args.port + random.randint(100, 1000)
966
967
968
        while True:
            if is_port_available(port):
                break
TianYu GUO's avatar
TianYu GUO committed
969
970
971
972
            if port < 60000:
                port += 42
            else:
                port -= 43
973

974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
        if not server_args.enable_dp_attention:
            # Normal case, use IPC within a single node
            return PortArgs(
                tokenizer_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
                scheduler_input_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
                detokenizer_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
                nccl_port=port,
            )
        else:
            # DP attention. Use TCP + port to handle both single-node and multi-node.
            if server_args.nnodes == 1 and server_args.dist_init_addr is None:
                dist_init_addr = ("127.0.0.1", server_args.port + ZMQ_TCP_PORT_DELTA)
            else:
                dist_init_addr = server_args.dist_init_addr.split(":")
            assert (
                len(dist_init_addr) == 2
            ), "please provide --dist-init-addr as host:port of head node"

            dist_init_host, dist_init_port = dist_init_addr
            port_base = int(dist_init_port) + 1
            if dp_rank is None:
                scheduler_input_port = (
                    port_base + 2
                )  # TokenizerManager to DataParallelController
            else:
                scheduler_input_port = port_base + 2 + 1 + dp_rank

            return PortArgs(
                tokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base}",
                scheduler_input_ipc_name=f"tcp://{dist_init_host}:{scheduler_input_port}",
                detokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base + 1}",
                nccl_port=port,
            )
1007

1008
1009
1010
1011
1012
1013
1014
1015
1016
1017

class LoRAPathAction(argparse.Action):
    def __call__(self, parser, namespace, values, option_string=None):
        setattr(namespace, self.dest, {})
        for lora_path in values:
            if "=" in lora_path:
                name, path = lora_path.split("=", 1)
                getattr(namespace, self.dest)[name] = path
            else:
                getattr(namespace, self.dest)[lora_path] = lora_path
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027


class DeprecatedAction(argparse.Action):
    def __init__(self, option_strings, dest, nargs=0, **kwargs):
        super(DeprecatedAction, self).__init__(
            option_strings, dest, nargs=nargs, **kwargs
        )

    def __call__(self, parser, namespace, values, option_string=None):
        raise ValueError(self.help)