server_args.py 28 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

Lianmin Zheng's avatar
Lianmin Zheng committed
16
17
"""The arguments of the server."""

Lianmin Zheng's avatar
Lianmin Zheng committed
18
19
import argparse
import dataclasses
20
import logging
21
import random
22
23
import tempfile
from typing import List, Optional
Lianmin Zheng's avatar
Lianmin Zheng committed
24

25
from sglang.srt.utils import is_flashinfer_available, is_ipv6, is_port_available
26

27
28
logger = logging.getLogger(__name__)

Lianmin Zheng's avatar
Lianmin Zheng committed
29
30
31

@dataclasses.dataclass
class ServerArgs:
Lianmin Zheng's avatar
Lianmin Zheng committed
32
    # Model and tokenizer
Lianmin Zheng's avatar
Lianmin Zheng committed
33
34
35
    model_path: str
    tokenizer_path: Optional[str] = None
    tokenizer_mode: str = "auto"
36
    skip_tokenizer_init: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
37
    load_format: str = "auto"
38
    trust_remote_code: bool = True
Lianmin Zheng's avatar
Lianmin Zheng committed
39
    dtype: str = "auto"
40
    kv_cache_dtype: str = "auto"
Lianmin Zheng's avatar
Lianmin Zheng committed
41
    quantization: Optional[str] = None
42
43
    context_length: Optional[int] = None
    device: str = "cuda"
44
    served_model_name: Optional[str] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
45
    chat_template: Optional[str] = None
46
    is_embedding: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
47
48
49
50
51
52

    # Port
    host: str = "127.0.0.1"
    port: int = 30000

    # Memory and scheduling
Lianmin Zheng's avatar
Lianmin Zheng committed
53
    mem_fraction_static: Optional[float] = None
54
    max_running_requests: Optional[int] = None
55
    max_total_tokens: Optional[int] = None
56
    chunked_prefill_size: int = 8192
57
    max_prefill_tokens: int = 16384
58
    schedule_policy: str = "lpm"
59
    schedule_conservativeness: float = 1.0
Lianmin Zheng's avatar
Lianmin Zheng committed
60
61
62

    # Other runtime options
    tp_size: int = 1
63
    stream_interval: int = 1
64
    random_seed: Optional[int] = None
65
    constrained_json_whitespace_pattern: Optional[str] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
66
67
68

    # Logging
    log_level: str = "info"
69
    log_level_http: Optional[str] = None
70
    log_requests: bool = False
Liangsheng Yin's avatar
Liangsheng Yin committed
71
    show_time_cost: bool = False
Liangsheng Yin's avatar
Liangsheng Yin committed
72

Lianmin Zheng's avatar
Lianmin Zheng committed
73
    # Other
74
    api_key: Optional[str] = None
75
    file_storage_pth: str = "SGLang_storage"
76
    enable_cache_report: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
77
    watchdog_timeout: float = 600
Lianmin Zheng's avatar
Lianmin Zheng committed
78

79
80
81
82
    # Data parallelism
    dp_size: int = 1
    load_balance_method: str = "round_robin"

Lianmin Zheng's avatar
Lianmin Zheng committed
83
    # Distributed args
84
    dist_init_addr: Optional[str] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
85
    nnodes: int = 1
86
    node_rank: int = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
87
88
89
90

    # Model override args in JSON
    json_model_override_args: str = "{}"

Shuo Yang's avatar
Shuo Yang committed
91
92
93
94
95
96
97
98
    # Double Sparsity
    enable_double_sparsity: bool = False
    ds_channel_config_path: str = None
    ds_heavy_channel_num: int = 32
    ds_heavy_token_num: int = 256
    ds_heavy_channel_type: str = "qk"
    ds_sparse_decode_threshold: int = 4096

99
100
101
102
103
    # LoRA
    lora_paths: Optional[List[str]] = None
    max_loras_per_batch: int = 8

    # Kernel backend
104
105
    attention_backend: Optional[str] = None
    sampling_backend: Optional[str] = None
106
    grammar_backend: Optional[str] = "outlines"
107

108
    # Optimization/debug options
109
    disable_flashinfer: bool = False
110
    disable_flashinfer_sampling: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
111
    disable_radix_cache: bool = False
112
    disable_regex_jump_forward: bool = False
113
    disable_cuda_graph: bool = False
114
    disable_cuda_graph_padding: bool = False
115
    disable_disk_cache: bool = False
116
    disable_custom_all_reduce: bool = False
Ke Bao's avatar
Ke Bao committed
117
    disable_mla: bool = False
118
    disable_penalizer: bool = False
119
    disable_nan_detection: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
120
    enable_overlap_schedule: bool = False
121
    enable_mixed_chunk: bool = False
122
    enable_torch_compile: bool = False
123
124
    torch_compile_max_bs: int = 32
    cuda_graph_max_bs: int = 160
125
    torchao_config: str = ""
126
    enable_p2p_check: bool = False
127
    triton_attention_reduce_in_fp32: bool = False
128
    num_continuous_decode_steps: int = 1
Lianmin Zheng's avatar
Lianmin Zheng committed
129
130

    def __post_init__(self):
131
        # Set missing default values
Lianmin Zheng's avatar
Lianmin Zheng committed
132
133
        if self.tokenizer_path is None:
            self.tokenizer_path = self.model_path
134
135
136
137

        if self.served_model_name is None:
            self.served_model_name = self.model_path

138
139
140
141
        if self.chunked_prefill_size <= 0:
            # Disable chunked prefill
            self.chunked_prefill_size = None

142
        # Mem fraction depends on the tensor parallelism size
Lianmin Zheng's avatar
Lianmin Zheng committed
143
        if self.mem_fraction_static is None:
144
            if self.tp_size >= 16:
145
                self.mem_fraction_static = 0.79
146
            elif self.tp_size >= 8:
147
                self.mem_fraction_static = 0.83
Lianmin Zheng's avatar
Lianmin Zheng committed
148
            elif self.tp_size >= 4:
149
                self.mem_fraction_static = 0.85
Lianmin Zheng's avatar
Lianmin Zheng committed
150
            elif self.tp_size >= 2:
151
                self.mem_fraction_static = 0.87
Ying Sheng's avatar
Ying Sheng committed
152
            else:
153
                self.mem_fraction_static = 0.88
154

155
156
157
        if self.random_seed is None:
            self.random_seed = random.randint(0, 1 << 30)

158
159
160
161
162
163
        # Deprecation warnings
        if self.disable_flashinfer:
            logger.warning(
                "The option '--disable-flashinfer' will be deprecated in the next release. "
                "Please use '--attention-backend triton' instead."
            )
164
            self.attention_backend = "triton"
165
166
167
168
169
        if self.disable_flashinfer_sampling:
            logger.warning(
                "The option '--disable-flashinfer-sampling' will be deprecated in the next release. "
                "Please use '--sampling-backend pytorch' instead. "
            )
170
            self.sampling_backend = "pytorch"
171

172
        if not is_flashinfer_available():
173
174
175
            self.attention_backend = "triton"
            self.sampling_backend = "pytorch"

176
177
178
179
180
181
182
        # Default kernel backends
        if self.attention_backend is None:
            self.attention_backend = "flashinfer"

        if self.sampling_backend is None:
            self.sampling_backend = "flashinfer"

183
184
185
186
187
188
189
190
191
192
        if self.enable_overlap_schedule:
            logger.warning(
                "Overlap scheduler mode is enabled. This is an experimental feature. "
                "Sampling penalizer (e.g., frequency and repetition penalty), constrained decoding (e.g., regex, JSON), "
                "and embedding APIs are not supported and will lead to wrong results. "
                "The NaN detection is also disabled."
            )
            self.disable_penalizer = True
            self.disable_nan_detection = True

193
194
195
196
197
198
199
200
201
202
203
        # Model-specific patches
        if "Alibaba-NLP/gte-Qwen2-1.5B-instruct" == self.model_path:
            logger.info(
                "Not sure why, the tokenizer will add an additional token at the end of the prompt when trust_remote_mode=True"
            )
            self.trust_remote_code = False

        if "gemma-2" in self.model_path.lower():
            logger.info("When using sliding window in gemma-2, turn on flashinfer.")
            self.attention_backend = "flashinfer"

Lianmin Zheng's avatar
Lianmin Zheng committed
204
205
206
207
208
209
210
211
212
213
214
215
216
217
    @staticmethod
    def add_cli_args(parser: argparse.ArgumentParser):
        parser.add_argument(
            "--model-path",
            type=str,
            help="The path of the model weights. This can be a local folder or a Hugging Face repo ID.",
            required=True,
        )
        parser.add_argument(
            "--tokenizer-path",
            type=str,
            default=ServerArgs.tokenizer_path,
            help="The path of the tokenizer.",
        )
Yuanhan Zhang's avatar
Yuanhan Zhang committed
218
219
220
221
222
223
        parser.add_argument(
            "--host", type=str, default=ServerArgs.host, help="The host of the server."
        )
        parser.add_argument(
            "--port", type=int, default=ServerArgs.port, help="The port of the server."
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
224
225
226
227
228
229
230
231
232
        parser.add_argument(
            "--tokenizer-mode",
            type=str,
            default=ServerArgs.tokenizer_mode,
            choices=["auto", "slow"],
            help="Tokenizer mode. 'auto' will use the fast "
            "tokenizer if available, and 'slow' will "
            "always use the slow tokenizer.",
        )
233
234
235
236
237
        parser.add_argument(
            "--skip-tokenizer-init",
            action="store_true",
            help="If set, skip init tokenizer and pass input_ids in generate request",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
        parser.add_argument(
            "--load-format",
            type=str,
            default=ServerArgs.load_format,
            choices=["auto", "pt", "safetensors", "npcache", "dummy"],
            help="The format of the model weights to load. "
            '"auto" will try to load the weights in the safetensors format '
            "and fall back to the pytorch bin format if safetensors format "
            "is not available. "
            '"pt" will load the weights in the pytorch bin format. '
            '"safetensors" will load the weights in the safetensors format. '
            '"npcache" will load the weights in pytorch format and store '
            "a numpy cache to speed up the loading. "
            '"dummy" will initialize the weights with random values, '
            "which is mainly for profiling.",
        )
254
255
256
257
258
        parser.add_argument(
            "--trust-remote-code",
            action="store_true",
            help="Whether or not to allow for custom models defined on the Hub in their own modeling files.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
259
        parser.add_argument(
Lianmin Zheng's avatar
Lianmin Zheng committed
260
            "--dtype",
Cody Yu's avatar
Cody Yu committed
261
            type=str,
Lianmin Zheng's avatar
Lianmin Zheng committed
262
            default=ServerArgs.dtype,
Ying Sheng's avatar
Ying Sheng committed
263
264
            choices=["auto", "half", "float16", "bfloat16", "float", "float32"],
            help="Data type for model weights and activations.\n\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
265
            '* "auto" will use FP16 precision for FP32 and FP16 models, and '
Ying Sheng's avatar
Ying Sheng committed
266
            "BF16 precision for BF16 models.\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
267
268
269
270
            '* "half" for FP16. Recommended for AWQ quantization.\n'
            '* "float16" is the same as "half".\n'
            '* "bfloat16" for a balance between precision and range.\n'
            '* "float" is shorthand for FP32 precision.\n'
Ying Sheng's avatar
Ying Sheng committed
271
272
            '* "float32" for FP32 precision.',
        )
273
274
275
276
277
278
279
        parser.add_argument(
            "--kv-cache-dtype",
            type=str,
            default=ServerArgs.kv_cache_dtype,
            choices=["auto", "fp8_e5m2"],
            help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" is supported for CUDA 11.8+.',
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
280
281
282
283
        parser.add_argument(
            "--quantization",
            type=str,
            default=ServerArgs.quantization,
Ying Sheng's avatar
Ying Sheng committed
284
285
286
287
288
289
            choices=[
                "awq",
                "fp8",
                "gptq",
                "marlin",
                "gptq_marlin",
Ying Sheng's avatar
Ying Sheng committed
290
                "awq_marlin",
Ying Sheng's avatar
Ying Sheng committed
291
292
                "bitsandbytes",
            ],
Lianmin Zheng's avatar
Lianmin Zheng committed
293
294
            help="The quantization method.",
        )
295
296
297
298
299
300
301
302
303
304
305
306
307
        parser.add_argument(
            "--context-length",
            type=int,
            default=ServerArgs.context_length,
            help="The model's maximum context length. Defaults to None (will use the value from the model's config.json instead).",
        )
        parser.add_argument(
            "--device",
            type=str,
            default="cuda",
            choices=["cuda", "xpu"],
            help="The device type.",
        )
308
309
310
311
312
313
        parser.add_argument(
            "--served-model-name",
            type=str,
            default=ServerArgs.served_model_name,
            help="Override the model name returned by the v1/models endpoint in OpenAI API server.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
314
315
316
317
318
319
        parser.add_argument(
            "--chat-template",
            type=str,
            default=ServerArgs.chat_template,
            help="The buliltin chat template name or the path of the chat template file. This is only used for OpenAI-compatible API server.",
        )
320
321
322
323
324
        parser.add_argument(
            "--is-embedding",
            action="store_true",
            help="Whether to use a CausalLM as an embedding model.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
325
326
327
328
        parser.add_argument(
            "--mem-fraction-static",
            type=float,
            default=ServerArgs.mem_fraction_static,
329
            help="The fraction of the memory used for static allocation (model weights and KV cache memory pool). Use a smaller value if you see out-of-memory errors.",
Lianmin Zheng's avatar
Lianmin Zheng committed
330
        )
331
332
333
334
335
336
        parser.add_argument(
            "--max-running-requests",
            type=int,
            default=ServerArgs.max_running_requests,
            help="The maximum number of running requests.",
        )
337
338
339
340
        parser.add_argument(
            "--max-total-tokens",
            type=int,
            default=ServerArgs.max_total_tokens,
341
342
            help="The maximum number of tokens in the memory pool. If not specified, it will be automatically calculated based on the memory usage fraction. "
            "This option is typically used for development and debugging purposes.",
343
        )
344
345
346
347
348
349
350
351
352
353
354
355
        parser.add_argument(
            "--chunked-prefill-size",
            type=int,
            default=ServerArgs.chunked_prefill_size,
            help="The maximum number of tokens in a chunk for the chunked prefill. Setting this to -1 means disabling chunked prefill",
        )
        parser.add_argument(
            "--max-prefill-tokens",
            type=int,
            default=ServerArgs.max_prefill_tokens,
            help="The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model's maximum context length.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
356
        parser.add_argument(
357
            "--schedule-policy",
Lianmin Zheng's avatar
Lianmin Zheng committed
358
            type=str,
359
            default=ServerArgs.schedule_policy,
Liangsheng Yin's avatar
Liangsheng Yin committed
360
            choices=["lpm", "random", "fcfs", "dfs-weight"],
361
            help="The scheduling policy of the requests.",
Lianmin Zheng's avatar
Lianmin Zheng committed
362
        )
363
364
365
366
        parser.add_argument(
            "--schedule-conservativeness",
            type=float,
            default=ServerArgs.schedule_conservativeness,
367
            help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see requests being retracted frequently.",
368
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
369
        parser.add_argument(
370
            "--tensor-parallel-size",
Lianmin Zheng's avatar
Lianmin Zheng committed
371
            "--tp-size",
Lianmin Zheng's avatar
Lianmin Zheng committed
372
            type=int,
Lianmin Zheng's avatar
Lianmin Zheng committed
373
            default=ServerArgs.tp_size,
374
            help="The tensor parallelism size.",
375
        )
376
377
378
        parser.add_argument(
            "--stream-interval",
            type=int,
Lianmin Zheng's avatar
Lianmin Zheng committed
379
            default=ServerArgs.stream_interval,
380
            help="The interval (or buffer size) for streaming in terms of the token length. A smaller value makes streaming smoother, while a larger value makes the throughput higher",
381
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
382
383
384
385
        parser.add_argument(
            "--random-seed",
            type=int,
            default=ServerArgs.random_seed,
386
            help="The random seed.",
Lianmin Zheng's avatar
Lianmin Zheng committed
387
        )
388
389
390
391
392
393
        parser.add_argument(
            "--constrained-json-whitespace-pattern",
            type=str,
            default=ServerArgs.constrained_json_whitespace_pattern,
            help=r"Regex pattern for syntactic whitespaces allowed in JSON constrained output. For example, to allow the model generate consecutive whitespaces, set the pattern to [\n\t ]*",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
394
395
396
397
        parser.add_argument(
            "--log-level",
            type=str,
            default=ServerArgs.log_level,
398
            help="The logging level of all loggers.",
Lianmin Zheng's avatar
Lianmin Zheng committed
399
        )
400
        parser.add_argument(
401
402
403
404
            "--log-level-http",
            type=str,
            default=ServerArgs.log_level_http,
            help="The logging level of HTTP server. If not set, reuse --log-level by default.",
405
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
406
        parser.add_argument(
407
            "--log-requests",
Lianmin Zheng's avatar
Lianmin Zheng committed
408
            action="store_true",
409
            help="Log the inputs and outputs of all requests.",
Lianmin Zheng's avatar
Lianmin Zheng committed
410
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
411
412
413
        parser.add_argument(
            "--show-time-cost",
            action="store_true",
Ying Sheng's avatar
Ying Sheng committed
414
            help="Show time cost of custom marks.",
Lianmin Zheng's avatar
Lianmin Zheng committed
415
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
416
417
418
419
        parser.add_argument(
            "--api-key",
            type=str,
            default=ServerArgs.api_key,
420
            help="Set API key of the server. It is also used in the OpenAI API compatible server.",
Liangsheng Yin's avatar
Liangsheng Yin committed
421
        )
422
423
424
425
426
427
        parser.add_argument(
            "--file-storage-pth",
            type=str,
            default=ServerArgs.file_storage_pth,
            help="The path of the file storage in backend.",
        )
428
429
430
431
432
        parser.add_argument(
            "--enable-cache-report",
            action="store_true",
            help="Return number of cached tokens in usage.prompt_tokens_details for each openai request.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
433
434
435
436
437
438
        parser.add_argument(
            "--watchdog-timeout",
            type=float,
            default=ServerArgs.watchdog_timeout,
            help="Set watchdog timeout in seconds. If a forward batch takes longer than this, the server will crash to prevent hanging.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
439

440
441
        # Data parallelism
        parser.add_argument(
442
            "--data-parallel-size",
443
444
445
            "--dp-size",
            type=int,
            default=ServerArgs.dp_size,
446
            help="The data parallelism size.",
447
448
449
450
451
        )
        parser.add_argument(
            "--load-balance-method",
            type=str,
            default=ServerArgs.load_balance_method,
452
            help="The load balancing strategy for data parallelism.",
453
454
455
456
457
458
            choices=[
                "round_robin",
                "shortest_queue",
            ],
        )

459
460
        # Multi-node distributed serving args
        parser.add_argument(
461
462
            "--dist-init-addr",
            "--nccl-init-addr",  # For backward compatbility. This will be removed in the future.
463
            type=str,
464
            help="The host address for initializing distributed backend (e.g., `192.168.0.2:25000`).",
465
466
        )
        parser.add_argument(
Liangsheng Yin's avatar
Liangsheng Yin committed
467
            "--nnodes", type=int, default=ServerArgs.nnodes, help="The number of nodes."
468
        )
469
470
471
        parser.add_argument(
            "--node-rank", type=int, default=ServerArgs.node_rank, help="The node rank."
        )
472

Lianmin Zheng's avatar
Lianmin Zheng committed
473
474
475
476
477
478
479
480
        # Model override args
        parser.add_argument(
            "--json-model-override-args",
            type=str,
            help="A dictionary in JSON string format used to override default model configurations.",
            default=ServerArgs.json_model_override_args,
        )

Shuo Yang's avatar
Shuo Yang committed
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
        # Double Sparsity
        parser.add_argument(
            "--enable-double-sparsity",
            action="store_true",
            help="Enable double sparsity attention",
        )
        parser.add_argument(
            "--ds-channel-config-path",
            type=str,
            default=ServerArgs.ds_channel_config_path,
            help="The path of the double sparsity channel config",
        )
        parser.add_argument(
            "--ds-heavy-channel-num",
            type=int,
            default=ServerArgs.ds_heavy_channel_num,
            help="The number of heavy channels in double sparsity attention",
        )
        parser.add_argument(
            "--ds-heavy-token-num",
            type=int,
            default=ServerArgs.ds_heavy_token_num,
            help="The number of heavy tokens in double sparsity attention",
        )
        parser.add_argument(
            "--ds-heavy-channel-type",
            type=str,
            default=ServerArgs.ds_heavy_channel_type,
            help="The type of heavy channels in double sparsity attention",
        )
        parser.add_argument(
            "--ds-sparse-decode-threshold",
            type=int,
            default=ServerArgs.ds_sparse_decode_threshold,
            help="The type of heavy channels in double sparsity attention",
        )

518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
        # LoRA
        parser.add_argument(
            "--lora-paths",
            type=str,
            nargs="*",
            default=None,
            action=LoRAPathAction,
            help="The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}",
        )
        parser.add_argument(
            "--max-loras-per-batch",
            type=int,
            default=8,
            help="Maximum number of adapters for a running batch, include base-only request",
        )

        # Kernel backend
535
536
537
538
539
540
541
542
543
544
545
546
547
548
        parser.add_argument(
            "--attention-backend",
            type=str,
            choices=["flashinfer", "triton"],
            default=ServerArgs.attention_backend,
            help="Choose the kernels for attention layers.",
        )
        parser.add_argument(
            "--sampling-backend",
            type=str,
            choices=["flashinfer", "pytorch"],
            default=ServerArgs.sampling_backend,
            help="Choose the kernels for sampling layers.",
        )
549
550
551
552
553
554
555
        parser.add_argument(
            "--grammar-backend",
            type=str,
            choices=["xgrammar", "outlines"],
            default=ServerArgs.grammar_backend,
            help="Choose the backend for constrained decoding.",
        )
556
557

        # Optimization/debug options
Liangsheng Yin's avatar
Liangsheng Yin committed
558
        parser.add_argument(
559
            "--disable-flashinfer",
Liangsheng Yin's avatar
Liangsheng Yin committed
560
            action="store_true",
561
            help="Disable flashinfer attention kernels. This option will be deprecated in the next release. Please use '--attention-backend triton' instead.",
562
563
564
565
        )
        parser.add_argument(
            "--disable-flashinfer-sampling",
            action="store_true",
566
            help="Disable flashinfer sampling kernels. This option will be deprecated in the next release. Please use '--sampling-backend pytorch' instead.",
Liangsheng Yin's avatar
Liangsheng Yin committed
567
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
568
        parser.add_argument(
Lianmin Zheng's avatar
Lianmin Zheng committed
569
            "--disable-radix-cache",
Liangsheng Yin's avatar
Liangsheng Yin committed
570
            action="store_true",
Ying Sheng's avatar
Ying Sheng committed
571
            help="Disable RadixAttention for prefix caching.",
Liangsheng Yin's avatar
Liangsheng Yin committed
572
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
573
        parser.add_argument(
574
            "--disable-regex-jump-forward",
Liangsheng Yin's avatar
Liangsheng Yin committed
575
            action="store_true",
Ying Sheng's avatar
Ying Sheng committed
576
            help="Disable regex jump-forward.",
Liangsheng Yin's avatar
Liangsheng Yin committed
577
        )
578
579
580
581
582
        parser.add_argument(
            "--disable-cuda-graph",
            action="store_true",
            help="Disable cuda graph.",
        )
583
        parser.add_argument(
584
585
586
587
588
            "--disable-cuda-graph-padding",
            action="store_true",
            help="Disable cuda graph when padding is needed. Still uses cuda graph when padding is not needed.",
        )
        parser.add_argument(
589
590
591
592
            "--disable-disk-cache",
            action="store_true",
            help="Disable disk cache to avoid possible crashes related to file system or high concurrency.",
        )
593
594
595
596
597
598
        parser.add_argument(
            "--disable-custom-all-reduce",
            action="store_true",
            default=False,
            help="Disable the custom all-reduce kernel and fall back to NCCL.",
        )
Ke Bao's avatar
Ke Bao committed
599
600
601
602
603
        parser.add_argument(
            "--disable-mla",
            action="store_true",
            help="Disable Multi-head Latent Attention (MLA) for DeepSeek-V2.",
        )
604
605
606
        parser.add_argument(
            "--disable-penalizer",
            action="store_true",
607
608
609
610
611
612
            help="Disable the logit penalizers (e.g., frequency and repetition penalty) for better performance if they are not used in any requests.",
        )
        parser.add_argument(
            "--disable-nan-detection",
            action="store_true",
            help="Disable the NaN detection for better performance.",
613
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
614
615
616
617
618
        parser.add_argument(
            "--enable-overlap-schedule",
            action="store_true",
            help="Overlap the CPU scheduler with GPU model worker. Experimental feature.",
        )
619
620
621
        parser.add_argument(
            "--enable-mixed-chunk",
            action="store_true",
622
            help="Enabling mixing prefill and decode in a batch when using chunked prefill.",
623
        )
624
625
626
        parser.add_argument(
            "--enable-torch-compile",
            action="store_true",
627
628
            help="Optimize the model with torch.compile. Experimental feature.",
        )
629
        parser.add_argument(
630
            "--torch-compile-max-bs",
631
            type=int,
632
            default=ServerArgs.torch_compile_max_bs,
633
634
            help="Set the maximum batch size when using torch compile.",
        )
635
        parser.add_argument(
636
            "--cuda-graph-max-bs",
637
            type=int,
638
            default=ServerArgs.cuda_graph_max_bs,
639
640
            help="Set the maximum batch size for cuda graph.",
        )
641
642
643
644
645
        parser.add_argument(
            "--torchao-config",
            type=str,
            default=ServerArgs.torchao_config,
            help="Optimize the model with torchao. Experimental feature. Current choices are: int8dq, int8wo, int4wo-<group_size>, fp8wo",
646
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
647
        parser.add_argument(
648
            "--enable-p2p-check",
Lianmin Zheng's avatar
Lianmin Zheng committed
649
            action="store_true",
650
            help="Enable P2P check for GPU access, otherwise the p2p access is allowed by default.",
Lianmin Zheng's avatar
Lianmin Zheng committed
651
        )
652
        parser.add_argument(
653
            "--triton-attention-reduce-in-fp32",
654
            action="store_true",
655
            help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
656
            "This only affects Triton attention kernels.",
657
        )
658
659
660
661
662
663
664
665
        parser.add_argument(
            "--num-continuous-decode-steps",
            type=int,
            default=ServerArgs.num_continuous_decode_steps,
            help="Run multiple continuous decoding steps to reduce scheduling overhead. "
            "This can potentially increase throughput but may also increase time-to-first-token latency. "
            "The default value is 1, meaning only run one decoding step at a time.",
        )
666

Lianmin Zheng's avatar
Lianmin Zheng committed
667
668
    @classmethod
    def from_cli_args(cls, args: argparse.Namespace):
669
670
        args.tp_size = args.tensor_parallel_size
        args.dp_size = args.data_parallel_size
Lianmin Zheng's avatar
Lianmin Zheng committed
671
672
673
674
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        return cls(**{attr: getattr(args, attr) for attr in attrs})

    def url(self):
675
676
677
678
        if is_ipv6(self.host):
            return f"http://[{self.host}]:{self.port}"
        else:
            return f"http://{self.host}:{self.port}"
Lianmin Zheng's avatar
Lianmin Zheng committed
679

680
681
682
683
684
    def check_server_args(self):
        assert (
            self.tp_size % self.nnodes == 0
        ), "tp_size must be divisible by number of nodes"
        assert not (
685
            self.dp_size > 1 and self.nnodes != 1
686
        ), "multi-node data parallel is not supported"
687
688
689
690
691
692
        assert (
            self.max_loras_per_batch > 0
            # FIXME
            and (self.lora_paths is None or self.disable_cuda_graph)
            and (self.lora_paths is None or self.disable_radix_cache)
        ), "compatibility of lora and cuda graph and radix attention is in progress"
693

694
695
696
697
698
699
700
701
702
703
        if isinstance(self.lora_paths, list):
            lora_paths = self.lora_paths
            self.lora_paths = {}
            for lora_path in lora_paths:
                if "=" in lora_path:
                    name, path = lora_path.split("=", 1)
                    self.lora_paths[name] = path
                else:
                    self.lora_paths[lora_path] = lora_path

Lianmin Zheng's avatar
Lianmin Zheng committed
704

Lianmin Zheng's avatar
Lianmin Zheng committed
705
def prepare_server_args(argv: List[str]) -> ServerArgs:
706
707
708
709
710
711
712
713
714
715
716
717
    """
    Prepare the server arguments from the command line arguments.

    Args:
        args: The command line arguments. Typically, it should be `sys.argv[1:]`
            to ensure compatibility with `parse_args` when no arguments are passed.

    Returns:
        The server arguments.
    """
    parser = argparse.ArgumentParser()
    ServerArgs.add_cli_args(parser)
Lianmin Zheng's avatar
Lianmin Zheng committed
718
    raw_args = parser.parse_args(argv)
719
720
721
722
    server_args = ServerArgs.from_cli_args(raw_args)
    return server_args


Lianmin Zheng's avatar
Lianmin Zheng committed
723
724
@dataclasses.dataclass
class PortArgs:
725
726
727
728
729
730
    # The ipc filename for tokenizer to receive inputs from detokenizer (zmq)
    tokenizer_ipc_name: str
    # The ipc filename for scheduler (rank 0) to receive inputs from tokenizer (zmq)
    scheduler_input_ipc_name: str
    # The ipc filename for detokenizer to receive inputs from scheduler (zmq)
    detokenizer_ipc_name: str
731

732
733
    # The port for nccl initialization (torch.dist)
    nccl_port: int
734

735
736
    @staticmethod
    def init_new(server_args) -> "PortArgs":
737
        port = server_args.port + 42
738
739
740
        while True:
            if is_port_available(port):
                break
741
            port += 42
742
743
744
745
746

        return PortArgs(
            tokenizer_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
            scheduler_input_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
            detokenizer_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
747
            nccl_port=port,
748
749
        )

750
751
752
753
754
755
756
757
758
759

class LoRAPathAction(argparse.Action):
    def __call__(self, parser, namespace, values, option_string=None):
        setattr(namespace, self.dest, {})
        for lora_path in values:
            if "=" in lora_path:
                name, path = lora_path.split("=", 1)
                getattr(namespace, self.dest)[name] = path
            else:
                getattr(namespace, self.dest)[lora_path] = lora_path