server_args.py 23.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

Lianmin Zheng's avatar
Lianmin Zheng committed
16
17
"""The arguments of the server."""

Lianmin Zheng's avatar
Lianmin Zheng committed
18
19
import argparse
import dataclasses
20
import logging
21
import random
22
23
import tempfile
from typing import List, Optional
Lianmin Zheng's avatar
Lianmin Zheng committed
24

25
from sglang.srt.utils import is_flashinfer_available, is_ipv6, is_port_available
26

27
28
logger = logging.getLogger(__name__)

Lianmin Zheng's avatar
Lianmin Zheng committed
29
30
31

@dataclasses.dataclass
class ServerArgs:
Lianmin Zheng's avatar
Lianmin Zheng committed
32
    # Model and tokenizer
Lianmin Zheng's avatar
Lianmin Zheng committed
33
34
35
    model_path: str
    tokenizer_path: Optional[str] = None
    tokenizer_mode: str = "auto"
36
    skip_tokenizer_init: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
37
    load_format: str = "auto"
38
    trust_remote_code: bool = True
Lianmin Zheng's avatar
Lianmin Zheng committed
39
    dtype: str = "auto"
40
    kv_cache_dtype: str = "auto"
Lianmin Zheng's avatar
Lianmin Zheng committed
41
    quantization: Optional[str] = None
42
43
    context_length: Optional[int] = None
    device: str = "cuda"
44
    served_model_name: Optional[str] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
45
    chat_template: Optional[str] = None
46
    is_embedding: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
47
48
49
50
51
52

    # Port
    host: str = "127.0.0.1"
    port: int = 30000

    # Memory and scheduling
Lianmin Zheng's avatar
Lianmin Zheng committed
53
    mem_fraction_static: Optional[float] = None
54
    max_running_requests: Optional[int] = None
55
    max_total_tokens: Optional[int] = None
56
    chunked_prefill_size: int = 8192
57
    max_prefill_tokens: int = 16384
58
    schedule_policy: str = "lpm"
59
    schedule_conservativeness: float = 1.0
Lianmin Zheng's avatar
Lianmin Zheng committed
60
61
62

    # Other runtime options
    tp_size: int = 1
63
    stream_interval: int = 1
64
    random_seed: Optional[int] = None
65
    constrained_json_whitespace_pattern: Optional[str] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
66
67
68

    # Logging
    log_level: str = "info"
69
    log_level_http: Optional[str] = None
70
    log_requests: bool = False
Liangsheng Yin's avatar
Liangsheng Yin committed
71
    show_time_cost: bool = False
Liangsheng Yin's avatar
Liangsheng Yin committed
72

Lianmin Zheng's avatar
Lianmin Zheng committed
73
    # Other
74
    api_key: Optional[str] = None
75
    file_storage_pth: str = "SGLang_storage"
Lianmin Zheng's avatar
Lianmin Zheng committed
76

77
78
79
80
    # Data parallelism
    dp_size: int = 1
    load_balance_method: str = "round_robin"

Lianmin Zheng's avatar
Lianmin Zheng committed
81
    # Distributed args
82
    dist_init_addr: Optional[str] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
83
    nnodes: int = 1
84
    node_rank: int = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
85
86
87
88

    # Model override args in JSON
    json_model_override_args: str = "{}"

89
90
91
92
93
    # LoRA
    lora_paths: Optional[List[str]] = None
    max_loras_per_batch: int = 8

    # Kernel backend
94
95
    attention_backend: Optional[str] = None
    sampling_backend: Optional[str] = None
96

97
    # Optimization/debug options
98
    disable_flashinfer: bool = False
99
    disable_flashinfer_sampling: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
100
    disable_radix_cache: bool = False
101
    disable_regex_jump_forward: bool = False
102
    disable_cuda_graph: bool = False
103
    disable_cuda_graph_padding: bool = False
104
    disable_disk_cache: bool = False
105
    disable_custom_all_reduce: bool = False
Ke Bao's avatar
Ke Bao committed
106
    disable_mla: bool = False
107
    disable_penalizer: bool = False
108
    enable_mixed_chunk: bool = False
109
    enable_torch_compile: bool = False
110
    max_torch_compile_bs: int = 32
111
    torchao_config: str = ""
112
    enable_p2p_check: bool = False
113
    triton_attention_reduce_in_fp32: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
114
115

    def __post_init__(self):
116
        # Set missing default values
Lianmin Zheng's avatar
Lianmin Zheng committed
117
118
        if self.tokenizer_path is None:
            self.tokenizer_path = self.model_path
119
120
121
122

        if self.served_model_name is None:
            self.served_model_name = self.model_path

123
124
125
126
        if self.chunked_prefill_size <= 0:
            # Disable chunked prefill
            self.chunked_prefill_size = None

127
        # Mem fraction depends on the tensor parallelism size
Lianmin Zheng's avatar
Lianmin Zheng committed
128
        if self.mem_fraction_static is None:
129
            if self.tp_size >= 16:
130
                self.mem_fraction_static = 0.79
131
            elif self.tp_size >= 8:
132
                self.mem_fraction_static = 0.83
Lianmin Zheng's avatar
Lianmin Zheng committed
133
            elif self.tp_size >= 4:
134
                self.mem_fraction_static = 0.85
Lianmin Zheng's avatar
Lianmin Zheng committed
135
            elif self.tp_size >= 2:
136
                self.mem_fraction_static = 0.87
Ying Sheng's avatar
Ying Sheng committed
137
            else:
138
                self.mem_fraction_static = 0.88
139

140
141
142
        if self.random_seed is None:
            self.random_seed = random.randint(0, 1 << 30)

143
144
145
146
147
148
        # Deprecation warnings
        if self.disable_flashinfer:
            logger.warning(
                "The option '--disable-flashinfer' will be deprecated in the next release. "
                "Please use '--attention-backend triton' instead."
            )
149
            self.attention_backend = "triton"
150
151
152
153
154
        if self.disable_flashinfer_sampling:
            logger.warning(
                "The option '--disable-flashinfer-sampling' will be deprecated in the next release. "
                "Please use '--sampling-backend pytorch' instead. "
            )
155
            self.sampling_backend = "pytorch"
156

157
        if not is_flashinfer_available():
158
159
160
            self.attention_backend = "triton"
            self.sampling_backend = "pytorch"

161
162
163
164
165
166
167
        # Default kernel backends
        if self.attention_backend is None:
            self.attention_backend = "flashinfer"

        if self.sampling_backend is None:
            self.sampling_backend = "flashinfer"

168
169
170
171
172
173
174
175
176
177
178
        # Model-specific patches
        if "Alibaba-NLP/gte-Qwen2-1.5B-instruct" == self.model_path:
            logger.info(
                "Not sure why, the tokenizer will add an additional token at the end of the prompt when trust_remote_mode=True"
            )
            self.trust_remote_code = False

        if "gemma-2" in self.model_path.lower():
            logger.info("When using sliding window in gemma-2, turn on flashinfer.")
            self.attention_backend = "flashinfer"

Lianmin Zheng's avatar
Lianmin Zheng committed
179
180
181
182
183
184
185
186
187
188
189
190
191
192
    @staticmethod
    def add_cli_args(parser: argparse.ArgumentParser):
        parser.add_argument(
            "--model-path",
            type=str,
            help="The path of the model weights. This can be a local folder or a Hugging Face repo ID.",
            required=True,
        )
        parser.add_argument(
            "--tokenizer-path",
            type=str,
            default=ServerArgs.tokenizer_path,
            help="The path of the tokenizer.",
        )
Yuanhan Zhang's avatar
Yuanhan Zhang committed
193
194
195
196
197
198
        parser.add_argument(
            "--host", type=str, default=ServerArgs.host, help="The host of the server."
        )
        parser.add_argument(
            "--port", type=int, default=ServerArgs.port, help="The port of the server."
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
199
200
201
202
203
204
205
206
207
        parser.add_argument(
            "--tokenizer-mode",
            type=str,
            default=ServerArgs.tokenizer_mode,
            choices=["auto", "slow"],
            help="Tokenizer mode. 'auto' will use the fast "
            "tokenizer if available, and 'slow' will "
            "always use the slow tokenizer.",
        )
208
209
210
211
212
        parser.add_argument(
            "--skip-tokenizer-init",
            action="store_true",
            help="If set, skip init tokenizer and pass input_ids in generate request",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
        parser.add_argument(
            "--load-format",
            type=str,
            default=ServerArgs.load_format,
            choices=["auto", "pt", "safetensors", "npcache", "dummy"],
            help="The format of the model weights to load. "
            '"auto" will try to load the weights in the safetensors format '
            "and fall back to the pytorch bin format if safetensors format "
            "is not available. "
            '"pt" will load the weights in the pytorch bin format. '
            '"safetensors" will load the weights in the safetensors format. '
            '"npcache" will load the weights in pytorch format and store '
            "a numpy cache to speed up the loading. "
            '"dummy" will initialize the weights with random values, '
            "which is mainly for profiling.",
        )
229
230
231
232
233
        parser.add_argument(
            "--trust-remote-code",
            action="store_true",
            help="Whether or not to allow for custom models defined on the Hub in their own modeling files.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
234
        parser.add_argument(
Lianmin Zheng's avatar
Lianmin Zheng committed
235
            "--dtype",
Cody Yu's avatar
Cody Yu committed
236
            type=str,
Lianmin Zheng's avatar
Lianmin Zheng committed
237
            default=ServerArgs.dtype,
Ying Sheng's avatar
Ying Sheng committed
238
239
            choices=["auto", "half", "float16", "bfloat16", "float", "float32"],
            help="Data type for model weights and activations.\n\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
240
            '* "auto" will use FP16 precision for FP32 and FP16 models, and '
Ying Sheng's avatar
Ying Sheng committed
241
            "BF16 precision for BF16 models.\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
242
243
244
245
            '* "half" for FP16. Recommended for AWQ quantization.\n'
            '* "float16" is the same as "half".\n'
            '* "bfloat16" for a balance between precision and range.\n'
            '* "float" is shorthand for FP32 precision.\n'
Ying Sheng's avatar
Ying Sheng committed
246
247
            '* "float32" for FP32 precision.',
        )
248
249
250
251
252
253
254
        parser.add_argument(
            "--kv-cache-dtype",
            type=str,
            default=ServerArgs.kv_cache_dtype,
            choices=["auto", "fp8_e5m2"],
            help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" is supported for CUDA 11.8+.',
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
255
256
257
258
        parser.add_argument(
            "--quantization",
            type=str,
            default=ServerArgs.quantization,
Ying Sheng's avatar
Ying Sheng committed
259
260
261
262
263
264
            choices=[
                "awq",
                "fp8",
                "gptq",
                "marlin",
                "gptq_marlin",
Ying Sheng's avatar
Ying Sheng committed
265
                "awq_marlin",
Ying Sheng's avatar
Ying Sheng committed
266
267
                "bitsandbytes",
            ],
Lianmin Zheng's avatar
Lianmin Zheng committed
268
269
            help="The quantization method.",
        )
270
271
272
273
274
275
276
277
278
279
280
281
282
        parser.add_argument(
            "--context-length",
            type=int,
            default=ServerArgs.context_length,
            help="The model's maximum context length. Defaults to None (will use the value from the model's config.json instead).",
        )
        parser.add_argument(
            "--device",
            type=str,
            default="cuda",
            choices=["cuda", "xpu"],
            help="The device type.",
        )
283
284
285
286
287
288
        parser.add_argument(
            "--served-model-name",
            type=str,
            default=ServerArgs.served_model_name,
            help="Override the model name returned by the v1/models endpoint in OpenAI API server.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
289
290
291
292
293
294
        parser.add_argument(
            "--chat-template",
            type=str,
            default=ServerArgs.chat_template,
            help="The buliltin chat template name or the path of the chat template file. This is only used for OpenAI-compatible API server.",
        )
295
296
297
298
299
        parser.add_argument(
            "--is-embedding",
            action="store_true",
            help="Whether to use a CausalLM as an embedding model.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
300
301
302
303
        parser.add_argument(
            "--mem-fraction-static",
            type=float,
            default=ServerArgs.mem_fraction_static,
304
            help="The fraction of the memory used for static allocation (model weights and KV cache memory pool). Use a smaller value if you see out-of-memory errors.",
Lianmin Zheng's avatar
Lianmin Zheng committed
305
        )
306
307
308
309
310
311
        parser.add_argument(
            "--max-running-requests",
            type=int,
            default=ServerArgs.max_running_requests,
            help="The maximum number of running requests.",
        )
312
313
314
315
        parser.add_argument(
            "--max-total-tokens",
            type=int,
            default=ServerArgs.max_total_tokens,
316
317
            help="The maximum number of tokens in the memory pool. If not specified, it will be automatically calculated based on the memory usage fraction. "
            "This option is typically used for development and debugging purposes.",
318
        )
319
320
321
322
323
324
325
326
327
328
329
330
        parser.add_argument(
            "--chunked-prefill-size",
            type=int,
            default=ServerArgs.chunked_prefill_size,
            help="The maximum number of tokens in a chunk for the chunked prefill. Setting this to -1 means disabling chunked prefill",
        )
        parser.add_argument(
            "--max-prefill-tokens",
            type=int,
            default=ServerArgs.max_prefill_tokens,
            help="The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model's maximum context length.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
331
        parser.add_argument(
332
            "--schedule-policy",
Lianmin Zheng's avatar
Lianmin Zheng committed
333
            type=str,
334
            default=ServerArgs.schedule_policy,
Liangsheng Yin's avatar
Liangsheng Yin committed
335
            choices=["lpm", "random", "fcfs", "dfs-weight"],
336
            help="The scheduling policy of the requests.",
Lianmin Zheng's avatar
Lianmin Zheng committed
337
        )
338
339
340
341
        parser.add_argument(
            "--schedule-conservativeness",
            type=float,
            default=ServerArgs.schedule_conservativeness,
342
            help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see requests being retracted frequently.",
343
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
344
        parser.add_argument(
345
            "--tensor-parallel-size",
Lianmin Zheng's avatar
Lianmin Zheng committed
346
            "--tp-size",
Lianmin Zheng's avatar
Lianmin Zheng committed
347
            type=int,
Lianmin Zheng's avatar
Lianmin Zheng committed
348
            default=ServerArgs.tp_size,
349
            help="The tensor parallelism size.",
350
        )
351
352
353
        parser.add_argument(
            "--stream-interval",
            type=int,
Lianmin Zheng's avatar
Lianmin Zheng committed
354
            default=ServerArgs.stream_interval,
355
            help="The interval (or buffer size) for streaming in terms of the token length. A smaller value makes streaming smoother, while a larger value makes the throughput higher",
356
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
357
358
359
360
        parser.add_argument(
            "--random-seed",
            type=int,
            default=ServerArgs.random_seed,
361
            help="The random seed.",
Lianmin Zheng's avatar
Lianmin Zheng committed
362
        )
363
364
365
366
367
368
        parser.add_argument(
            "--constrained-json-whitespace-pattern",
            type=str,
            default=ServerArgs.constrained_json_whitespace_pattern,
            help=r"Regex pattern for syntactic whitespaces allowed in JSON constrained output. For example, to allow the model generate consecutive whitespaces, set the pattern to [\n\t ]*",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
369
370
371
372
        parser.add_argument(
            "--log-level",
            type=str,
            default=ServerArgs.log_level,
373
            help="The logging level of all loggers.",
Lianmin Zheng's avatar
Lianmin Zheng committed
374
        )
375
        parser.add_argument(
376
377
378
379
            "--log-level-http",
            type=str,
            default=ServerArgs.log_level_http,
            help="The logging level of HTTP server. If not set, reuse --log-level by default.",
380
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
381
        parser.add_argument(
382
            "--log-requests",
Lianmin Zheng's avatar
Lianmin Zheng committed
383
            action="store_true",
384
            help="Log the inputs and outputs of all requests.",
Lianmin Zheng's avatar
Lianmin Zheng committed
385
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
386
387
388
        parser.add_argument(
            "--show-time-cost",
            action="store_true",
Ying Sheng's avatar
Ying Sheng committed
389
            help="Show time cost of custom marks.",
Lianmin Zheng's avatar
Lianmin Zheng committed
390
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
391
392
393
394
        parser.add_argument(
            "--api-key",
            type=str,
            default=ServerArgs.api_key,
395
            help="Set API key of the server. It is also used in the OpenAI API compatible server.",
Liangsheng Yin's avatar
Liangsheng Yin committed
396
        )
397
398
399
400
401
402
        parser.add_argument(
            "--file-storage-pth",
            type=str,
            default=ServerArgs.file_storage_pth,
            help="The path of the file storage in backend.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
403

404
405
        # Data parallelism
        parser.add_argument(
406
            "--data-parallel-size",
407
408
409
            "--dp-size",
            type=int,
            default=ServerArgs.dp_size,
410
            help="The data parallelism size.",
411
412
413
414
415
        )
        parser.add_argument(
            "--load-balance-method",
            type=str,
            default=ServerArgs.load_balance_method,
416
            help="The load balancing strategy for data parallelism.",
417
418
419
420
421
422
            choices=[
                "round_robin",
                "shortest_queue",
            ],
        )

423
424
        # Multi-node distributed serving args
        parser.add_argument(
425
426
            "--dist-init-addr",
            "--nccl-init-addr",  # For backward compatbility. This will be removed in the future.
427
            type=str,
428
            help="The host address for initializing distributed backend (e.g., `192.168.0.2:25000`).",
429
430
        )
        parser.add_argument(
Liangsheng Yin's avatar
Liangsheng Yin committed
431
            "--nnodes", type=int, default=ServerArgs.nnodes, help="The number of nodes."
432
        )
433
434
435
        parser.add_argument(
            "--node-rank", type=int, default=ServerArgs.node_rank, help="The node rank."
        )
436

Lianmin Zheng's avatar
Lianmin Zheng committed
437
438
439
440
441
442
443
444
        # Model override args
        parser.add_argument(
            "--json-model-override-args",
            type=str,
            help="A dictionary in JSON string format used to override default model configurations.",
            default=ServerArgs.json_model_override_args,
        )

445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
        # LoRA
        parser.add_argument(
            "--lora-paths",
            type=str,
            nargs="*",
            default=None,
            action=LoRAPathAction,
            help="The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}",
        )
        parser.add_argument(
            "--max-loras-per-batch",
            type=int,
            default=8,
            help="Maximum number of adapters for a running batch, include base-only request",
        )

        # Kernel backend
462
463
464
465
466
467
468
469
470
471
472
473
474
475
        parser.add_argument(
            "--attention-backend",
            type=str,
            choices=["flashinfer", "triton"],
            default=ServerArgs.attention_backend,
            help="Choose the kernels for attention layers.",
        )
        parser.add_argument(
            "--sampling-backend",
            type=str,
            choices=["flashinfer", "pytorch"],
            default=ServerArgs.sampling_backend,
            help="Choose the kernels for sampling layers.",
        )
476
477

        # Optimization/debug options
Liangsheng Yin's avatar
Liangsheng Yin committed
478
        parser.add_argument(
479
            "--disable-flashinfer",
Liangsheng Yin's avatar
Liangsheng Yin committed
480
            action="store_true",
481
            help="Disable flashinfer attention kernels. This option will be deprecated in the next release. Please use '--attention-backend triton' instead.",
482
483
484
485
        )
        parser.add_argument(
            "--disable-flashinfer-sampling",
            action="store_true",
486
            help="Disable flashinfer sampling kernels. This option will be deprecated in the next release. Please use '--sampling-backend pytorch' instead.",
Liangsheng Yin's avatar
Liangsheng Yin committed
487
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
488
        parser.add_argument(
Lianmin Zheng's avatar
Lianmin Zheng committed
489
            "--disable-radix-cache",
Liangsheng Yin's avatar
Liangsheng Yin committed
490
            action="store_true",
Ying Sheng's avatar
Ying Sheng committed
491
            help="Disable RadixAttention for prefix caching.",
Liangsheng Yin's avatar
Liangsheng Yin committed
492
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
493
        parser.add_argument(
494
            "--disable-regex-jump-forward",
Liangsheng Yin's avatar
Liangsheng Yin committed
495
            action="store_true",
Ying Sheng's avatar
Ying Sheng committed
496
            help="Disable regex jump-forward.",
Liangsheng Yin's avatar
Liangsheng Yin committed
497
        )
498
499
500
501
502
        parser.add_argument(
            "--disable-cuda-graph",
            action="store_true",
            help="Disable cuda graph.",
        )
503
        parser.add_argument(
504
505
506
507
508
            "--disable-cuda-graph-padding",
            action="store_true",
            help="Disable cuda graph when padding is needed. Still uses cuda graph when padding is not needed.",
        )
        parser.add_argument(
509
510
511
512
            "--disable-disk-cache",
            action="store_true",
            help="Disable disk cache to avoid possible crashes related to file system or high concurrency.",
        )
513
514
515
516
517
518
        parser.add_argument(
            "--disable-custom-all-reduce",
            action="store_true",
            default=False,
            help="Disable the custom all-reduce kernel and fall back to NCCL.",
        )
Ke Bao's avatar
Ke Bao committed
519
520
521
522
523
        parser.add_argument(
            "--disable-mla",
            action="store_true",
            help="Disable Multi-head Latent Attention (MLA) for DeepSeek-V2.",
        )
524
525
526
527
528
        parser.add_argument(
            "--disable-penalizer",
            action="store_true",
            help="Disable the logit penalizer (e.g., frequency and repetition penalty).",
        )
529
530
531
        parser.add_argument(
            "--enable-mixed-chunk",
            action="store_true",
532
            help="Enabling mixing prefill and decode in a batch when using chunked prefill.",
533
        )
534
535
536
        parser.add_argument(
            "--enable-torch-compile",
            action="store_true",
537
538
            help="Optimize the model with torch.compile. Experimental feature.",
        )
539
540
541
542
543
544
        parser.add_argument(
            "--max-torch-compile-bs",
            type=int,
            default=ServerArgs.max_torch_compile_bs,
            help="Set the maximum batch size when using torch compile.",
        )
545
546
547
548
549
        parser.add_argument(
            "--torchao-config",
            type=str,
            default=ServerArgs.torchao_config,
            help="Optimize the model with torchao. Experimental feature. Current choices are: int8dq, int8wo, int4wo-<group_size>, fp8wo",
550
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
551
        parser.add_argument(
552
            "--enable-p2p-check",
Lianmin Zheng's avatar
Lianmin Zheng committed
553
            action="store_true",
554
            help="Enable P2P check for GPU access, otherwise the p2p access is allowed by default.",
Lianmin Zheng's avatar
Lianmin Zheng committed
555
        )
556
        parser.add_argument(
557
            "--triton-attention-reduce-in-fp32",
558
            action="store_true",
559
            help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
560
            "This only affects Triton attention kernels.",
561
        )
562

Lianmin Zheng's avatar
Lianmin Zheng committed
563
564
    @classmethod
    def from_cli_args(cls, args: argparse.Namespace):
565
566
        args.tp_size = args.tensor_parallel_size
        args.dp_size = args.data_parallel_size
Lianmin Zheng's avatar
Lianmin Zheng committed
567
568
569
570
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        return cls(**{attr: getattr(args, attr) for attr in attrs})

    def url(self):
571
572
573
574
        if is_ipv6(self.host):
            return f"http://[{self.host}]:{self.port}"
        else:
            return f"http://{self.host}:{self.port}"
Lianmin Zheng's avatar
Lianmin Zheng committed
575

576
577
578
579
580
    def check_server_args(self):
        assert (
            self.tp_size % self.nnodes == 0
        ), "tp_size must be divisible by number of nodes"
        assert not (
581
            self.dp_size > 1 and self.nnodes != 1
582
        ), "multi-node data parallel is not supported"
583
584
585
586
587
588
        assert (
            self.max_loras_per_batch > 0
            # FIXME
            and (self.lora_paths is None or self.disable_cuda_graph)
            and (self.lora_paths is None or self.disable_radix_cache)
        ), "compatibility of lora and cuda graph and radix attention is in progress"
589

590
591
592
593
594
595
596
597
598
599
        if isinstance(self.lora_paths, list):
            lora_paths = self.lora_paths
            self.lora_paths = {}
            for lora_path in lora_paths:
                if "=" in lora_path:
                    name, path = lora_path.split("=", 1)
                    self.lora_paths[name] = path
                else:
                    self.lora_paths[lora_path] = lora_path

Lianmin Zheng's avatar
Lianmin Zheng committed
600

Lianmin Zheng's avatar
Lianmin Zheng committed
601
def prepare_server_args(argv: List[str]) -> ServerArgs:
602
603
604
605
606
607
608
609
610
611
612
613
    """
    Prepare the server arguments from the command line arguments.

    Args:
        args: The command line arguments. Typically, it should be `sys.argv[1:]`
            to ensure compatibility with `parse_args` when no arguments are passed.

    Returns:
        The server arguments.
    """
    parser = argparse.ArgumentParser()
    ServerArgs.add_cli_args(parser)
Lianmin Zheng's avatar
Lianmin Zheng committed
614
    raw_args = parser.parse_args(argv)
615
616
617
618
    server_args = ServerArgs.from_cli_args(raw_args)
    return server_args


Lianmin Zheng's avatar
Lianmin Zheng committed
619
620
@dataclasses.dataclass
class PortArgs:
621
622
623
624
625
626
    # The ipc filename for tokenizer to receive inputs from detokenizer (zmq)
    tokenizer_ipc_name: str
    # The ipc filename for scheduler (rank 0) to receive inputs from tokenizer (zmq)
    scheduler_input_ipc_name: str
    # The ipc filename for detokenizer to receive inputs from scheduler (zmq)
    detokenizer_ipc_name: str
627

628
629
    # The port for nccl initialization (torch.dist)
    nccl_port: int
630

631
632
    @staticmethod
    def init_new(server_args) -> "PortArgs":
633
634
635
636
637
638
639
640
641
642
        port = server_args.port + 1
        while True:
            if is_port_available(port):
                break
            port += 1

        return PortArgs(
            tokenizer_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
            scheduler_input_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
            detokenizer_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
643
            nccl_port=port,
644
645
        )

646
647
648
649
650
651
652
653
654
655

class LoRAPathAction(argparse.Action):
    def __call__(self, parser, namespace, values, option_string=None):
        setattr(namespace, self.dest, {})
        for lora_path in values:
            if "=" in lora_path:
                name, path = lora_path.split("=", 1)
                getattr(namespace, self.dest)[name] = path
            else:
                getattr(namespace, self.dest)[lora_path] = lora_path