server_args.py 34.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
Lianmin Zheng's avatar
Lianmin Zheng committed
14
15
"""The arguments of the server."""

Lianmin Zheng's avatar
Lianmin Zheng committed
16
17
import argparse
import dataclasses
18
import logging
19
import random
20
21
import tempfile
from typing import List, Optional
Lianmin Zheng's avatar
Lianmin Zheng committed
22

23
24
import torch

25
from sglang.srt.hf_transformers_utils import check_gguf_file
26
from sglang.srt.utils import (
HAI's avatar
HAI committed
27
    get_amdgpu_memory_capacity,
28
    get_hpu_memory_capacity,
HAI's avatar
HAI committed
29
    get_nvgpu_memory_capacity,
30
    is_flashinfer_available,
HAI's avatar
HAI committed
31
    is_hip,
32
33
34
    is_ipv6,
    is_port_available,
)
35

36
37
logger = logging.getLogger(__name__)

Lianmin Zheng's avatar
Lianmin Zheng committed
38
39
40

@dataclasses.dataclass
class ServerArgs:
Lianmin Zheng's avatar
Lianmin Zheng committed
41
    # Model and tokenizer
Lianmin Zheng's avatar
Lianmin Zheng committed
42
43
44
    model_path: str
    tokenizer_path: Optional[str] = None
    tokenizer_mode: str = "auto"
Lianmin Zheng's avatar
Lianmin Zheng committed
45
    load_format: str = "auto"
46
    trust_remote_code: bool = True
Lianmin Zheng's avatar
Lianmin Zheng committed
47
    dtype: str = "auto"
48
    kv_cache_dtype: str = "auto"
Lianmin Zheng's avatar
Lianmin Zheng committed
49
    quantization: Optional[str] = None
50
51
    context_length: Optional[int] = None
    device: str = "cuda"
52
    served_model_name: Optional[str] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
53
    chat_template: Optional[str] = None
54
    is_embedding: bool = False
55
    revision: Optional[str] = None
56
    skip_tokenizer_init: bool = False
57
    return_token_ids: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
58

59
    # Port for the HTTP server
Lianmin Zheng's avatar
Lianmin Zheng committed
60
61
62
63
    host: str = "127.0.0.1"
    port: int = 30000

    # Memory and scheduling
Lianmin Zheng's avatar
Lianmin Zheng committed
64
    mem_fraction_static: Optional[float] = None
65
    max_running_requests: Optional[int] = None
66
    max_total_tokens: Optional[int] = None
67
    chunked_prefill_size: Optional[int] = None
68
    max_prefill_tokens: int = 16384
69
    schedule_policy: str = "lpm"
70
    schedule_conservativeness: float = 1.0
71
    cpu_offload_gb: int = 0
72
    prefill_only_one_req: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
73
74
75

    # Other runtime options
    tp_size: int = 1
76
    stream_interval: int = 1
77
    random_seed: Optional[int] = None
78
    constrained_json_whitespace_pattern: Optional[str] = None
79
    watchdog_timeout: float = 300
80
    download_dir: Optional[str] = None
81
    base_gpu_id: int = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
82
83
84

    # Logging
    log_level: str = "info"
85
    log_level_http: Optional[str] = None
86
    log_requests: bool = False
Liangsheng Yin's avatar
Liangsheng Yin committed
87
    show_time_cost: bool = False
88
    enable_metrics: bool = False
89
    decode_log_interval: int = 40
Liangsheng Yin's avatar
Liangsheng Yin committed
90

91
    # API related
92
    api_key: Optional[str] = None
93
    file_storage_pth: str = "SGLang_storage"
94
    enable_cache_report: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
95

96
97
98
    # Data parallelism
    dp_size: int = 1
    load_balance_method: str = "round_robin"
99

xiaobochen's avatar
xiaobochen committed
100
101
    # Expert parallelism
    ep_size: int = 1
102

103
    # Multi-node distributed serving
104
    dist_init_addr: Optional[str] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
105
    nnodes: int = 1
106
    node_rank: int = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
107
108
109
110

    # Model override args in JSON
    json_model_override_args: str = "{}"

111
112
113
114
115
    # LoRA
    lora_paths: Optional[List[str]] = None
    max_loras_per_batch: int = 8

    # Kernel backend
116
117
    attention_backend: Optional[str] = None
    sampling_backend: Optional[str] = None
118
    grammar_backend: Optional[str] = "outlines"
119

120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
    # Speculative decoding
    speculative_draft_model_path: Optional[str] = None
    speculative_algorithm: Optional[str] = None
    speculative_num_steps: int = 5
    speculative_num_draft_tokens: int = 64
    speculative_eagle_topk: int = 8

    # Double Sparsity
    enable_double_sparsity: bool = False
    ds_channel_config_path: str = None
    ds_heavy_channel_num: int = 32
    ds_heavy_token_num: int = 256
    ds_heavy_channel_type: str = "qk"
    ds_sparse_decode_threshold: int = 4096

135
    # Optimization/debug options
Lianmin Zheng's avatar
Lianmin Zheng committed
136
    disable_radix_cache: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
137
    disable_jump_forward: bool = False
138
    disable_cuda_graph: bool = False
139
    disable_cuda_graph_padding: bool = False
140
    disable_outlines_disk_cache: bool = False
141
    disable_custom_all_reduce: bool = False
Ke Bao's avatar
Ke Bao committed
142
    disable_mla: bool = False
143
    disable_overlap_schedule: bool = False
144
    enable_mixed_chunk: bool = False
Ke Bao's avatar
Ke Bao committed
145
    enable_dp_attention: bool = False
xiaobochen's avatar
xiaobochen committed
146
    enable_ep_moe: bool = False
147
    enable_torch_compile: bool = False
148
    torch_compile_max_bs: int = 32
149
    cuda_graph_max_bs: Optional[int] = None
150
    torchao_config: str = ""
151
    enable_nan_detection: bool = False
152
    enable_p2p_check: bool = False
153
    triton_attention_reduce_in_fp32: bool = False
154
    triton_attention_num_kv_splits: int = 8
155
    num_continuous_decode_steps: int = 1
156
    delete_ckpt_after_loading: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
157
158

    def __post_init__(self):
159
        # Set missing default values
Lianmin Zheng's avatar
Lianmin Zheng committed
160
161
        if self.tokenizer_path is None:
            self.tokenizer_path = self.model_path
162
163
164
165

        if self.served_model_name is None:
            self.served_model_name = self.model_path

166
167
168
        if self.random_seed is None:
            self.random_seed = random.randint(0, 1 << 30)

169
170
        if is_hip():
            gpu_mem = get_amdgpu_memory_capacity()
171
        elif torch.cuda.is_available():
172
            gpu_mem = get_nvgpu_memory_capacity()
173
174
        elif self.device == "hpu":
            gpu_mem = get_hpu_memory_capacity()
175
176
177
        else:
            # GPU memory is not known yet or no GPU is available.
            gpu_mem = None
178
179

        # Set mem fraction static, which depends on the tensor parallelism size
Lianmin Zheng's avatar
Lianmin Zheng committed
180
        if self.mem_fraction_static is None:
181
            if self.tp_size >= 16:
182
                self.mem_fraction_static = 0.79
183
            elif self.tp_size >= 8:
184
                self.mem_fraction_static = 0.81
Lianmin Zheng's avatar
Lianmin Zheng committed
185
            elif self.tp_size >= 4:
186
                self.mem_fraction_static = 0.85
Lianmin Zheng's avatar
Lianmin Zheng committed
187
            elif self.tp_size >= 2:
188
                self.mem_fraction_static = 0.87
Ying Sheng's avatar
Ying Sheng committed
189
            else:
190
                self.mem_fraction_static = 0.88
191

192
193
        # Set chunked prefill size, which depends on the gpu memory capacity
        if self.chunked_prefill_size is None:
194
            if gpu_mem is not None and gpu_mem < 25_000:
195
196
197
                self.chunked_prefill_size = 2048
            else:
                self.chunked_prefill_size = 8192
198

199
200
        # Set cuda graph max batch size
        if self.cuda_graph_max_bs is None:
201
            # Based on detailed statistics, when serving TP1/TP2 models on lower-end GPUs with HBM<25G, you can either disable cuda graph or set `cuda_graph_max_bs` to a very small value to reduce the memory overhead of creating cuda graphs, with almost no impact on performance. However, when serving models with TP4 or TP8, we need to enable cuda graph to maintain high performance. In this case, we can set `cuda_graph_max_bs` to 80 (half of the default value 160) to reduce the memory overhead of creating cuda graphs. Looking at the logs from TP4 serving of qwen2-72b, a value of 80 is sufficient and can reduce the memory overhead of creating cuda graphs on lower-end GPUs compared to the original 160, avoiding OOM issues.
202
            if gpu_mem is not None and gpu_mem < 25_000:
203
204
205
206
                if self.tp_size < 4:
                    self.cuda_graph_max_bs = 8
                else:
                    self.cuda_graph_max_bs = 80
207
208
            else:
                self.cuda_graph_max_bs = 160
209

210
        # Choose kernel backends
211
212
213
214
        if self.device == "hpu":
            self.attention_backend = "torch_native"
            self.sampling_backend = "pytorch"

215
        if self.attention_backend is None:
216
217
218
            self.attention_backend = (
                "flashinfer" if is_flashinfer_available() else "triton"
            )
219
        if self.sampling_backend is None:
220
221
222
223
224
            self.sampling_backend = (
                "flashinfer" if is_flashinfer_available() else "pytorch"
            )

        if self.attention_backend == "torch_native":
225
            logger.warning(
226
227
228
                "Cuda graph is disabled because of using torch native attention backend"
            )
            self.disable_cuda_graph = True
229

230
231
232
233
234
235
236
        # Expert parallelism
        if self.enable_ep_moe:
            self.ep_size = self.tp_size
            logger.info(
                f"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
            )

237
        # Others
Ke Bao's avatar
Ke Bao committed
238
239
240
        if self.enable_dp_attention:
            self.dp_size = self.tp_size
            self.chunked_prefill_size = self.chunked_prefill_size // 2
241
            self.schedule_conservativeness = self.schedule_conservativeness * 0.3
242
            self.disable_overlap_schedule = True
243
            logger.warning(
244
                f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE kernel issues. "
245
                f"The schedule conservativeness is adjusted to {self.schedule_conservativeness}. "
246
                "Data parallel size is adjusted to be the same as tensor parallel size. "
247
                "Overlap scheduler is disabled."
248
249
            )

250
251
252
253
254
255
        # GGUF
        if (
            self.load_format == "auto" or self.load_format == "gguf"
        ) and check_gguf_file(self.model_path):
            self.quantization = self.load_format = "gguf"

Lianmin Zheng's avatar
Lianmin Zheng committed
256
257
    @staticmethod
    def add_cli_args(parser: argparse.ArgumentParser):
258
        # Model and port args
Lianmin Zheng's avatar
Lianmin Zheng committed
259
260
261
262
263
264
265
266
267
268
269
270
        parser.add_argument(
            "--model-path",
            type=str,
            help="The path of the model weights. This can be a local folder or a Hugging Face repo ID.",
            required=True,
        )
        parser.add_argument(
            "--tokenizer-path",
            type=str,
            default=ServerArgs.tokenizer_path,
            help="The path of the tokenizer.",
        )
Yuanhan Zhang's avatar
Yuanhan Zhang committed
271
272
273
274
275
276
        parser.add_argument(
            "--host", type=str, default=ServerArgs.host, help="The host of the server."
        )
        parser.add_argument(
            "--port", type=int, default=ServerArgs.port, help="The port of the server."
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
277
278
279
280
281
282
283
284
285
        parser.add_argument(
            "--tokenizer-mode",
            type=str,
            default=ServerArgs.tokenizer_mode,
            choices=["auto", "slow"],
            help="Tokenizer mode. 'auto' will use the fast "
            "tokenizer if available, and 'slow' will "
            "always use the slow tokenizer.",
        )
286
        parser.add_argument(
Lianmin Zheng's avatar
Lianmin Zheng committed
287
288
289
            "--load-format",
            type=str,
            default=ServerArgs.load_format,
290
291
292
293
294
295
296
297
298
            choices=[
                "auto",
                "pt",
                "safetensors",
                "npcache",
                "dummy",
                "gguf",
                "bitsandbytes",
            ],
Lianmin Zheng's avatar
Lianmin Zheng committed
299
300
301
302
303
304
305
306
307
            help="The format of the model weights to load. "
            '"auto" will try to load the weights in the safetensors format '
            "and fall back to the pytorch bin format if safetensors format "
            "is not available. "
            '"pt" will load the weights in the pytorch bin format. '
            '"safetensors" will load the weights in the safetensors format. '
            '"npcache" will load the weights in pytorch format and store '
            "a numpy cache to speed up the loading. "
            '"dummy" will initialize the weights with random values, '
308
            "which is mainly for profiling."
309
310
311
            '"gguf" will load the weights in the gguf format. '
            '"bitsandbytes" will load the weights using bitsandbytes '
            "quantization.",
Lianmin Zheng's avatar
Lianmin Zheng committed
312
        )
313
314
315
316
317
        parser.add_argument(
            "--trust-remote-code",
            action="store_true",
            help="Whether or not to allow for custom models defined on the Hub in their own modeling files.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
318
        parser.add_argument(
Lianmin Zheng's avatar
Lianmin Zheng committed
319
            "--dtype",
Cody Yu's avatar
Cody Yu committed
320
            type=str,
Lianmin Zheng's avatar
Lianmin Zheng committed
321
            default=ServerArgs.dtype,
Ying Sheng's avatar
Ying Sheng committed
322
323
            choices=["auto", "half", "float16", "bfloat16", "float", "float32"],
            help="Data type for model weights and activations.\n\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
324
            '* "auto" will use FP16 precision for FP32 and FP16 models, and '
Ying Sheng's avatar
Ying Sheng committed
325
            "BF16 precision for BF16 models.\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
326
327
328
329
            '* "half" for FP16. Recommended for AWQ quantization.\n'
            '* "float16" is the same as "half".\n'
            '* "bfloat16" for a balance between precision and range.\n'
            '* "float" is shorthand for FP32 precision.\n'
Ying Sheng's avatar
Ying Sheng committed
330
331
            '* "float32" for FP32 precision.',
        )
332
333
334
335
336
337
338
        parser.add_argument(
            "--kv-cache-dtype",
            type=str,
            default=ServerArgs.kv_cache_dtype,
            choices=["auto", "fp8_e5m2"],
            help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" is supported for CUDA 11.8+.',
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
339
340
341
342
        parser.add_argument(
            "--quantization",
            type=str,
            default=ServerArgs.quantization,
Ying Sheng's avatar
Ying Sheng committed
343
344
345
346
347
348
            choices=[
                "awq",
                "fp8",
                "gptq",
                "marlin",
                "gptq_marlin",
Ying Sheng's avatar
Ying Sheng committed
349
                "awq_marlin",
Ying Sheng's avatar
Ying Sheng committed
350
                "bitsandbytes",
351
                "gguf",
Ying Sheng's avatar
Ying Sheng committed
352
            ],
Lianmin Zheng's avatar
Lianmin Zheng committed
353
354
            help="The quantization method.",
        )
355
356
357
358
359
360
361
362
363
364
        parser.add_argument(
            "--context-length",
            type=int,
            default=ServerArgs.context_length,
            help="The model's maximum context length. Defaults to None (will use the value from the model's config.json instead).",
        )
        parser.add_argument(
            "--device",
            type=str,
            default="cuda",
365
            choices=["cuda", "xpu", "hpu"],
366
367
            help="The device type.",
        )
368
369
370
371
372
373
        parser.add_argument(
            "--served-model-name",
            type=str,
            default=ServerArgs.served_model_name,
            help="Override the model name returned by the v1/models endpoint in OpenAI API server.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
374
375
376
377
378
379
        parser.add_argument(
            "--chat-template",
            type=str,
            default=ServerArgs.chat_template,
            help="The buliltin chat template name or the path of the chat template file. This is only used for OpenAI-compatible API server.",
        )
380
381
382
383
384
        parser.add_argument(
            "--is-embedding",
            action="store_true",
            help="Whether to use a CausalLM as an embedding model.",
        )
385
386
387
388
389
390
391
392
        parser.add_argument(
            "--revision",
            type=str,
            default=None,
            help="The specific model version to use. It can be a branch "
            "name, a tag name, or a commit id. If unspecified, will use "
            "the default version.",
        )
393
394
395
396
397
398
399
400
401
402
403
        parser.add_argument(
            "--skip-tokenizer-init",
            action="store_true",
            help="If set, skip init tokenizer and pass input_ids in generate request",
        )
        parser.add_argument(
            "--return-token-ids",
            action="store_true",
            default=ServerArgs.return_token_ids,
            help="Whether to return token IDs in the output, this may introduce additional overhead.",
        )
404
405

        # Memory and scheduling
Lianmin Zheng's avatar
Lianmin Zheng committed
406
407
408
409
        parser.add_argument(
            "--mem-fraction-static",
            type=float,
            default=ServerArgs.mem_fraction_static,
410
            help="The fraction of the memory used for static allocation (model weights and KV cache memory pool). Use a smaller value if you see out-of-memory errors.",
Lianmin Zheng's avatar
Lianmin Zheng committed
411
        )
412
413
414
415
416
417
        parser.add_argument(
            "--max-running-requests",
            type=int,
            default=ServerArgs.max_running_requests,
            help="The maximum number of running requests.",
        )
418
419
420
421
        parser.add_argument(
            "--max-total-tokens",
            type=int,
            default=ServerArgs.max_total_tokens,
422
423
            help="The maximum number of tokens in the memory pool. If not specified, it will be automatically calculated based on the memory usage fraction. "
            "This option is typically used for development and debugging purposes.",
424
        )
425
426
427
428
429
430
431
432
433
434
435
436
        parser.add_argument(
            "--chunked-prefill-size",
            type=int,
            default=ServerArgs.chunked_prefill_size,
            help="The maximum number of tokens in a chunk for the chunked prefill. Setting this to -1 means disabling chunked prefill",
        )
        parser.add_argument(
            "--max-prefill-tokens",
            type=int,
            default=ServerArgs.max_prefill_tokens,
            help="The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model's maximum context length.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
437
        parser.add_argument(
438
            "--schedule-policy",
Lianmin Zheng's avatar
Lianmin Zheng committed
439
            type=str,
440
            default=ServerArgs.schedule_policy,
Liangsheng Yin's avatar
Liangsheng Yin committed
441
            choices=["lpm", "random", "fcfs", "dfs-weight"],
442
            help="The scheduling policy of the requests.",
Lianmin Zheng's avatar
Lianmin Zheng committed
443
        )
444
445
446
447
        parser.add_argument(
            "--schedule-conservativeness",
            type=float,
            default=ServerArgs.schedule_conservativeness,
448
            help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see requests being retracted frequently.",
449
        )
450
451
452
453
454
455
        parser.add_argument(
            "--cpu-offload-gb",
            type=int,
            default=ServerArgs.cpu_offload_gb,
            help="How many GBs of RAM to reserve for CPU offloading",
        )
456
457
458
459
460
461
        parser.add_argument(
            "--prefill-only-one-req",
            type=bool,
            help="If true, we only prefill one request at one prefill batch",
            default=ServerArgs.prefill_only_one_req,
        )
462

463
        # Other runtime options
Lianmin Zheng's avatar
Lianmin Zheng committed
464
        parser.add_argument(
465
            "--tensor-parallel-size",
Lianmin Zheng's avatar
Lianmin Zheng committed
466
            "--tp-size",
Lianmin Zheng's avatar
Lianmin Zheng committed
467
            type=int,
Lianmin Zheng's avatar
Lianmin Zheng committed
468
            default=ServerArgs.tp_size,
469
            help="The tensor parallelism size.",
470
        )
471
472
473
        parser.add_argument(
            "--stream-interval",
            type=int,
Lianmin Zheng's avatar
Lianmin Zheng committed
474
            default=ServerArgs.stream_interval,
475
            help="The interval (or buffer size) for streaming in terms of the token length. A smaller value makes streaming smoother, while a larger value makes the throughput higher",
476
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
477
478
479
480
        parser.add_argument(
            "--random-seed",
            type=int,
            default=ServerArgs.random_seed,
481
            help="The random seed.",
Lianmin Zheng's avatar
Lianmin Zheng committed
482
        )
483
484
485
486
487
488
        parser.add_argument(
            "--constrained-json-whitespace-pattern",
            type=str,
            default=ServerArgs.constrained_json_whitespace_pattern,
            help=r"Regex pattern for syntactic whitespaces allowed in JSON constrained output. For example, to allow the model generate consecutive whitespaces, set the pattern to [\n\t ]*",
        )
489
490
491
492
493
494
        parser.add_argument(
            "--watchdog-timeout",
            type=float,
            default=ServerArgs.watchdog_timeout,
            help="Set watchdog timeout in seconds. If a forward batch takes longer than this, the server will crash to prevent hanging.",
        )
495
496
497
498
499
500
        parser.add_argument(
            "--download-dir",
            type=str,
            default=ServerArgs.download_dir,
            help="Model download directory.",
        )
501
502
503
504
505
506
        parser.add_argument(
            "--base-gpu-id",
            type=int,
            default=ServerArgs.base_gpu_id,
            help="The base GPU ID to start allocating GPUs from. Useful when running multiple instances on the same machine.",
        )
507
508

        # Logging
Lianmin Zheng's avatar
Lianmin Zheng committed
509
510
511
512
        parser.add_argument(
            "--log-level",
            type=str,
            default=ServerArgs.log_level,
513
            help="The logging level of all loggers.",
Lianmin Zheng's avatar
Lianmin Zheng committed
514
        )
515
        parser.add_argument(
516
517
518
519
            "--log-level-http",
            type=str,
            default=ServerArgs.log_level_http,
            help="The logging level of HTTP server. If not set, reuse --log-level by default.",
520
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
521
        parser.add_argument(
522
            "--log-requests",
Lianmin Zheng's avatar
Lianmin Zheng committed
523
            action="store_true",
524
            help="Log the inputs and outputs of all requests.",
Lianmin Zheng's avatar
Lianmin Zheng committed
525
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
526
527
528
        parser.add_argument(
            "--show-time-cost",
            action="store_true",
Ying Sheng's avatar
Ying Sheng committed
529
            help="Show time cost of custom marks.",
Lianmin Zheng's avatar
Lianmin Zheng committed
530
        )
531
532
533
534
535
        parser.add_argument(
            "--enable-metrics",
            action="store_true",
            help="Enable log prometheus metrics.",
        )
536
537
538
539
540
541
        parser.add_argument(
            "--decode-log-interval",
            type=int,
            default=ServerArgs.decode_log_interval,
            help="The log interval of decode batch",
        )
542

543
        # API related
Liangsheng Yin's avatar
Liangsheng Yin committed
544
545
546
547
        parser.add_argument(
            "--api-key",
            type=str,
            default=ServerArgs.api_key,
548
            help="Set API key of the server. It is also used in the OpenAI API compatible server.",
Liangsheng Yin's avatar
Liangsheng Yin committed
549
        )
550
551
552
553
554
555
        parser.add_argument(
            "--file-storage-pth",
            type=str,
            default=ServerArgs.file_storage_pth,
            help="The path of the file storage in backend.",
        )
556
557
558
559
560
        parser.add_argument(
            "--enable-cache-report",
            action="store_true",
            help="Return number of cached tokens in usage.prompt_tokens_details for each openai request.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
561

562
563
        # Data parallelism
        parser.add_argument(
564
            "--data-parallel-size",
565
566
567
            "--dp-size",
            type=int,
            default=ServerArgs.dp_size,
568
            help="The data parallelism size.",
569
570
571
572
573
        )
        parser.add_argument(
            "--load-balance-method",
            type=str,
            default=ServerArgs.load_balance_method,
574
            help="The load balancing strategy for data parallelism.",
575
576
577
578
579
            choices=[
                "round_robin",
                "shortest_queue",
            ],
        )
580

xiaobochen's avatar
xiaobochen committed
581
582
583
584
585
586
587
588
        # Expert parallelism
        parser.add_argument(
            "--expert-parallel-size",
            "--ep-size",
            type=int,
            default=ServerArgs.ep_size,
            help="The expert parallelism size.",
        )
589

590
        # Multi-node distributed serving
591
        parser.add_argument(
592
593
            "--dist-init-addr",
            "--nccl-init-addr",  # For backward compatbility. This will be removed in the future.
594
            type=str,
595
            help="The host address for initializing distributed backend (e.g., `192.168.0.2:25000`).",
596
597
        )
        parser.add_argument(
Liangsheng Yin's avatar
Liangsheng Yin committed
598
            "--nnodes", type=int, default=ServerArgs.nnodes, help="The number of nodes."
599
        )
600
601
602
        parser.add_argument(
            "--node-rank", type=int, default=ServerArgs.node_rank, help="The node rank."
        )
603

Lianmin Zheng's avatar
Lianmin Zheng committed
604
605
606
607
608
609
610
611
        # Model override args
        parser.add_argument(
            "--json-model-override-args",
            type=str,
            help="A dictionary in JSON string format used to override default model configurations.",
            default=ServerArgs.json_model_override_args,
        )

612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
        # LoRA
        parser.add_argument(
            "--lora-paths",
            type=str,
            nargs="*",
            default=None,
            action=LoRAPathAction,
            help="The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}",
        )
        parser.add_argument(
            "--max-loras-per-batch",
            type=int,
            default=8,
            help="Maximum number of adapters for a running batch, include base-only request",
        )

        # Kernel backend
629
630
631
        parser.add_argument(
            "--attention-backend",
            type=str,
632
            choices=["flashinfer", "triton", "torch_native"],
633
634
635
636
637
638
639
640
641
642
            default=ServerArgs.attention_backend,
            help="Choose the kernels for attention layers.",
        )
        parser.add_argument(
            "--sampling-backend",
            type=str,
            choices=["flashinfer", "pytorch"],
            default=ServerArgs.sampling_backend,
            help="Choose the kernels for sampling layers.",
        )
643
644
645
646
647
        parser.add_argument(
            "--grammar-backend",
            type=str,
            choices=["xgrammar", "outlines"],
            default=ServerArgs.grammar_backend,
Lianmin Zheng's avatar
Lianmin Zheng committed
648
            help="Choose the backend for grammar-guided decoding.",
649
        )
650

651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
        # Speculative decoding
        parser.add_argument(
            "--speculative-algorithm",
            type=str,
            choices=["EAGLE"],
            help="Speculative algorithm.",
        )
        parser.add_argument(
            "--speculative-draft-model-path",
            type=str,
            help="The path of the draft model weights. This can be a local folder or a Hugging Face repo ID.",
        )
        parser.add_argument(
            "--speculative-num-steps",
            type=int,
            help="The number of steps sampled from draft model in Speculative Decoding.",
            default=ServerArgs.speculative_num_steps,
        )
        parser.add_argument(
            "--speculative-num-draft-tokens",
            type=int,
            help="The number of token sampled from draft model in Speculative Decoding.",
            default=ServerArgs.speculative_num_draft_tokens,
        )
        parser.add_argument(
            "--speculative-eagle-topk",
            type=int,
            help="The number of token sampled from draft model in eagle2 each step.",
            choices=[1, 2, 4, 8],
            default=ServerArgs.speculative_eagle_topk,
        )

        # Double Sparsity
        parser.add_argument(
            "--enable-double-sparsity",
            action="store_true",
            help="Enable double sparsity attention",
        )
        parser.add_argument(
            "--ds-channel-config-path",
            type=str,
            default=ServerArgs.ds_channel_config_path,
            help="The path of the double sparsity channel config",
        )
        parser.add_argument(
            "--ds-heavy-channel-num",
            type=int,
            default=ServerArgs.ds_heavy_channel_num,
            help="The number of heavy channels in double sparsity attention",
        )
        parser.add_argument(
            "--ds-heavy-token-num",
            type=int,
            default=ServerArgs.ds_heavy_token_num,
            help="The number of heavy tokens in double sparsity attention",
        )
        parser.add_argument(
            "--ds-heavy-channel-type",
            type=str,
            default=ServerArgs.ds_heavy_channel_type,
            help="The type of heavy channels in double sparsity attention",
        )
        parser.add_argument(
            "--ds-sparse-decode-threshold",
            type=int,
            default=ServerArgs.ds_sparse_decode_threshold,
            help="The type of heavy channels in double sparsity attention",
        )

720
        # Optimization/debug options
Liangsheng Yin's avatar
Liangsheng Yin committed
721
        parser.add_argument(
Lianmin Zheng's avatar
Lianmin Zheng committed
722
            "--disable-radix-cache",
Liangsheng Yin's avatar
Liangsheng Yin committed
723
            action="store_true",
Ying Sheng's avatar
Ying Sheng committed
724
            help="Disable RadixAttention for prefix caching.",
Liangsheng Yin's avatar
Liangsheng Yin committed
725
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
726
        parser.add_argument(
Lianmin Zheng's avatar
Lianmin Zheng committed
727
            "--disable-jump-forward",
Liangsheng Yin's avatar
Liangsheng Yin committed
728
            action="store_true",
Lianmin Zheng's avatar
Lianmin Zheng committed
729
            help="Disable jump-forward for grammar-guided decoding.",
Liangsheng Yin's avatar
Liangsheng Yin committed
730
        )
731
732
733
734
735
        parser.add_argument(
            "--disable-cuda-graph",
            action="store_true",
            help="Disable cuda graph.",
        )
736
        parser.add_argument(
737
738
739
740
741
            "--disable-cuda-graph-padding",
            action="store_true",
            help="Disable cuda graph when padding is needed. Still uses cuda graph when padding is not needed.",
        )
        parser.add_argument(
742
            "--disable-outlines-disk-cache",
743
            action="store_true",
744
            help="Disable disk cache of outlines to avoid possible crashes related to file system or high concurrency.",
745
        )
746
747
748
749
750
        parser.add_argument(
            "--disable-custom-all-reduce",
            action="store_true",
            help="Disable the custom all-reduce kernel and fall back to NCCL.",
        )
Ke Bao's avatar
Ke Bao committed
751
752
753
754
755
        parser.add_argument(
            "--disable-mla",
            action="store_true",
            help="Disable Multi-head Latent Attention (MLA) for DeepSeek-V2.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
756
        parser.add_argument(
757
            "--disable-overlap-schedule",
Lianmin Zheng's avatar
Lianmin Zheng committed
758
            action="store_true",
759
            help="Disable the overlap scheduler, which overlaps the CPU scheduler with GPU model worker.",
Lianmin Zheng's avatar
Lianmin Zheng committed
760
        )
761
762
763
        parser.add_argument(
            "--enable-mixed-chunk",
            action="store_true",
764
            help="Enabling mixing prefill and decode in a batch when using chunked prefill.",
765
        )
Ke Bao's avatar
Ke Bao committed
766
767
768
769
770
        parser.add_argument(
            "--enable-dp-attention",
            action="store_true",
            help="Enabling data parallelism for attention and tensor parallelism for FFN. The dp size should be equal to the tp size. Currently only DeepSeek-V2 is supported.",
        )
xiaobochen's avatar
xiaobochen committed
771
772
773
774
775
        parser.add_argument(
            "--enable-ep-moe",
            action="store_true",
            help="Enabling expert parallelism for moe. The ep size is equal to the tp size.",
        )
776
777
778
        parser.add_argument(
            "--enable-torch-compile",
            action="store_true",
779
780
            help="Optimize the model with torch.compile. Experimental feature.",
        )
781
        parser.add_argument(
782
            "--torch-compile-max-bs",
783
            type=int,
784
            default=ServerArgs.torch_compile_max_bs,
785
786
            help="Set the maximum batch size when using torch compile.",
        )
787
        parser.add_argument(
788
            "--cuda-graph-max-bs",
789
            type=int,
790
            default=ServerArgs.cuda_graph_max_bs,
791
792
            help="Set the maximum batch size for cuda graph.",
        )
793
794
795
796
        parser.add_argument(
            "--torchao-config",
            type=str,
            default=ServerArgs.torchao_config,
797
            help="Optimize the model with torchao. Experimental feature. Current choices are: int8dq, int8wo, int4wo-<group_size>, fp8wo, fp8dq-per_tensor, fp8dq-per_row",
798
        )
799
800
801
802
803
        parser.add_argument(
            "--enable-nan-detection",
            action="store_true",
            help="Enable the NaN detection for debugging purposes.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
804
        parser.add_argument(
805
            "--enable-p2p-check",
Lianmin Zheng's avatar
Lianmin Zheng committed
806
            action="store_true",
807
            help="Enable P2P check for GPU access, otherwise the p2p access is allowed by default.",
Lianmin Zheng's avatar
Lianmin Zheng committed
808
        )
809
        parser.add_argument(
810
            "--triton-attention-reduce-in-fp32",
811
            action="store_true",
812
            help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
813
            "This only affects Triton attention kernels.",
814
        )
815
816
817
818
819
820
        parser.add_argument(
            "--triton-attention-num-kv-splits",
            type=int,
            default=ServerArgs.triton_attention_num_kv_splits,
            help="The number of KV splits in flash decoding Triton kernel. Larger value is better in longer context scenarios. The default value is 8.",
        )
821
822
823
824
825
826
827
828
        parser.add_argument(
            "--num-continuous-decode-steps",
            type=int,
            default=ServerArgs.num_continuous_decode_steps,
            help="Run multiple continuous decoding steps to reduce scheduling overhead. "
            "This can potentially increase throughput but may also increase time-to-first-token latency. "
            "The default value is 1, meaning only run one decoding step at a time.",
        )
829
830
831
832
833
        parser.add_argument(
            "--delete-ckpt-after-loading",
            action="store_true",
            help="Delete the model checkpoint after loading the model.",
        )
834

Lianmin Zheng's avatar
Lianmin Zheng committed
835
836
    @classmethod
    def from_cli_args(cls, args: argparse.Namespace):
837
838
        args.tp_size = args.tensor_parallel_size
        args.dp_size = args.data_parallel_size
xiaobochen's avatar
xiaobochen committed
839
        args.ep_size = args.expert_parallel_size
Lianmin Zheng's avatar
Lianmin Zheng committed
840
841
842
843
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        return cls(**{attr: getattr(args, attr) for attr in attrs})

    def url(self):
844
845
846
847
        if is_ipv6(self.host):
            return f"http://[{self.host}]:{self.port}"
        else:
            return f"http://{self.host}:{self.port}"
Lianmin Zheng's avatar
Lianmin Zheng committed
848

849
850
851
852
853
    def check_server_args(self):
        assert (
            self.tp_size % self.nnodes == 0
        ), "tp_size must be divisible by number of nodes"
        assert not (
854
            self.dp_size > 1 and self.nnodes != 1
855
        ), "multi-node data parallel is not supported"
856
857
858
859
860
861
        assert (
            self.max_loras_per_batch > 0
            # FIXME
            and (self.lora_paths is None or self.disable_cuda_graph)
            and (self.lora_paths is None or self.disable_radix_cache)
        ), "compatibility of lora and cuda graph and radix attention is in progress"
862
        assert self.base_gpu_id >= 0, "base_gpu_id must be non-negative"
863

864
865
866
867
868
869
870
871
872
873
        if isinstance(self.lora_paths, list):
            lora_paths = self.lora_paths
            self.lora_paths = {}
            for lora_path in lora_paths:
                if "=" in lora_path:
                    name, path = lora_path.split("=", 1)
                    self.lora_paths[name] = path
                else:
                    self.lora_paths[lora_path] = lora_path

Lianmin Zheng's avatar
Lianmin Zheng committed
874

Lianmin Zheng's avatar
Lianmin Zheng committed
875
def prepare_server_args(argv: List[str]) -> ServerArgs:
876
877
878
879
880
881
882
883
884
885
886
887
    """
    Prepare the server arguments from the command line arguments.

    Args:
        args: The command line arguments. Typically, it should be `sys.argv[1:]`
            to ensure compatibility with `parse_args` when no arguments are passed.

    Returns:
        The server arguments.
    """
    parser = argparse.ArgumentParser()
    ServerArgs.add_cli_args(parser)
Lianmin Zheng's avatar
Lianmin Zheng committed
888
    raw_args = parser.parse_args(argv)
889
890
891
892
    server_args = ServerArgs.from_cli_args(raw_args)
    return server_args


Lianmin Zheng's avatar
Lianmin Zheng committed
893
894
@dataclasses.dataclass
class PortArgs:
895
896
897
898
899
900
    # The ipc filename for tokenizer to receive inputs from detokenizer (zmq)
    tokenizer_ipc_name: str
    # The ipc filename for scheduler (rank 0) to receive inputs from tokenizer (zmq)
    scheduler_input_ipc_name: str
    # The ipc filename for detokenizer to receive inputs from scheduler (zmq)
    detokenizer_ipc_name: str
901

902
903
    # The port for nccl initialization (torch.dist)
    nccl_port: int
904

905
906
    @staticmethod
    def init_new(server_args) -> "PortArgs":
907
        port = server_args.port + random.randint(100, 1000)
908
909
910
        while True:
            if is_port_available(port):
                break
911
            port += 42
912
913
914
915
916

        return PortArgs(
            tokenizer_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
            scheduler_input_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
            detokenizer_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
917
            nccl_port=port,
918
919
        )

920
921
922
923
924
925
926
927
928
929

class LoRAPathAction(argparse.Action):
    def __call__(self, parser, namespace, values, option_string=None):
        setattr(namespace, self.dest, {})
        for lora_path in values:
            if "=" in lora_path:
                name, path = lora_path.split("=", 1)
                getattr(namespace, self.dest)[name] = path
            else:
                getattr(namespace, self.dest)[lora_path] = lora_path
930
931
932
933
934
935
936
937
938
939


class DeprecatedAction(argparse.Action):
    def __init__(self, option_strings, dest, nargs=0, **kwargs):
        super(DeprecatedAction, self).__init__(
            option_strings, dest, nargs=nargs, **kwargs
        )

    def __call__(self, parser, namespace, values, option_string=None):
        raise ValueError(self.help)