server_args.py 29.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

Lianmin Zheng's avatar
Lianmin Zheng committed
16
17
"""The arguments of the server."""

Lianmin Zheng's avatar
Lianmin Zheng committed
18
19
import argparse
import dataclasses
20
import logging
21
import random
22
23
import tempfile
from typing import List, Optional
Lianmin Zheng's avatar
Lianmin Zheng committed
24

25
from sglang.srt.utils import (
HAI's avatar
HAI committed
26
27
    get_amdgpu_memory_capacity,
    get_nvgpu_memory_capacity,
28
    is_flashinfer_available,
HAI's avatar
HAI committed
29
    is_hip,
30
31
32
    is_ipv6,
    is_port_available,
)
33

34
35
logger = logging.getLogger(__name__)

Lianmin Zheng's avatar
Lianmin Zheng committed
36
37
38

@dataclasses.dataclass
class ServerArgs:
Lianmin Zheng's avatar
Lianmin Zheng committed
39
    # Model and tokenizer
Lianmin Zheng's avatar
Lianmin Zheng committed
40
41
42
    model_path: str
    tokenizer_path: Optional[str] = None
    tokenizer_mode: str = "auto"
43
    skip_tokenizer_init: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
44
    load_format: str = "auto"
45
    trust_remote_code: bool = True
Lianmin Zheng's avatar
Lianmin Zheng committed
46
    dtype: str = "auto"
47
    kv_cache_dtype: str = "auto"
Lianmin Zheng's avatar
Lianmin Zheng committed
48
    quantization: Optional[str] = None
49
50
    context_length: Optional[int] = None
    device: str = "cuda"
51
    served_model_name: Optional[str] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
52
    chat_template: Optional[str] = None
53
    is_embedding: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
54
55
56
57
58
59

    # Port
    host: str = "127.0.0.1"
    port: int = 30000

    # Memory and scheduling
Lianmin Zheng's avatar
Lianmin Zheng committed
60
    mem_fraction_static: Optional[float] = None
61
    max_running_requests: Optional[int] = None
62
    max_total_tokens: Optional[int] = None
63
    chunked_prefill_size: int = 8192
64
    max_prefill_tokens: int = 16384
65
    schedule_policy: str = "lpm"
66
    schedule_conservativeness: float = 1.0
Lianmin Zheng's avatar
Lianmin Zheng committed
67
68
69

    # Other runtime options
    tp_size: int = 1
70
    stream_interval: int = 1
71
    random_seed: Optional[int] = None
72
    constrained_json_whitespace_pattern: Optional[str] = None
73
    watchdog_timeout: float = 300
74
    download_dir: Optional[str] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
75
76
77

    # Logging
    log_level: str = "info"
78
    log_level_http: Optional[str] = None
79
    log_requests: bool = False
Liangsheng Yin's avatar
Liangsheng Yin committed
80
    show_time_cost: bool = False
81
    enable_metrics: bool = False
82
    decode_log_interval: int = 40
Liangsheng Yin's avatar
Liangsheng Yin committed
83

84
    # API related
85
    api_key: Optional[str] = None
86
    file_storage_pth: str = "SGLang_storage"
87
    enable_cache_report: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
88

89
90
91
92
    # Data parallelism
    dp_size: int = 1
    load_balance_method: str = "round_robin"

93
    # Multi-node distributed serving
94
    dist_init_addr: Optional[str] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
95
    nnodes: int = 1
96
    node_rank: int = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
97
98
99
100

    # Model override args in JSON
    json_model_override_args: str = "{}"

Shuo Yang's avatar
Shuo Yang committed
101
102
103
104
105
106
107
108
    # Double Sparsity
    enable_double_sparsity: bool = False
    ds_channel_config_path: str = None
    ds_heavy_channel_num: int = 32
    ds_heavy_token_num: int = 256
    ds_heavy_channel_type: str = "qk"
    ds_sparse_decode_threshold: int = 4096

109
110
111
112
113
    # LoRA
    lora_paths: Optional[List[str]] = None
    max_loras_per_batch: int = 8

    # Kernel backend
114
115
    attention_backend: Optional[str] = None
    sampling_backend: Optional[str] = None
116
    grammar_backend: Optional[str] = "outlines"
117

118
    # Optimization/debug options
Lianmin Zheng's avatar
Lianmin Zheng committed
119
    disable_radix_cache: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
120
    disable_jump_forward: bool = False
121
    disable_cuda_graph: bool = False
122
    disable_cuda_graph_padding: bool = False
123
    disable_disk_cache: bool = False
124
    disable_custom_all_reduce: bool = False
Ke Bao's avatar
Ke Bao committed
125
    disable_mla: bool = False
126
    disable_overlap_schedule: bool = False
127
    enable_mixed_chunk: bool = False
Ke Bao's avatar
Ke Bao committed
128
    enable_dp_attention: bool = False
129
    enable_torch_compile: bool = False
130
131
    torch_compile_max_bs: int = 32
    cuda_graph_max_bs: int = 160
132
    torchao_config: str = ""
133
    enable_nan_detection: bool = False
134
    enable_p2p_check: bool = False
135
    triton_attention_reduce_in_fp32: bool = False
136
    num_continuous_decode_steps: int = 1
137
    delete_ckpt_after_loading: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
138
139

    def __post_init__(self):
140
        # Set missing default values
Lianmin Zheng's avatar
Lianmin Zheng committed
141
142
        if self.tokenizer_path is None:
            self.tokenizer_path = self.model_path
143
144
145
146

        if self.served_model_name is None:
            self.served_model_name = self.model_path

147
148
149
150
        if self.chunked_prefill_size <= 0:
            # Disable chunked prefill
            self.chunked_prefill_size = None

151
152
153
        if self.random_seed is None:
            self.random_seed = random.randint(0, 1 << 30)

154
        # Mem fraction depends on the tensor parallelism size
Lianmin Zheng's avatar
Lianmin Zheng committed
155
        if self.mem_fraction_static is None:
156
            if self.tp_size >= 16:
157
                self.mem_fraction_static = 0.79
158
            elif self.tp_size >= 8:
159
                self.mem_fraction_static = 0.82
Lianmin Zheng's avatar
Lianmin Zheng committed
160
            elif self.tp_size >= 4:
161
                self.mem_fraction_static = 0.85
Lianmin Zheng's avatar
Lianmin Zheng committed
162
            elif self.tp_size >= 2:
163
                self.mem_fraction_static = 0.87
Ying Sheng's avatar
Ying Sheng committed
164
            else:
165
                self.mem_fraction_static = 0.88
166

167
        # Adjust for GPUs with small memory capacities
HAI's avatar
HAI committed
168
169
170
171
        if is_hip():
            gpu_mem = get_amdgpu_memory_capacity()
        else:
            gpu_mem = get_nvgpu_memory_capacity()
172
        if gpu_mem < 25000:
173
174
            self.chunked_prefill_size //= 4  # make it 2048
            self.cuda_graph_max_bs = 4
175
            logger.info("Automatically adjust --chunked-prefill-size for small GPUs.")
176

177
        # Choose kernel backends
178
        if not is_flashinfer_available():
179
180
181
            self.attention_backend = "triton"
            self.sampling_backend = "pytorch"

182
183
184
185
186
        if self.attention_backend is None:
            self.attention_backend = "flashinfer"
        if self.sampling_backend is None:
            self.sampling_backend = "flashinfer"

187
        # Others
Ke Bao's avatar
Ke Bao committed
188
189
190
        if self.enable_dp_attention:
            self.dp_size = self.tp_size
            self.chunked_prefill_size = self.chunked_prefill_size // 2
191
            self.cuda_graph_max_bs = min(self.cuda_graph_max_bs, 96)
192
            self.schedule_conservativeness = self.schedule_conservativeness * 0.3
193
194
            self.disable_overlap_schedule = True
            logger.info(
195
                f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE kernel issues. "
196
                f"The CUDA graph max batch size is adjusted to {self.cuda_graph_max_bs}. "
197
                f"The schedule conservativeness is adjusted to {self.schedule_conservativeness}. "
198
199
200
201
202
203
204
                "Data parallel size is adjusted to be the same as tensor parallel size. "
                "Overlap schedule is disabled."
            )

        if self.enable_mixed_chunk:
            logger.info(
                "Overlap schedule is disabled because mixed-style chunked prefill is enabled."
Ke Bao's avatar
Ke Bao committed
205
            )
206
            self.disable_overlap_schedule = True
Ke Bao's avatar
Ke Bao committed
207

Lianmin Zheng's avatar
Lianmin Zheng committed
208
209
    @staticmethod
    def add_cli_args(parser: argparse.ArgumentParser):
210
        # Model and port args
Lianmin Zheng's avatar
Lianmin Zheng committed
211
212
213
214
215
216
217
218
219
220
221
222
        parser.add_argument(
            "--model-path",
            type=str,
            help="The path of the model weights. This can be a local folder or a Hugging Face repo ID.",
            required=True,
        )
        parser.add_argument(
            "--tokenizer-path",
            type=str,
            default=ServerArgs.tokenizer_path,
            help="The path of the tokenizer.",
        )
Yuanhan Zhang's avatar
Yuanhan Zhang committed
223
224
225
226
227
228
        parser.add_argument(
            "--host", type=str, default=ServerArgs.host, help="The host of the server."
        )
        parser.add_argument(
            "--port", type=int, default=ServerArgs.port, help="The port of the server."
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
229
230
231
232
233
234
235
236
237
        parser.add_argument(
            "--tokenizer-mode",
            type=str,
            default=ServerArgs.tokenizer_mode,
            choices=["auto", "slow"],
            help="Tokenizer mode. 'auto' will use the fast "
            "tokenizer if available, and 'slow' will "
            "always use the slow tokenizer.",
        )
238
239
240
241
242
        parser.add_argument(
            "--skip-tokenizer-init",
            action="store_true",
            help="If set, skip init tokenizer and pass input_ids in generate request",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
        parser.add_argument(
            "--load-format",
            type=str,
            default=ServerArgs.load_format,
            choices=["auto", "pt", "safetensors", "npcache", "dummy"],
            help="The format of the model weights to load. "
            '"auto" will try to load the weights in the safetensors format '
            "and fall back to the pytorch bin format if safetensors format "
            "is not available. "
            '"pt" will load the weights in the pytorch bin format. '
            '"safetensors" will load the weights in the safetensors format. '
            '"npcache" will load the weights in pytorch format and store '
            "a numpy cache to speed up the loading. "
            '"dummy" will initialize the weights with random values, '
            "which is mainly for profiling.",
        )
259
260
261
262
263
        parser.add_argument(
            "--trust-remote-code",
            action="store_true",
            help="Whether or not to allow for custom models defined on the Hub in their own modeling files.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
264
        parser.add_argument(
Lianmin Zheng's avatar
Lianmin Zheng committed
265
            "--dtype",
Cody Yu's avatar
Cody Yu committed
266
            type=str,
Lianmin Zheng's avatar
Lianmin Zheng committed
267
            default=ServerArgs.dtype,
Ying Sheng's avatar
Ying Sheng committed
268
269
            choices=["auto", "half", "float16", "bfloat16", "float", "float32"],
            help="Data type for model weights and activations.\n\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
270
            '* "auto" will use FP16 precision for FP32 and FP16 models, and '
Ying Sheng's avatar
Ying Sheng committed
271
            "BF16 precision for BF16 models.\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
272
273
274
275
            '* "half" for FP16. Recommended for AWQ quantization.\n'
            '* "float16" is the same as "half".\n'
            '* "bfloat16" for a balance between precision and range.\n'
            '* "float" is shorthand for FP32 precision.\n'
Ying Sheng's avatar
Ying Sheng committed
276
277
            '* "float32" for FP32 precision.',
        )
278
279
280
281
282
283
284
        parser.add_argument(
            "--kv-cache-dtype",
            type=str,
            default=ServerArgs.kv_cache_dtype,
            choices=["auto", "fp8_e5m2"],
            help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" is supported for CUDA 11.8+.',
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
285
286
287
288
        parser.add_argument(
            "--quantization",
            type=str,
            default=ServerArgs.quantization,
Ying Sheng's avatar
Ying Sheng committed
289
290
291
292
293
294
            choices=[
                "awq",
                "fp8",
                "gptq",
                "marlin",
                "gptq_marlin",
Ying Sheng's avatar
Ying Sheng committed
295
                "awq_marlin",
Ying Sheng's avatar
Ying Sheng committed
296
297
                "bitsandbytes",
            ],
Lianmin Zheng's avatar
Lianmin Zheng committed
298
299
            help="The quantization method.",
        )
300
301
302
303
304
305
306
307
308
309
310
311
312
        parser.add_argument(
            "--context-length",
            type=int,
            default=ServerArgs.context_length,
            help="The model's maximum context length. Defaults to None (will use the value from the model's config.json instead).",
        )
        parser.add_argument(
            "--device",
            type=str,
            default="cuda",
            choices=["cuda", "xpu"],
            help="The device type.",
        )
313
314
315
316
317
318
        parser.add_argument(
            "--served-model-name",
            type=str,
            default=ServerArgs.served_model_name,
            help="Override the model name returned by the v1/models endpoint in OpenAI API server.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
319
320
321
322
323
324
        parser.add_argument(
            "--chat-template",
            type=str,
            default=ServerArgs.chat_template,
            help="The buliltin chat template name or the path of the chat template file. This is only used for OpenAI-compatible API server.",
        )
325
326
327
328
329
        parser.add_argument(
            "--is-embedding",
            action="store_true",
            help="Whether to use a CausalLM as an embedding model.",
        )
330
331

        # Memory and scheduling
Lianmin Zheng's avatar
Lianmin Zheng committed
332
333
334
335
        parser.add_argument(
            "--mem-fraction-static",
            type=float,
            default=ServerArgs.mem_fraction_static,
336
            help="The fraction of the memory used for static allocation (model weights and KV cache memory pool). Use a smaller value if you see out-of-memory errors.",
Lianmin Zheng's avatar
Lianmin Zheng committed
337
        )
338
339
340
341
342
343
        parser.add_argument(
            "--max-running-requests",
            type=int,
            default=ServerArgs.max_running_requests,
            help="The maximum number of running requests.",
        )
344
345
346
347
        parser.add_argument(
            "--max-total-tokens",
            type=int,
            default=ServerArgs.max_total_tokens,
348
349
            help="The maximum number of tokens in the memory pool. If not specified, it will be automatically calculated based on the memory usage fraction. "
            "This option is typically used for development and debugging purposes.",
350
        )
351
352
353
354
355
356
357
358
359
360
361
362
        parser.add_argument(
            "--chunked-prefill-size",
            type=int,
            default=ServerArgs.chunked_prefill_size,
            help="The maximum number of tokens in a chunk for the chunked prefill. Setting this to -1 means disabling chunked prefill",
        )
        parser.add_argument(
            "--max-prefill-tokens",
            type=int,
            default=ServerArgs.max_prefill_tokens,
            help="The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model's maximum context length.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
363
        parser.add_argument(
364
            "--schedule-policy",
Lianmin Zheng's avatar
Lianmin Zheng committed
365
            type=str,
366
            default=ServerArgs.schedule_policy,
Liangsheng Yin's avatar
Liangsheng Yin committed
367
            choices=["lpm", "random", "fcfs", "dfs-weight"],
368
            help="The scheduling policy of the requests.",
Lianmin Zheng's avatar
Lianmin Zheng committed
369
        )
370
371
372
373
        parser.add_argument(
            "--schedule-conservativeness",
            type=float,
            default=ServerArgs.schedule_conservativeness,
374
            help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see requests being retracted frequently.",
375
        )
376
377

        # Other runtime options
Lianmin Zheng's avatar
Lianmin Zheng committed
378
        parser.add_argument(
379
            "--tensor-parallel-size",
Lianmin Zheng's avatar
Lianmin Zheng committed
380
            "--tp-size",
Lianmin Zheng's avatar
Lianmin Zheng committed
381
            type=int,
Lianmin Zheng's avatar
Lianmin Zheng committed
382
            default=ServerArgs.tp_size,
383
            help="The tensor parallelism size.",
384
        )
385
386
387
        parser.add_argument(
            "--stream-interval",
            type=int,
Lianmin Zheng's avatar
Lianmin Zheng committed
388
            default=ServerArgs.stream_interval,
389
            help="The interval (or buffer size) for streaming in terms of the token length. A smaller value makes streaming smoother, while a larger value makes the throughput higher",
390
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
391
392
393
394
        parser.add_argument(
            "--random-seed",
            type=int,
            default=ServerArgs.random_seed,
395
            help="The random seed.",
Lianmin Zheng's avatar
Lianmin Zheng committed
396
        )
397
398
399
400
401
402
        parser.add_argument(
            "--constrained-json-whitespace-pattern",
            type=str,
            default=ServerArgs.constrained_json_whitespace_pattern,
            help=r"Regex pattern for syntactic whitespaces allowed in JSON constrained output. For example, to allow the model generate consecutive whitespaces, set the pattern to [\n\t ]*",
        )
403
404
405
406
407
408
        parser.add_argument(
            "--watchdog-timeout",
            type=float,
            default=ServerArgs.watchdog_timeout,
            help="Set watchdog timeout in seconds. If a forward batch takes longer than this, the server will crash to prevent hanging.",
        )
409
410
411
412
413
414
        parser.add_argument(
            "--download-dir",
            type=str,
            default=ServerArgs.download_dir,
            help="Model download directory.",
        )
415
416

        # Logging
Lianmin Zheng's avatar
Lianmin Zheng committed
417
418
419
420
        parser.add_argument(
            "--log-level",
            type=str,
            default=ServerArgs.log_level,
421
            help="The logging level of all loggers.",
Lianmin Zheng's avatar
Lianmin Zheng committed
422
        )
423
        parser.add_argument(
424
425
426
427
            "--log-level-http",
            type=str,
            default=ServerArgs.log_level_http,
            help="The logging level of HTTP server. If not set, reuse --log-level by default.",
428
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
429
        parser.add_argument(
430
            "--log-requests",
Lianmin Zheng's avatar
Lianmin Zheng committed
431
            action="store_true",
432
            help="Log the inputs and outputs of all requests.",
Lianmin Zheng's avatar
Lianmin Zheng committed
433
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
434
435
436
        parser.add_argument(
            "--show-time-cost",
            action="store_true",
Ying Sheng's avatar
Ying Sheng committed
437
            help="Show time cost of custom marks.",
Lianmin Zheng's avatar
Lianmin Zheng committed
438
        )
439
440
441
442
443
        parser.add_argument(
            "--enable-metrics",
            action="store_true",
            help="Enable log prometheus metrics.",
        )
444
445
446
447
448
449
        parser.add_argument(
            "--decode-log-interval",
            type=int,
            default=ServerArgs.decode_log_interval,
            help="The log interval of decode batch",
        )
450

451
        # API related
Liangsheng Yin's avatar
Liangsheng Yin committed
452
453
454
455
        parser.add_argument(
            "--api-key",
            type=str,
            default=ServerArgs.api_key,
456
            help="Set API key of the server. It is also used in the OpenAI API compatible server.",
Liangsheng Yin's avatar
Liangsheng Yin committed
457
        )
458
459
460
461
462
463
        parser.add_argument(
            "--file-storage-pth",
            type=str,
            default=ServerArgs.file_storage_pth,
            help="The path of the file storage in backend.",
        )
464
465
466
467
468
        parser.add_argument(
            "--enable-cache-report",
            action="store_true",
            help="Return number of cached tokens in usage.prompt_tokens_details for each openai request.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
469

470
471
        # Data parallelism
        parser.add_argument(
472
            "--data-parallel-size",
473
474
475
            "--dp-size",
            type=int,
            default=ServerArgs.dp_size,
476
            help="The data parallelism size.",
477
478
479
480
481
        )
        parser.add_argument(
            "--load-balance-method",
            type=str,
            default=ServerArgs.load_balance_method,
482
            help="The load balancing strategy for data parallelism.",
483
484
485
486
487
488
            choices=[
                "round_robin",
                "shortest_queue",
            ],
        )

489
        # Multi-node distributed serving
490
        parser.add_argument(
491
492
            "--dist-init-addr",
            "--nccl-init-addr",  # For backward compatbility. This will be removed in the future.
493
            type=str,
494
            help="The host address for initializing distributed backend (e.g., `192.168.0.2:25000`).",
495
496
        )
        parser.add_argument(
Liangsheng Yin's avatar
Liangsheng Yin committed
497
            "--nnodes", type=int, default=ServerArgs.nnodes, help="The number of nodes."
498
        )
499
500
501
        parser.add_argument(
            "--node-rank", type=int, default=ServerArgs.node_rank, help="The node rank."
        )
502

Lianmin Zheng's avatar
Lianmin Zheng committed
503
504
505
506
507
508
509
510
        # Model override args
        parser.add_argument(
            "--json-model-override-args",
            type=str,
            help="A dictionary in JSON string format used to override default model configurations.",
            default=ServerArgs.json_model_override_args,
        )

Shuo Yang's avatar
Shuo Yang committed
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
        # Double Sparsity
        parser.add_argument(
            "--enable-double-sparsity",
            action="store_true",
            help="Enable double sparsity attention",
        )
        parser.add_argument(
            "--ds-channel-config-path",
            type=str,
            default=ServerArgs.ds_channel_config_path,
            help="The path of the double sparsity channel config",
        )
        parser.add_argument(
            "--ds-heavy-channel-num",
            type=int,
            default=ServerArgs.ds_heavy_channel_num,
            help="The number of heavy channels in double sparsity attention",
        )
        parser.add_argument(
            "--ds-heavy-token-num",
            type=int,
            default=ServerArgs.ds_heavy_token_num,
            help="The number of heavy tokens in double sparsity attention",
        )
        parser.add_argument(
            "--ds-heavy-channel-type",
            type=str,
            default=ServerArgs.ds_heavy_channel_type,
            help="The type of heavy channels in double sparsity attention",
        )
        parser.add_argument(
            "--ds-sparse-decode-threshold",
            type=int,
            default=ServerArgs.ds_sparse_decode_threshold,
            help="The type of heavy channels in double sparsity attention",
        )

548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
        # LoRA
        parser.add_argument(
            "--lora-paths",
            type=str,
            nargs="*",
            default=None,
            action=LoRAPathAction,
            help="The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}",
        )
        parser.add_argument(
            "--max-loras-per-batch",
            type=int,
            default=8,
            help="Maximum number of adapters for a running batch, include base-only request",
        )

        # Kernel backend
565
566
567
568
569
570
571
572
573
574
575
576
577
578
        parser.add_argument(
            "--attention-backend",
            type=str,
            choices=["flashinfer", "triton"],
            default=ServerArgs.attention_backend,
            help="Choose the kernels for attention layers.",
        )
        parser.add_argument(
            "--sampling-backend",
            type=str,
            choices=["flashinfer", "pytorch"],
            default=ServerArgs.sampling_backend,
            help="Choose the kernels for sampling layers.",
        )
579
580
581
582
583
        parser.add_argument(
            "--grammar-backend",
            type=str,
            choices=["xgrammar", "outlines"],
            default=ServerArgs.grammar_backend,
Lianmin Zheng's avatar
Lianmin Zheng committed
584
            help="Choose the backend for grammar-guided decoding.",
585
        )
586
587

        # Optimization/debug options
Liangsheng Yin's avatar
Liangsheng Yin committed
588
        parser.add_argument(
Lianmin Zheng's avatar
Lianmin Zheng committed
589
            "--disable-radix-cache",
Liangsheng Yin's avatar
Liangsheng Yin committed
590
            action="store_true",
Ying Sheng's avatar
Ying Sheng committed
591
            help="Disable RadixAttention for prefix caching.",
Liangsheng Yin's avatar
Liangsheng Yin committed
592
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
593
        parser.add_argument(
Lianmin Zheng's avatar
Lianmin Zheng committed
594
            "--disable-jump-forward",
Liangsheng Yin's avatar
Liangsheng Yin committed
595
            action="store_true",
Lianmin Zheng's avatar
Lianmin Zheng committed
596
            help="Disable jump-forward for grammar-guided decoding.",
Liangsheng Yin's avatar
Liangsheng Yin committed
597
        )
598
599
600
601
602
        parser.add_argument(
            "--disable-cuda-graph",
            action="store_true",
            help="Disable cuda graph.",
        )
603
        parser.add_argument(
604
605
606
607
608
            "--disable-cuda-graph-padding",
            action="store_true",
            help="Disable cuda graph when padding is needed. Still uses cuda graph when padding is not needed.",
        )
        parser.add_argument(
609
610
611
612
            "--disable-disk-cache",
            action="store_true",
            help="Disable disk cache to avoid possible crashes related to file system or high concurrency.",
        )
613
614
615
616
617
        parser.add_argument(
            "--disable-custom-all-reduce",
            action="store_true",
            help="Disable the custom all-reduce kernel and fall back to NCCL.",
        )
Ke Bao's avatar
Ke Bao committed
618
619
620
621
622
        parser.add_argument(
            "--disable-mla",
            action="store_true",
            help="Disable Multi-head Latent Attention (MLA) for DeepSeek-V2.",
        )
623
624
625
626
        parser.add_argument(
            "--disable-nan-detection",
            action="store_true",
            help="Disable the NaN detection for better performance.",
627
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
628
        parser.add_argument(
629
            "--disable-overlap-schedule",
Lianmin Zheng's avatar
Lianmin Zheng committed
630
            action="store_true",
631
            help="Disable the overlap scheduler, which overlaps the CPU scheduler with GPU model worker.",
Lianmin Zheng's avatar
Lianmin Zheng committed
632
        )
633
634
635
        parser.add_argument(
            "--enable-mixed-chunk",
            action="store_true",
636
            help="Enabling mixing prefill and decode in a batch when using chunked prefill.",
637
        )
Ke Bao's avatar
Ke Bao committed
638
639
640
641
642
        parser.add_argument(
            "--enable-dp-attention",
            action="store_true",
            help="Enabling data parallelism for attention and tensor parallelism for FFN. The dp size should be equal to the tp size. Currently only DeepSeek-V2 is supported.",
        )
643
644
645
        parser.add_argument(
            "--enable-torch-compile",
            action="store_true",
646
647
            help="Optimize the model with torch.compile. Experimental feature.",
        )
648
        parser.add_argument(
649
            "--torch-compile-max-bs",
650
            type=int,
651
            default=ServerArgs.torch_compile_max_bs,
652
653
            help="Set the maximum batch size when using torch compile.",
        )
654
        parser.add_argument(
655
            "--cuda-graph-max-bs",
656
            type=int,
657
            default=ServerArgs.cuda_graph_max_bs,
658
659
            help="Set the maximum batch size for cuda graph.",
        )
660
661
662
663
        parser.add_argument(
            "--torchao-config",
            type=str,
            default=ServerArgs.torchao_config,
664
            help="Optimize the model with torchao. Experimental feature. Current choices are: int8dq, int8wo, int4wo-<group_size>, fp8wo, fp8dq-per_tensor, fp8dq-per_row",
665
        )
666
667
668
669
670
        parser.add_argument(
            "--enable-nan-detection",
            action="store_true",
            help="Enable the NaN detection for debugging purposes.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
671
        parser.add_argument(
672
            "--enable-p2p-check",
Lianmin Zheng's avatar
Lianmin Zheng committed
673
            action="store_true",
674
            help="Enable P2P check for GPU access, otherwise the p2p access is allowed by default.",
Lianmin Zheng's avatar
Lianmin Zheng committed
675
        )
676
        parser.add_argument(
677
            "--triton-attention-reduce-in-fp32",
678
            action="store_true",
679
            help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
680
            "This only affects Triton attention kernels.",
681
        )
682
683
684
685
686
687
688
689
        parser.add_argument(
            "--num-continuous-decode-steps",
            type=int,
            default=ServerArgs.num_continuous_decode_steps,
            help="Run multiple continuous decoding steps to reduce scheduling overhead. "
            "This can potentially increase throughput but may also increase time-to-first-token latency. "
            "The default value is 1, meaning only run one decoding step at a time.",
        )
690
691
692
693
694
        parser.add_argument(
            "--delete-ckpt-after-loading",
            action="store_true",
            help="Delete the model checkpoint after loading the model.",
        )
695

696
        # Deprecated arguments
697
698
699
700
701
        parser.add_argument(
            "--enable-overlap-schedule",
            action=DeprecatedAction,
            help="'--enable-overlap-schedule' is deprecated. It is enabled by default now. Please drop this argument.",
        )
702
703
704
705
706
707
708
709
710
711
712
        parser.add_argument(
            "--disable-flashinfer",
            action=DeprecatedAction,
            help="'--disable-flashinfer' is deprecated. Please use '--attention-backend triton' instead.",
        )
        parser.add_argument(
            "--disable-flashinfer-sampling",
            action=DeprecatedAction,
            help="'--disable-flashinfer-sampling' is deprecated. Please use '--sampling-backend pytroch' instead.",
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
713
714
    @classmethod
    def from_cli_args(cls, args: argparse.Namespace):
715
716
        args.tp_size = args.tensor_parallel_size
        args.dp_size = args.data_parallel_size
Lianmin Zheng's avatar
Lianmin Zheng committed
717
718
719
720
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        return cls(**{attr: getattr(args, attr) for attr in attrs})

    def url(self):
721
722
723
724
        if is_ipv6(self.host):
            return f"http://[{self.host}]:{self.port}"
        else:
            return f"http://{self.host}:{self.port}"
Lianmin Zheng's avatar
Lianmin Zheng committed
725

726
727
728
729
730
    def check_server_args(self):
        assert (
            self.tp_size % self.nnodes == 0
        ), "tp_size must be divisible by number of nodes"
        assert not (
731
            self.dp_size > 1 and self.nnodes != 1
732
        ), "multi-node data parallel is not supported"
733
734
735
736
737
738
        assert (
            self.max_loras_per_batch > 0
            # FIXME
            and (self.lora_paths is None or self.disable_cuda_graph)
            and (self.lora_paths is None or self.disable_radix_cache)
        ), "compatibility of lora and cuda graph and radix attention is in progress"
739

740
741
742
743
744
745
746
747
748
749
        if isinstance(self.lora_paths, list):
            lora_paths = self.lora_paths
            self.lora_paths = {}
            for lora_path in lora_paths:
                if "=" in lora_path:
                    name, path = lora_path.split("=", 1)
                    self.lora_paths[name] = path
                else:
                    self.lora_paths[lora_path] = lora_path

Lianmin Zheng's avatar
Lianmin Zheng committed
750

Lianmin Zheng's avatar
Lianmin Zheng committed
751
def prepare_server_args(argv: List[str]) -> ServerArgs:
752
753
754
755
756
757
758
759
760
761
762
763
    """
    Prepare the server arguments from the command line arguments.

    Args:
        args: The command line arguments. Typically, it should be `sys.argv[1:]`
            to ensure compatibility with `parse_args` when no arguments are passed.

    Returns:
        The server arguments.
    """
    parser = argparse.ArgumentParser()
    ServerArgs.add_cli_args(parser)
Lianmin Zheng's avatar
Lianmin Zheng committed
764
    raw_args = parser.parse_args(argv)
765
766
767
768
    server_args = ServerArgs.from_cli_args(raw_args)
    return server_args


Lianmin Zheng's avatar
Lianmin Zheng committed
769
770
@dataclasses.dataclass
class PortArgs:
771
772
773
774
775
776
    # The ipc filename for tokenizer to receive inputs from detokenizer (zmq)
    tokenizer_ipc_name: str
    # The ipc filename for scheduler (rank 0) to receive inputs from tokenizer (zmq)
    scheduler_input_ipc_name: str
    # The ipc filename for detokenizer to receive inputs from scheduler (zmq)
    detokenizer_ipc_name: str
777

778
779
    # The port for nccl initialization (torch.dist)
    nccl_port: int
780

781
782
    @staticmethod
    def init_new(server_args) -> "PortArgs":
783
        port = server_args.port + random.randint(100, 1000)
784
785
786
        while True:
            if is_port_available(port):
                break
787
            port += 42
788
789
790
791
792

        return PortArgs(
            tokenizer_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
            scheduler_input_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
            detokenizer_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
793
            nccl_port=port,
794
795
        )

796
797
798
799
800
801
802
803
804
805

class LoRAPathAction(argparse.Action):
    def __call__(self, parser, namespace, values, option_string=None):
        setattr(namespace, self.dest, {})
        for lora_path in values:
            if "=" in lora_path:
                name, path = lora_path.split("=", 1)
                getattr(namespace, self.dest)[name] = path
            else:
                getattr(namespace, self.dest)[lora_path] = lora_path
806
807
808
809
810
811
812
813
814
815


class DeprecatedAction(argparse.Action):
    def __init__(self, option_strings, dest, nargs=0, **kwargs):
        super(DeprecatedAction, self).__init__(
            option_strings, dest, nargs=nargs, **kwargs
        )

    def __call__(self, parser, namespace, values, option_string=None):
        raise ValueError(self.help)