server_args.py 30.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
Lianmin Zheng's avatar
Lianmin Zheng committed
14
15
"""The arguments of the server."""

Lianmin Zheng's avatar
Lianmin Zheng committed
16
17
import argparse
import dataclasses
18
import logging
19
import random
20
21
import tempfile
from typing import List, Optional
Lianmin Zheng's avatar
Lianmin Zheng committed
22

23
from sglang.srt.utils import (
HAI's avatar
HAI committed
24
25
    get_amdgpu_memory_capacity,
    get_nvgpu_memory_capacity,
26
    is_flashinfer_available,
HAI's avatar
HAI committed
27
    is_hip,
28
29
30
    is_ipv6,
    is_port_available,
)
31

32
33
logger = logging.getLogger(__name__)

Lianmin Zheng's avatar
Lianmin Zheng committed
34
35
36

@dataclasses.dataclass
class ServerArgs:
Lianmin Zheng's avatar
Lianmin Zheng committed
37
    # Model and tokenizer
Lianmin Zheng's avatar
Lianmin Zheng committed
38
39
40
    model_path: str
    tokenizer_path: Optional[str] = None
    tokenizer_mode: str = "auto"
41
    skip_tokenizer_init: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
42
    load_format: str = "auto"
43
    trust_remote_code: bool = True
Lianmin Zheng's avatar
Lianmin Zheng committed
44
    dtype: str = "auto"
45
    kv_cache_dtype: str = "auto"
Lianmin Zheng's avatar
Lianmin Zheng committed
46
    quantization: Optional[str] = None
47
48
    context_length: Optional[int] = None
    device: str = "cuda"
49
    served_model_name: Optional[str] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
50
    chat_template: Optional[str] = None
51
    is_embedding: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
52
53
54
55
56
57

    # Port
    host: str = "127.0.0.1"
    port: int = 30000

    # Memory and scheduling
Lianmin Zheng's avatar
Lianmin Zheng committed
58
    mem_fraction_static: Optional[float] = None
59
    max_running_requests: Optional[int] = None
60
    max_total_tokens: Optional[int] = None
61
    chunked_prefill_size: int = 8192
62
    max_prefill_tokens: int = 16384
63
    schedule_policy: str = "lpm"
64
    schedule_conservativeness: float = 1.0
Lianmin Zheng's avatar
Lianmin Zheng committed
65
66
67

    # Other runtime options
    tp_size: int = 1
68
    stream_interval: int = 1
69
    random_seed: Optional[int] = None
70
    constrained_json_whitespace_pattern: Optional[str] = None
71
    watchdog_timeout: float = 300
72
    download_dir: Optional[str] = None
73
    base_gpu_id: int = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
74
75
76

    # Logging
    log_level: str = "info"
77
    log_level_http: Optional[str] = None
78
    log_requests: bool = False
Liangsheng Yin's avatar
Liangsheng Yin committed
79
    show_time_cost: bool = False
80
    enable_metrics: bool = False
81
    decode_log_interval: int = 40
Liangsheng Yin's avatar
Liangsheng Yin committed
82

83
    # API related
84
    api_key: Optional[str] = None
85
    file_storage_pth: str = "SGLang_storage"
86
    enable_cache_report: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
87

88
89
90
91
    # Data parallelism
    dp_size: int = 1
    load_balance_method: str = "round_robin"

92
    # Multi-node distributed serving
93
    dist_init_addr: Optional[str] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
94
    nnodes: int = 1
95
    node_rank: int = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
96
97
98
99

    # Model override args in JSON
    json_model_override_args: str = "{}"

Shuo Yang's avatar
Shuo Yang committed
100
101
102
103
104
105
106
107
    # Double Sparsity
    enable_double_sparsity: bool = False
    ds_channel_config_path: str = None
    ds_heavy_channel_num: int = 32
    ds_heavy_token_num: int = 256
    ds_heavy_channel_type: str = "qk"
    ds_sparse_decode_threshold: int = 4096

108
109
110
111
112
    # LoRA
    lora_paths: Optional[List[str]] = None
    max_loras_per_batch: int = 8

    # Kernel backend
113
114
    attention_backend: Optional[str] = None
    sampling_backend: Optional[str] = None
115
    grammar_backend: Optional[str] = "outlines"
116

117
    # Optimization/debug options
Lianmin Zheng's avatar
Lianmin Zheng committed
118
    disable_radix_cache: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
119
    disable_jump_forward: bool = False
120
    disable_cuda_graph: bool = False
121
    disable_cuda_graph_padding: bool = False
122
    disable_disk_cache: bool = False
123
    disable_custom_all_reduce: bool = False
Ke Bao's avatar
Ke Bao committed
124
    disable_mla: bool = False
125
    disable_overlap_schedule: bool = False
126
    enable_mixed_chunk: bool = False
Ke Bao's avatar
Ke Bao committed
127
    enable_dp_attention: bool = False
128
    enable_torch_compile: bool = False
129
130
    torch_compile_max_bs: int = 32
    cuda_graph_max_bs: int = 160
131
    torchao_config: str = ""
132
    enable_nan_detection: bool = False
133
    enable_p2p_check: bool = False
134
    triton_attention_reduce_in_fp32: bool = False
135
    num_continuous_decode_steps: int = 1
136
    delete_ckpt_after_loading: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
137
138

    def __post_init__(self):
139
        # Set missing default values
Lianmin Zheng's avatar
Lianmin Zheng committed
140
141
        if self.tokenizer_path is None:
            self.tokenizer_path = self.model_path
142
143
144
145

        if self.served_model_name is None:
            self.served_model_name = self.model_path

146
147
148
149
        if self.chunked_prefill_size <= 0:
            # Disable chunked prefill
            self.chunked_prefill_size = None

150
151
152
        if self.random_seed is None:
            self.random_seed = random.randint(0, 1 << 30)

153
        # Mem fraction depends on the tensor parallelism size
Lianmin Zheng's avatar
Lianmin Zheng committed
154
        if self.mem_fraction_static is None:
155
            if self.tp_size >= 16:
156
                self.mem_fraction_static = 0.79
157
            elif self.tp_size >= 8:
158
                self.mem_fraction_static = 0.82
Lianmin Zheng's avatar
Lianmin Zheng committed
159
            elif self.tp_size >= 4:
160
                self.mem_fraction_static = 0.85
Lianmin Zheng's avatar
Lianmin Zheng committed
161
            elif self.tp_size >= 2:
162
                self.mem_fraction_static = 0.87
Ying Sheng's avatar
Ying Sheng committed
163
            else:
164
                self.mem_fraction_static = 0.88
165

166
        # Adjust for GPUs with small memory capacities
HAI's avatar
HAI committed
167
168
169
170
        if is_hip():
            gpu_mem = get_amdgpu_memory_capacity()
        else:
            gpu_mem = get_nvgpu_memory_capacity()
171
        if gpu_mem < 25000:
172
173
            self.chunked_prefill_size //= 4  # make it 2048
            self.cuda_graph_max_bs = 4
174
            logger.info("Automatically adjust --chunked-prefill-size for small GPUs.")
175

176
        # Choose kernel backends
177
        if not is_flashinfer_available():
178
179
180
            self.attention_backend = "triton"
            self.sampling_backend = "pytorch"

181
182
183
184
185
        if self.attention_backend is None:
            self.attention_backend = "flashinfer"
        if self.sampling_backend is None:
            self.sampling_backend = "flashinfer"

186
        # Others
Ke Bao's avatar
Ke Bao committed
187
188
189
        if self.enable_dp_attention:
            self.dp_size = self.tp_size
            self.chunked_prefill_size = self.chunked_prefill_size // 2
190
            self.cuda_graph_max_bs = min(self.cuda_graph_max_bs, 96)
191
            self.schedule_conservativeness = self.schedule_conservativeness * 0.3
192
193
            self.disable_overlap_schedule = True
            logger.info(
194
                f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE kernel issues. "
195
                f"The CUDA graph max batch size is adjusted to {self.cuda_graph_max_bs}. "
196
                f"The schedule conservativeness is adjusted to {self.schedule_conservativeness}. "
197
198
199
200
201
202
203
                "Data parallel size is adjusted to be the same as tensor parallel size. "
                "Overlap schedule is disabled."
            )

        if self.enable_mixed_chunk:
            logger.info(
                "Overlap schedule is disabled because mixed-style chunked prefill is enabled."
Ke Bao's avatar
Ke Bao committed
204
            )
205
            self.disable_overlap_schedule = True
Ke Bao's avatar
Ke Bao committed
206

Lianmin Zheng's avatar
Lianmin Zheng committed
207
208
    @staticmethod
    def add_cli_args(parser: argparse.ArgumentParser):
209
        # Model and port args
Lianmin Zheng's avatar
Lianmin Zheng committed
210
211
212
213
214
215
216
217
218
219
220
221
        parser.add_argument(
            "--model-path",
            type=str,
            help="The path of the model weights. This can be a local folder or a Hugging Face repo ID.",
            required=True,
        )
        parser.add_argument(
            "--tokenizer-path",
            type=str,
            default=ServerArgs.tokenizer_path,
            help="The path of the tokenizer.",
        )
Yuanhan Zhang's avatar
Yuanhan Zhang committed
222
223
224
225
226
227
        parser.add_argument(
            "--host", type=str, default=ServerArgs.host, help="The host of the server."
        )
        parser.add_argument(
            "--port", type=int, default=ServerArgs.port, help="The port of the server."
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
228
229
230
231
232
233
234
235
236
        parser.add_argument(
            "--tokenizer-mode",
            type=str,
            default=ServerArgs.tokenizer_mode,
            choices=["auto", "slow"],
            help="Tokenizer mode. 'auto' will use the fast "
            "tokenizer if available, and 'slow' will "
            "always use the slow tokenizer.",
        )
237
238
239
240
241
        parser.add_argument(
            "--skip-tokenizer-init",
            action="store_true",
            help="If set, skip init tokenizer and pass input_ids in generate request",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
        parser.add_argument(
            "--load-format",
            type=str,
            default=ServerArgs.load_format,
            choices=["auto", "pt", "safetensors", "npcache", "dummy"],
            help="The format of the model weights to load. "
            '"auto" will try to load the weights in the safetensors format '
            "and fall back to the pytorch bin format if safetensors format "
            "is not available. "
            '"pt" will load the weights in the pytorch bin format. '
            '"safetensors" will load the weights in the safetensors format. '
            '"npcache" will load the weights in pytorch format and store '
            "a numpy cache to speed up the loading. "
            '"dummy" will initialize the weights with random values, '
            "which is mainly for profiling.",
        )
258
259
260
261
262
        parser.add_argument(
            "--trust-remote-code",
            action="store_true",
            help="Whether or not to allow for custom models defined on the Hub in their own modeling files.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
263
        parser.add_argument(
Lianmin Zheng's avatar
Lianmin Zheng committed
264
            "--dtype",
Cody Yu's avatar
Cody Yu committed
265
            type=str,
Lianmin Zheng's avatar
Lianmin Zheng committed
266
            default=ServerArgs.dtype,
Ying Sheng's avatar
Ying Sheng committed
267
268
            choices=["auto", "half", "float16", "bfloat16", "float", "float32"],
            help="Data type for model weights and activations.\n\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
269
            '* "auto" will use FP16 precision for FP32 and FP16 models, and '
Ying Sheng's avatar
Ying Sheng committed
270
            "BF16 precision for BF16 models.\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
271
272
273
274
            '* "half" for FP16. Recommended for AWQ quantization.\n'
            '* "float16" is the same as "half".\n'
            '* "bfloat16" for a balance between precision and range.\n'
            '* "float" is shorthand for FP32 precision.\n'
Ying Sheng's avatar
Ying Sheng committed
275
276
            '* "float32" for FP32 precision.',
        )
277
278
279
280
281
282
283
        parser.add_argument(
            "--kv-cache-dtype",
            type=str,
            default=ServerArgs.kv_cache_dtype,
            choices=["auto", "fp8_e5m2"],
            help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" is supported for CUDA 11.8+.',
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
284
285
286
287
        parser.add_argument(
            "--quantization",
            type=str,
            default=ServerArgs.quantization,
Ying Sheng's avatar
Ying Sheng committed
288
289
290
291
292
293
            choices=[
                "awq",
                "fp8",
                "gptq",
                "marlin",
                "gptq_marlin",
Ying Sheng's avatar
Ying Sheng committed
294
                "awq_marlin",
Ying Sheng's avatar
Ying Sheng committed
295
296
                "bitsandbytes",
            ],
Lianmin Zheng's avatar
Lianmin Zheng committed
297
298
            help="The quantization method.",
        )
299
300
301
302
303
304
305
306
307
308
309
310
311
        parser.add_argument(
            "--context-length",
            type=int,
            default=ServerArgs.context_length,
            help="The model's maximum context length. Defaults to None (will use the value from the model's config.json instead).",
        )
        parser.add_argument(
            "--device",
            type=str,
            default="cuda",
            choices=["cuda", "xpu"],
            help="The device type.",
        )
312
313
314
315
316
317
        parser.add_argument(
            "--served-model-name",
            type=str,
            default=ServerArgs.served_model_name,
            help="Override the model name returned by the v1/models endpoint in OpenAI API server.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
318
319
320
321
322
323
        parser.add_argument(
            "--chat-template",
            type=str,
            default=ServerArgs.chat_template,
            help="The buliltin chat template name or the path of the chat template file. This is only used for OpenAI-compatible API server.",
        )
324
325
326
327
328
        parser.add_argument(
            "--is-embedding",
            action="store_true",
            help="Whether to use a CausalLM as an embedding model.",
        )
329
330

        # Memory and scheduling
Lianmin Zheng's avatar
Lianmin Zheng committed
331
332
333
334
        parser.add_argument(
            "--mem-fraction-static",
            type=float,
            default=ServerArgs.mem_fraction_static,
335
            help="The fraction of the memory used for static allocation (model weights and KV cache memory pool). Use a smaller value if you see out-of-memory errors.",
Lianmin Zheng's avatar
Lianmin Zheng committed
336
        )
337
338
339
340
341
342
        parser.add_argument(
            "--max-running-requests",
            type=int,
            default=ServerArgs.max_running_requests,
            help="The maximum number of running requests.",
        )
343
344
345
346
        parser.add_argument(
            "--max-total-tokens",
            type=int,
            default=ServerArgs.max_total_tokens,
347
348
            help="The maximum number of tokens in the memory pool. If not specified, it will be automatically calculated based on the memory usage fraction. "
            "This option is typically used for development and debugging purposes.",
349
        )
350
351
352
353
354
355
356
357
358
359
360
361
        parser.add_argument(
            "--chunked-prefill-size",
            type=int,
            default=ServerArgs.chunked_prefill_size,
            help="The maximum number of tokens in a chunk for the chunked prefill. Setting this to -1 means disabling chunked prefill",
        )
        parser.add_argument(
            "--max-prefill-tokens",
            type=int,
            default=ServerArgs.max_prefill_tokens,
            help="The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model's maximum context length.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
362
        parser.add_argument(
363
            "--schedule-policy",
Lianmin Zheng's avatar
Lianmin Zheng committed
364
            type=str,
365
            default=ServerArgs.schedule_policy,
Liangsheng Yin's avatar
Liangsheng Yin committed
366
            choices=["lpm", "random", "fcfs", "dfs-weight"],
367
            help="The scheduling policy of the requests.",
Lianmin Zheng's avatar
Lianmin Zheng committed
368
        )
369
370
371
372
        parser.add_argument(
            "--schedule-conservativeness",
            type=float,
            default=ServerArgs.schedule_conservativeness,
373
            help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see requests being retracted frequently.",
374
        )
375
376

        # Other runtime options
Lianmin Zheng's avatar
Lianmin Zheng committed
377
        parser.add_argument(
378
            "--tensor-parallel-size",
Lianmin Zheng's avatar
Lianmin Zheng committed
379
            "--tp-size",
Lianmin Zheng's avatar
Lianmin Zheng committed
380
            type=int,
Lianmin Zheng's avatar
Lianmin Zheng committed
381
            default=ServerArgs.tp_size,
382
            help="The tensor parallelism size.",
383
        )
384
385
386
        parser.add_argument(
            "--stream-interval",
            type=int,
Lianmin Zheng's avatar
Lianmin Zheng committed
387
            default=ServerArgs.stream_interval,
388
            help="The interval (or buffer size) for streaming in terms of the token length. A smaller value makes streaming smoother, while a larger value makes the throughput higher",
389
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
390
391
392
393
        parser.add_argument(
            "--random-seed",
            type=int,
            default=ServerArgs.random_seed,
394
            help="The random seed.",
Lianmin Zheng's avatar
Lianmin Zheng committed
395
        )
396
397
398
399
400
401
        parser.add_argument(
            "--constrained-json-whitespace-pattern",
            type=str,
            default=ServerArgs.constrained_json_whitespace_pattern,
            help=r"Regex pattern for syntactic whitespaces allowed in JSON constrained output. For example, to allow the model generate consecutive whitespaces, set the pattern to [\n\t ]*",
        )
402
403
404
405
406
407
        parser.add_argument(
            "--watchdog-timeout",
            type=float,
            default=ServerArgs.watchdog_timeout,
            help="Set watchdog timeout in seconds. If a forward batch takes longer than this, the server will crash to prevent hanging.",
        )
408
409
410
411
412
413
        parser.add_argument(
            "--download-dir",
            type=str,
            default=ServerArgs.download_dir,
            help="Model download directory.",
        )
414
415
416
417
418
419
        parser.add_argument(
            "--base-gpu-id",
            type=int,
            default=ServerArgs.base_gpu_id,
            help="The base GPU ID to start allocating GPUs from. Useful when running multiple instances on the same machine.",
        )
420
421

        # Logging
Lianmin Zheng's avatar
Lianmin Zheng committed
422
423
424
425
        parser.add_argument(
            "--log-level",
            type=str,
            default=ServerArgs.log_level,
426
            help="The logging level of all loggers.",
Lianmin Zheng's avatar
Lianmin Zheng committed
427
        )
428
        parser.add_argument(
429
430
431
432
            "--log-level-http",
            type=str,
            default=ServerArgs.log_level_http,
            help="The logging level of HTTP server. If not set, reuse --log-level by default.",
433
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
434
        parser.add_argument(
435
            "--log-requests",
Lianmin Zheng's avatar
Lianmin Zheng committed
436
            action="store_true",
437
            help="Log the inputs and outputs of all requests.",
Lianmin Zheng's avatar
Lianmin Zheng committed
438
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
439
440
441
        parser.add_argument(
            "--show-time-cost",
            action="store_true",
Ying Sheng's avatar
Ying Sheng committed
442
            help="Show time cost of custom marks.",
Lianmin Zheng's avatar
Lianmin Zheng committed
443
        )
444
445
446
447
448
        parser.add_argument(
            "--enable-metrics",
            action="store_true",
            help="Enable log prometheus metrics.",
        )
449
450
451
452
453
454
        parser.add_argument(
            "--decode-log-interval",
            type=int,
            default=ServerArgs.decode_log_interval,
            help="The log interval of decode batch",
        )
455

456
        # API related
Liangsheng Yin's avatar
Liangsheng Yin committed
457
458
459
460
        parser.add_argument(
            "--api-key",
            type=str,
            default=ServerArgs.api_key,
461
            help="Set API key of the server. It is also used in the OpenAI API compatible server.",
Liangsheng Yin's avatar
Liangsheng Yin committed
462
        )
463
464
465
466
467
468
        parser.add_argument(
            "--file-storage-pth",
            type=str,
            default=ServerArgs.file_storage_pth,
            help="The path of the file storage in backend.",
        )
469
470
471
472
473
        parser.add_argument(
            "--enable-cache-report",
            action="store_true",
            help="Return number of cached tokens in usage.prompt_tokens_details for each openai request.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
474

475
476
        # Data parallelism
        parser.add_argument(
477
            "--data-parallel-size",
478
479
480
            "--dp-size",
            type=int,
            default=ServerArgs.dp_size,
481
            help="The data parallelism size.",
482
483
484
485
486
        )
        parser.add_argument(
            "--load-balance-method",
            type=str,
            default=ServerArgs.load_balance_method,
487
            help="The load balancing strategy for data parallelism.",
488
489
490
491
492
493
            choices=[
                "round_robin",
                "shortest_queue",
            ],
        )

494
        # Multi-node distributed serving
495
        parser.add_argument(
496
497
            "--dist-init-addr",
            "--nccl-init-addr",  # For backward compatbility. This will be removed in the future.
498
            type=str,
499
            help="The host address for initializing distributed backend (e.g., `192.168.0.2:25000`).",
500
501
        )
        parser.add_argument(
Liangsheng Yin's avatar
Liangsheng Yin committed
502
            "--nnodes", type=int, default=ServerArgs.nnodes, help="The number of nodes."
503
        )
504
505
506
        parser.add_argument(
            "--node-rank", type=int, default=ServerArgs.node_rank, help="The node rank."
        )
507

Lianmin Zheng's avatar
Lianmin Zheng committed
508
509
510
511
512
513
514
515
        # Model override args
        parser.add_argument(
            "--json-model-override-args",
            type=str,
            help="A dictionary in JSON string format used to override default model configurations.",
            default=ServerArgs.json_model_override_args,
        )

Shuo Yang's avatar
Shuo Yang committed
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
        # Double Sparsity
        parser.add_argument(
            "--enable-double-sparsity",
            action="store_true",
            help="Enable double sparsity attention",
        )
        parser.add_argument(
            "--ds-channel-config-path",
            type=str,
            default=ServerArgs.ds_channel_config_path,
            help="The path of the double sparsity channel config",
        )
        parser.add_argument(
            "--ds-heavy-channel-num",
            type=int,
            default=ServerArgs.ds_heavy_channel_num,
            help="The number of heavy channels in double sparsity attention",
        )
        parser.add_argument(
            "--ds-heavy-token-num",
            type=int,
            default=ServerArgs.ds_heavy_token_num,
            help="The number of heavy tokens in double sparsity attention",
        )
        parser.add_argument(
            "--ds-heavy-channel-type",
            type=str,
            default=ServerArgs.ds_heavy_channel_type,
            help="The type of heavy channels in double sparsity attention",
        )
        parser.add_argument(
            "--ds-sparse-decode-threshold",
            type=int,
            default=ServerArgs.ds_sparse_decode_threshold,
            help="The type of heavy channels in double sparsity attention",
        )

553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
        # LoRA
        parser.add_argument(
            "--lora-paths",
            type=str,
            nargs="*",
            default=None,
            action=LoRAPathAction,
            help="The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}",
        )
        parser.add_argument(
            "--max-loras-per-batch",
            type=int,
            default=8,
            help="Maximum number of adapters for a running batch, include base-only request",
        )

        # Kernel backend
570
571
572
573
574
575
576
577
578
579
580
581
582
583
        parser.add_argument(
            "--attention-backend",
            type=str,
            choices=["flashinfer", "triton"],
            default=ServerArgs.attention_backend,
            help="Choose the kernels for attention layers.",
        )
        parser.add_argument(
            "--sampling-backend",
            type=str,
            choices=["flashinfer", "pytorch"],
            default=ServerArgs.sampling_backend,
            help="Choose the kernels for sampling layers.",
        )
584
585
586
587
588
        parser.add_argument(
            "--grammar-backend",
            type=str,
            choices=["xgrammar", "outlines"],
            default=ServerArgs.grammar_backend,
Lianmin Zheng's avatar
Lianmin Zheng committed
589
            help="Choose the backend for grammar-guided decoding.",
590
        )
591
592

        # Optimization/debug options
Liangsheng Yin's avatar
Liangsheng Yin committed
593
        parser.add_argument(
Lianmin Zheng's avatar
Lianmin Zheng committed
594
            "--disable-radix-cache",
Liangsheng Yin's avatar
Liangsheng Yin committed
595
            action="store_true",
Ying Sheng's avatar
Ying Sheng committed
596
            help="Disable RadixAttention for prefix caching.",
Liangsheng Yin's avatar
Liangsheng Yin committed
597
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
598
        parser.add_argument(
Lianmin Zheng's avatar
Lianmin Zheng committed
599
            "--disable-jump-forward",
Liangsheng Yin's avatar
Liangsheng Yin committed
600
            action="store_true",
Lianmin Zheng's avatar
Lianmin Zheng committed
601
            help="Disable jump-forward for grammar-guided decoding.",
Liangsheng Yin's avatar
Liangsheng Yin committed
602
        )
603
604
605
606
607
        parser.add_argument(
            "--disable-cuda-graph",
            action="store_true",
            help="Disable cuda graph.",
        )
608
        parser.add_argument(
609
610
611
612
613
            "--disable-cuda-graph-padding",
            action="store_true",
            help="Disable cuda graph when padding is needed. Still uses cuda graph when padding is not needed.",
        )
        parser.add_argument(
614
615
616
617
            "--disable-disk-cache",
            action="store_true",
            help="Disable disk cache to avoid possible crashes related to file system or high concurrency.",
        )
618
619
620
621
622
        parser.add_argument(
            "--disable-custom-all-reduce",
            action="store_true",
            help="Disable the custom all-reduce kernel and fall back to NCCL.",
        )
Ke Bao's avatar
Ke Bao committed
623
624
625
626
627
        parser.add_argument(
            "--disable-mla",
            action="store_true",
            help="Disable Multi-head Latent Attention (MLA) for DeepSeek-V2.",
        )
628
629
630
631
        parser.add_argument(
            "--disable-nan-detection",
            action="store_true",
            help="Disable the NaN detection for better performance.",
632
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
633
        parser.add_argument(
634
            "--disable-overlap-schedule",
Lianmin Zheng's avatar
Lianmin Zheng committed
635
            action="store_true",
636
            help="Disable the overlap scheduler, which overlaps the CPU scheduler with GPU model worker.",
Lianmin Zheng's avatar
Lianmin Zheng committed
637
        )
638
639
640
        parser.add_argument(
            "--enable-mixed-chunk",
            action="store_true",
641
            help="Enabling mixing prefill and decode in a batch when using chunked prefill.",
642
        )
Ke Bao's avatar
Ke Bao committed
643
644
645
646
647
        parser.add_argument(
            "--enable-dp-attention",
            action="store_true",
            help="Enabling data parallelism for attention and tensor parallelism for FFN. The dp size should be equal to the tp size. Currently only DeepSeek-V2 is supported.",
        )
648
649
650
        parser.add_argument(
            "--enable-torch-compile",
            action="store_true",
651
652
            help="Optimize the model with torch.compile. Experimental feature.",
        )
653
        parser.add_argument(
654
            "--torch-compile-max-bs",
655
            type=int,
656
            default=ServerArgs.torch_compile_max_bs,
657
658
            help="Set the maximum batch size when using torch compile.",
        )
659
        parser.add_argument(
660
            "--cuda-graph-max-bs",
661
            type=int,
662
            default=ServerArgs.cuda_graph_max_bs,
663
664
            help="Set the maximum batch size for cuda graph.",
        )
665
666
667
668
        parser.add_argument(
            "--torchao-config",
            type=str,
            default=ServerArgs.torchao_config,
669
            help="Optimize the model with torchao. Experimental feature. Current choices are: int8dq, int8wo, int4wo-<group_size>, fp8wo, fp8dq-per_tensor, fp8dq-per_row",
670
        )
671
672
673
674
675
        parser.add_argument(
            "--enable-nan-detection",
            action="store_true",
            help="Enable the NaN detection for debugging purposes.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
676
        parser.add_argument(
677
            "--enable-p2p-check",
Lianmin Zheng's avatar
Lianmin Zheng committed
678
            action="store_true",
679
            help="Enable P2P check for GPU access, otherwise the p2p access is allowed by default.",
Lianmin Zheng's avatar
Lianmin Zheng committed
680
        )
681
        parser.add_argument(
682
            "--triton-attention-reduce-in-fp32",
683
            action="store_true",
684
            help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
685
            "This only affects Triton attention kernels.",
686
        )
687
688
689
690
691
692
693
694
        parser.add_argument(
            "--num-continuous-decode-steps",
            type=int,
            default=ServerArgs.num_continuous_decode_steps,
            help="Run multiple continuous decoding steps to reduce scheduling overhead. "
            "This can potentially increase throughput but may also increase time-to-first-token latency. "
            "The default value is 1, meaning only run one decoding step at a time.",
        )
695
696
697
698
699
        parser.add_argument(
            "--delete-ckpt-after-loading",
            action="store_true",
            help="Delete the model checkpoint after loading the model.",
        )
700

701
        # Deprecated arguments
702
703
704
705
706
        parser.add_argument(
            "--enable-overlap-schedule",
            action=DeprecatedAction,
            help="'--enable-overlap-schedule' is deprecated. It is enabled by default now. Please drop this argument.",
        )
707
708
709
710
711
712
713
714
715
716
717
        parser.add_argument(
            "--disable-flashinfer",
            action=DeprecatedAction,
            help="'--disable-flashinfer' is deprecated. Please use '--attention-backend triton' instead.",
        )
        parser.add_argument(
            "--disable-flashinfer-sampling",
            action=DeprecatedAction,
            help="'--disable-flashinfer-sampling' is deprecated. Please use '--sampling-backend pytroch' instead.",
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
718
719
    @classmethod
    def from_cli_args(cls, args: argparse.Namespace):
720
721
        args.tp_size = args.tensor_parallel_size
        args.dp_size = args.data_parallel_size
Lianmin Zheng's avatar
Lianmin Zheng committed
722
723
724
725
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        return cls(**{attr: getattr(args, attr) for attr in attrs})

    def url(self):
726
727
728
729
        if is_ipv6(self.host):
            return f"http://[{self.host}]:{self.port}"
        else:
            return f"http://{self.host}:{self.port}"
Lianmin Zheng's avatar
Lianmin Zheng committed
730

731
732
733
734
735
    def check_server_args(self):
        assert (
            self.tp_size % self.nnodes == 0
        ), "tp_size must be divisible by number of nodes"
        assert not (
736
            self.dp_size > 1 and self.nnodes != 1
737
        ), "multi-node data parallel is not supported"
738
739
740
741
742
743
        assert (
            self.max_loras_per_batch > 0
            # FIXME
            and (self.lora_paths is None or self.disable_cuda_graph)
            and (self.lora_paths is None or self.disable_radix_cache)
        ), "compatibility of lora and cuda graph and radix attention is in progress"
744
        assert self.base_gpu_id >= 0, "base_gpu_id must be non-negative"
745

746
747
748
749
750
751
752
753
754
755
        if isinstance(self.lora_paths, list):
            lora_paths = self.lora_paths
            self.lora_paths = {}
            for lora_path in lora_paths:
                if "=" in lora_path:
                    name, path = lora_path.split("=", 1)
                    self.lora_paths[name] = path
                else:
                    self.lora_paths[lora_path] = lora_path

Lianmin Zheng's avatar
Lianmin Zheng committed
756

Lianmin Zheng's avatar
Lianmin Zheng committed
757
def prepare_server_args(argv: List[str]) -> ServerArgs:
758
759
760
761
762
763
764
765
766
767
768
769
    """
    Prepare the server arguments from the command line arguments.

    Args:
        args: The command line arguments. Typically, it should be `sys.argv[1:]`
            to ensure compatibility with `parse_args` when no arguments are passed.

    Returns:
        The server arguments.
    """
    parser = argparse.ArgumentParser()
    ServerArgs.add_cli_args(parser)
Lianmin Zheng's avatar
Lianmin Zheng committed
770
    raw_args = parser.parse_args(argv)
771
772
773
774
    server_args = ServerArgs.from_cli_args(raw_args)
    return server_args


Lianmin Zheng's avatar
Lianmin Zheng committed
775
776
@dataclasses.dataclass
class PortArgs:
777
778
779
780
781
782
    # The ipc filename for tokenizer to receive inputs from detokenizer (zmq)
    tokenizer_ipc_name: str
    # The ipc filename for scheduler (rank 0) to receive inputs from tokenizer (zmq)
    scheduler_input_ipc_name: str
    # The ipc filename for detokenizer to receive inputs from scheduler (zmq)
    detokenizer_ipc_name: str
783

784
785
    # The port for nccl initialization (torch.dist)
    nccl_port: int
786

787
788
    @staticmethod
    def init_new(server_args) -> "PortArgs":
789
        port = server_args.port + random.randint(100, 1000)
790
791
792
        while True:
            if is_port_available(port):
                break
793
            port += 42
794
795
796
797
798

        return PortArgs(
            tokenizer_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
            scheduler_input_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
            detokenizer_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
799
            nccl_port=port,
800
801
        )

802
803
804
805
806
807
808
809
810
811

class LoRAPathAction(argparse.Action):
    def __call__(self, parser, namespace, values, option_string=None):
        setattr(namespace, self.dest, {})
        for lora_path in values:
            if "=" in lora_path:
                name, path = lora_path.split("=", 1)
                getattr(namespace, self.dest)[name] = path
            else:
                getattr(namespace, self.dest)[lora_path] = lora_path
812
813
814
815
816
817
818
819
820
821


class DeprecatedAction(argparse.Action):
    def __init__(self, option_strings, dest, nargs=0, **kwargs):
        super(DeprecatedAction, self).__init__(
            option_strings, dest, nargs=nargs, **kwargs
        )

    def __call__(self, parser, namespace, values, option_string=None):
        raise ValueError(self.help)