server_args.py 22.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

Lianmin Zheng's avatar
Lianmin Zheng committed
16
17
"""The arguments of the server."""

Lianmin Zheng's avatar
Lianmin Zheng committed
18
19
import argparse
import dataclasses
20
import logging
21
import random
22
from typing import List, Optional, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
23

24
25
from sglang.srt.utils import is_hip

26
27
logger = logging.getLogger(__name__)

Lianmin Zheng's avatar
Lianmin Zheng committed
28

29
30
31
32
33
34
35
36
37
38
39
class LoRAPathAction(argparse.Action):
    def __call__(self, parser, namespace, values, option_string=None):
        setattr(namespace, self.dest, {})
        for lora_path in values:
            if "=" in lora_path:
                name, path = lora_path.split("=", 1)
                getattr(namespace, self.dest)[name] = path
            else:
                getattr(namespace, self.dest)[lora_path] = lora_path


Lianmin Zheng's avatar
Lianmin Zheng committed
40
41
@dataclasses.dataclass
class ServerArgs:
Lianmin Zheng's avatar
Lianmin Zheng committed
42
    # Model and tokenizer
Lianmin Zheng's avatar
Lianmin Zheng committed
43
44
45
    model_path: str
    tokenizer_path: Optional[str] = None
    tokenizer_mode: str = "auto"
46
    skip_tokenizer_init: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
47
48
    load_format: str = "auto"
    dtype: str = "auto"
49
    kv_cache_dtype: str = "auto"
Lianmin Zheng's avatar
Lianmin Zheng committed
50
    trust_remote_code: bool = True
Lianmin Zheng's avatar
Lianmin Zheng committed
51
    context_length: Optional[int] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
52
    quantization: Optional[str] = None
53
    served_model_name: Optional[str] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
54
    chat_template: Optional[str] = None
55
    is_embedding: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
56
57
58
59
60
61
62

    # Port
    host: str = "127.0.0.1"
    port: int = 30000
    additional_ports: Optional[Union[List[int], int]] = None

    # Memory and scheduling
Lianmin Zheng's avatar
Lianmin Zheng committed
63
    mem_fraction_static: Optional[float] = None
64
    max_running_requests: Optional[int] = None
65
    max_total_tokens: Optional[int] = None
66
    chunked_prefill_size: int = 8192
67
    max_prefill_tokens: int = 16384
68
    schedule_policy: str = "lpm"
69
    schedule_conservativeness: float = 1.0
Lianmin Zheng's avatar
Lianmin Zheng committed
70
71
72

    # Other runtime options
    tp_size: int = 1
73
    stream_interval: int = 1
74
    random_seed: Optional[int] = None
75
    constrained_json_whitespace_pattern: Optional[str] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
76
77
78

    # Logging
    log_level: str = "info"
79
    log_level_http: Optional[str] = None
80
    log_requests: bool = False
Liangsheng Yin's avatar
Liangsheng Yin committed
81
    show_time_cost: bool = False
Liangsheng Yin's avatar
Liangsheng Yin committed
82

Lianmin Zheng's avatar
Lianmin Zheng committed
83
    # Other
84
    api_key: Optional[str] = None
85
    file_storage_pth: str = "SGLang_storage"
Lianmin Zheng's avatar
Lianmin Zheng committed
86

87
88
89
90
    # Data parallelism
    dp_size: int = 1
    load_balance_method: str = "round_robin"

Lianmin Zheng's avatar
Lianmin Zheng committed
91
92
93
94
95
96
97
98
    # Distributed args
    nccl_init_addr: Optional[str] = None
    nnodes: int = 1
    node_rank: Optional[int] = None

    # Model override args in JSON
    json_model_override_args: str = "{}"

Lianmin Zheng's avatar
Lianmin Zheng committed
99
    # Optimization/debug options
100
101
    attention_backend: Optional[str] = None
    sampling_backend: Optional[str] = None
102

103
    disable_flashinfer: bool = False
104
    disable_flashinfer_sampling: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
105
    disable_radix_cache: bool = False
106
    disable_regex_jump_forward: bool = False
107
    disable_cuda_graph: bool = False
108
    disable_cuda_graph_padding: bool = False
109
    disable_disk_cache: bool = False
110
    disable_custom_all_reduce: bool = False
111
    enable_mixed_chunk: bool = False
112
    enable_torch_compile: bool = False
113
    max_torch_compile_bs: int = 32
114
    torchao_config: str = ""
115
    enable_p2p_check: bool = False
116
    enable_mla: bool = False
117
    triton_attention_reduce_in_fp32: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
118

119
120
121
122
    # LoRA
    lora_paths: Optional[List[str]] = None
    max_loras_per_batch: int = 8

Lianmin Zheng's avatar
Lianmin Zheng committed
123
    def __post_init__(self):
124
        # Set missing default values
Lianmin Zheng's avatar
Lianmin Zheng committed
125
126
        if self.tokenizer_path is None:
            self.tokenizer_path = self.model_path
127
128
129
130

        if self.served_model_name is None:
            self.served_model_name = self.model_path

131
132
133
134
        if self.chunked_prefill_size <= 0:
            # Disable chunked prefill
            self.chunked_prefill_size = None

135
        # Mem fraction depends on the tensor parallelism size
Lianmin Zheng's avatar
Lianmin Zheng committed
136
        if self.mem_fraction_static is None:
137
            if self.tp_size >= 16:
138
                self.mem_fraction_static = 0.79
139
            elif self.tp_size >= 8:
140
                self.mem_fraction_static = 0.83
Lianmin Zheng's avatar
Lianmin Zheng committed
141
            elif self.tp_size >= 4:
142
                self.mem_fraction_static = 0.85
Lianmin Zheng's avatar
Lianmin Zheng committed
143
            elif self.tp_size >= 2:
144
                self.mem_fraction_static = 0.87
Ying Sheng's avatar
Ying Sheng committed
145
            else:
146
                self.mem_fraction_static = 0.88
147

148
149
150
151
        if isinstance(self.additional_ports, int):
            self.additional_ports = [self.additional_ports]
        elif self.additional_ports is None:
            self.additional_ports = []
Lianmin Zheng's avatar
Lianmin Zheng committed
152

153
154
155
        if self.random_seed is None:
            self.random_seed = random.randint(0, 1 << 30)

156
157
158
159
160
161
        # Deprecation warnings
        if self.disable_flashinfer:
            logger.warning(
                "The option '--disable-flashinfer' will be deprecated in the next release. "
                "Please use '--attention-backend triton' instead."
            )
162
            self.attention_backend = "triton"
163
164
165
166
167
        if self.disable_flashinfer_sampling:
            logger.warning(
                "The option '--disable-flashinfer-sampling' will be deprecated in the next release. "
                "Please use '--sampling-backend pytorch' instead. "
            )
168
            self.sampling_backend = "pytorch"
169

170
171
172
173
174
        # ROCm: flashinfer available later
        if is_hip():
            self.attention_backend = "triton"
            self.sampling_backend = "pytorch"

175
176
177
178
179
180
181
182
183
184
185
        # Default kernel backends
        if self.enable_mla:
            logger.info("MLA optimization is tunred on. Use triton backend.")
            self.attention_backend = "triton"

        if self.attention_backend is None:
            self.attention_backend = "flashinfer"

        if self.sampling_backend is None:
            self.sampling_backend = "flashinfer"

186
187
188
189
190
191
192
193
194
195
196
        # Model-specific patches
        if "Alibaba-NLP/gte-Qwen2-1.5B-instruct" == self.model_path:
            logger.info(
                "Not sure why, the tokenizer will add an additional token at the end of the prompt when trust_remote_mode=True"
            )
            self.trust_remote_code = False

        if "gemma-2" in self.model_path.lower():
            logger.info("When using sliding window in gemma-2, turn on flashinfer.")
            self.attention_backend = "flashinfer"

Lianmin Zheng's avatar
Lianmin Zheng committed
197
198
199
200
201
202
203
204
205
206
207
208
209
210
    @staticmethod
    def add_cli_args(parser: argparse.ArgumentParser):
        parser.add_argument(
            "--model-path",
            type=str,
            help="The path of the model weights. This can be a local folder or a Hugging Face repo ID.",
            required=True,
        )
        parser.add_argument(
            "--tokenizer-path",
            type=str,
            default=ServerArgs.tokenizer_path,
            help="The path of the tokenizer.",
        )
Yuanhan Zhang's avatar
Yuanhan Zhang committed
211
212
213
214
215
216
        parser.add_argument(
            "--host", type=str, default=ServerArgs.host, help="The host of the server."
        )
        parser.add_argument(
            "--port", type=int, default=ServerArgs.port, help="The port of the server."
        )
217
218
219
220
221
        parser.add_argument(
            "--additional-ports",
            type=int,
            nargs="*",
            default=[],
222
            help="The additional ports specified for the server.",
223
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
224
225
226
227
228
229
230
231
232
        parser.add_argument(
            "--tokenizer-mode",
            type=str,
            default=ServerArgs.tokenizer_mode,
            choices=["auto", "slow"],
            help="Tokenizer mode. 'auto' will use the fast "
            "tokenizer if available, and 'slow' will "
            "always use the slow tokenizer.",
        )
233
234
235
236
237
        parser.add_argument(
            "--skip-tokenizer-init",
            action="store_true",
            help="If set, skip init tokenizer and pass input_ids in generate request",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
        parser.add_argument(
            "--load-format",
            type=str,
            default=ServerArgs.load_format,
            choices=["auto", "pt", "safetensors", "npcache", "dummy"],
            help="The format of the model weights to load. "
            '"auto" will try to load the weights in the safetensors format '
            "and fall back to the pytorch bin format if safetensors format "
            "is not available. "
            '"pt" will load the weights in the pytorch bin format. '
            '"safetensors" will load the weights in the safetensors format. '
            '"npcache" will load the weights in pytorch format and store '
            "a numpy cache to speed up the loading. "
            '"dummy" will initialize the weights with random values, '
            "which is mainly for profiling.",
        )
        parser.add_argument(
Lianmin Zheng's avatar
Lianmin Zheng committed
255
            "--dtype",
Cody Yu's avatar
Cody Yu committed
256
            type=str,
Lianmin Zheng's avatar
Lianmin Zheng committed
257
            default=ServerArgs.dtype,
Ying Sheng's avatar
Ying Sheng committed
258
259
            choices=["auto", "half", "float16", "bfloat16", "float", "float32"],
            help="Data type for model weights and activations.\n\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
260
            '* "auto" will use FP16 precision for FP32 and FP16 models, and '
Ying Sheng's avatar
Ying Sheng committed
261
            "BF16 precision for BF16 models.\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
262
263
264
265
            '* "half" for FP16. Recommended for AWQ quantization.\n'
            '* "float16" is the same as "half".\n'
            '* "bfloat16" for a balance between precision and range.\n'
            '* "float" is shorthand for FP32 precision.\n'
Ying Sheng's avatar
Ying Sheng committed
266
267
            '* "float32" for FP32 precision.',
        )
268
269
270
271
272
273
274
        parser.add_argument(
            "--kv-cache-dtype",
            type=str,
            default=ServerArgs.kv_cache_dtype,
            choices=["auto", "fp8_e5m2"],
            help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" is supported for CUDA 11.8+.',
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
275
276
277
278
279
        parser.add_argument(
            "--trust-remote-code",
            action="store_true",
            help="Whether or not to allow for custom models defined on the Hub in their own modeling files.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
280
281
282
283
284
285
        parser.add_argument(
            "--context-length",
            type=int,
            default=ServerArgs.context_length,
            help="The model's maximum context length. Defaults to None (will use the value from the model's config.json instead).",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
286
287
288
289
        parser.add_argument(
            "--quantization",
            type=str,
            default=ServerArgs.quantization,
Ying Sheng's avatar
Ying Sheng committed
290
291
292
293
294
295
            choices=[
                "awq",
                "fp8",
                "gptq",
                "marlin",
                "gptq_marlin",
Ying Sheng's avatar
Ying Sheng committed
296
                "awq_marlin",
Ying Sheng's avatar
Ying Sheng committed
297
298
299
                "squeezellm",
                "bitsandbytes",
            ],
Lianmin Zheng's avatar
Lianmin Zheng committed
300
301
            help="The quantization method.",
        )
302
303
304
305
306
307
        parser.add_argument(
            "--served-model-name",
            type=str,
            default=ServerArgs.served_model_name,
            help="Override the model name returned by the v1/models endpoint in OpenAI API server.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
308
309
310
311
312
313
        parser.add_argument(
            "--chat-template",
            type=str,
            default=ServerArgs.chat_template,
            help="The buliltin chat template name or the path of the chat template file. This is only used for OpenAI-compatible API server.",
        )
314
315
316
317
318
        parser.add_argument(
            "--is-embedding",
            action="store_true",
            help="Whether to use a CausalLM as an embedding model.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
319
320
321
322
        parser.add_argument(
            "--mem-fraction-static",
            type=float,
            default=ServerArgs.mem_fraction_static,
323
            help="The fraction of the memory used for static allocation (model weights and KV cache memory pool). Use a smaller value if you see out-of-memory errors.",
Lianmin Zheng's avatar
Lianmin Zheng committed
324
        )
325
326
327
328
329
330
        parser.add_argument(
            "--max-running-requests",
            type=int,
            default=ServerArgs.max_running_requests,
            help="The maximum number of running requests.",
        )
331
332
333
334
        parser.add_argument(
            "--max-total-tokens",
            type=int,
            default=ServerArgs.max_total_tokens,
335
336
            help="The maximum number of tokens in the memory pool. If not specified, it will be automatically calculated based on the memory usage fraction. "
            "This option is typically used for development and debugging purposes.",
337
        )
338
339
340
341
342
343
344
345
346
347
348
349
        parser.add_argument(
            "--chunked-prefill-size",
            type=int,
            default=ServerArgs.chunked_prefill_size,
            help="The maximum number of tokens in a chunk for the chunked prefill. Setting this to -1 means disabling chunked prefill",
        )
        parser.add_argument(
            "--max-prefill-tokens",
            type=int,
            default=ServerArgs.max_prefill_tokens,
            help="The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model's maximum context length.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
350
        parser.add_argument(
351
            "--schedule-policy",
Lianmin Zheng's avatar
Lianmin Zheng committed
352
            type=str,
353
            default=ServerArgs.schedule_policy,
Liangsheng Yin's avatar
Liangsheng Yin committed
354
            choices=["lpm", "random", "fcfs", "dfs-weight"],
355
            help="The scheduling policy of the requests.",
Lianmin Zheng's avatar
Lianmin Zheng committed
356
        )
357
358
359
360
        parser.add_argument(
            "--schedule-conservativeness",
            type=float,
            default=ServerArgs.schedule_conservativeness,
361
            help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see requests being retracted frequently.",
362
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
363
        parser.add_argument(
364
            "--tensor-parallel-size",
Lianmin Zheng's avatar
Lianmin Zheng committed
365
            "--tp-size",
Lianmin Zheng's avatar
Lianmin Zheng committed
366
            type=int,
Lianmin Zheng's avatar
Lianmin Zheng committed
367
            default=ServerArgs.tp_size,
368
            help="The tensor parallelism size.",
369
        )
370
371
372
        parser.add_argument(
            "--stream-interval",
            type=int,
Lianmin Zheng's avatar
Lianmin Zheng committed
373
            default=ServerArgs.stream_interval,
374
            help="The interval (or buffer size) for streaming in terms of the token length. A smaller value makes streaming smoother, while a larger value makes the throughput higher",
375
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
376
377
378
379
        parser.add_argument(
            "--random-seed",
            type=int,
            default=ServerArgs.random_seed,
380
            help="The random seed.",
Lianmin Zheng's avatar
Lianmin Zheng committed
381
        )
382
383
384
385
386
387
        parser.add_argument(
            "--constrained-json-whitespace-pattern",
            type=str,
            default=ServerArgs.constrained_json_whitespace_pattern,
            help=r"Regex pattern for syntactic whitespaces allowed in JSON constrained output. For example, to allow the model generate consecutive whitespaces, set the pattern to [\n\t ]*",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
388
389
390
391
        parser.add_argument(
            "--log-level",
            type=str,
            default=ServerArgs.log_level,
392
            help="The logging level of all loggers.",
Lianmin Zheng's avatar
Lianmin Zheng committed
393
        )
394
        parser.add_argument(
395
396
397
398
            "--log-level-http",
            type=str,
            default=ServerArgs.log_level_http,
            help="The logging level of HTTP server. If not set, reuse --log-level by default.",
399
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
400
        parser.add_argument(
401
            "--log-requests",
Lianmin Zheng's avatar
Lianmin Zheng committed
402
            action="store_true",
403
            help="Log the inputs and outputs of all requests.",
Lianmin Zheng's avatar
Lianmin Zheng committed
404
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
405
406
407
        parser.add_argument(
            "--show-time-cost",
            action="store_true",
Ying Sheng's avatar
Ying Sheng committed
408
            help="Show time cost of custom marks.",
Lianmin Zheng's avatar
Lianmin Zheng committed
409
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
410
411
412
413
        parser.add_argument(
            "--api-key",
            type=str,
            default=ServerArgs.api_key,
414
            help="Set API key of the server. It is also used in the OpenAI API compatible server.",
Liangsheng Yin's avatar
Liangsheng Yin committed
415
        )
416
417
418
419
420
421
        parser.add_argument(
            "--file-storage-pth",
            type=str,
            default=ServerArgs.file_storage_pth,
            help="The path of the file storage in backend.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
422

423
424
        # Data parallelism
        parser.add_argument(
425
            "--data-parallel-size",
426
427
428
            "--dp-size",
            type=int,
            default=ServerArgs.dp_size,
429
            help="The data parallelism size.",
430
431
432
433
434
        )
        parser.add_argument(
            "--load-balance-method",
            type=str,
            default=ServerArgs.load_balance_method,
435
            help="The load balancing strategy for data parallelism.",
436
437
438
439
440
441
            choices=[
                "round_robin",
                "shortest_queue",
            ],
        )

442
443
444
445
        # Multi-node distributed serving args
        parser.add_argument(
            "--nccl-init-addr",
            type=str,
Ying Sheng's avatar
Ying Sheng committed
446
            help="The nccl init address of multi-node server.",
447
448
        )
        parser.add_argument(
Liangsheng Yin's avatar
Liangsheng Yin committed
449
            "--nnodes", type=int, default=ServerArgs.nnodes, help="The number of nodes."
450
        )
Ying Sheng's avatar
Ying Sheng committed
451
        parser.add_argument("--node-rank", type=int, help="The node rank.")
452

Lianmin Zheng's avatar
Lianmin Zheng committed
453
454
455
456
457
458
459
460
        # Model override args
        parser.add_argument(
            "--json-model-override-args",
            type=str,
            help="A dictionary in JSON string format used to override default model configurations.",
            default=ServerArgs.json_model_override_args,
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
461
        # Optimization/debug options
462
463
464
465
466
467
468
469
470
471
472
473
474
475
        parser.add_argument(
            "--attention-backend",
            type=str,
            choices=["flashinfer", "triton"],
            default=ServerArgs.attention_backend,
            help="Choose the kernels for attention layers.",
        )
        parser.add_argument(
            "--sampling-backend",
            type=str,
            choices=["flashinfer", "pytorch"],
            default=ServerArgs.sampling_backend,
            help="Choose the kernels for sampling layers.",
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
476
        parser.add_argument(
477
            "--disable-flashinfer",
Liangsheng Yin's avatar
Liangsheng Yin committed
478
            action="store_true",
479
            help="Disable flashinfer attention kernels. This option will be deprecated in the next release. Please use '--attention-backend triton' instead.",
480
481
482
483
        )
        parser.add_argument(
            "--disable-flashinfer-sampling",
            action="store_true",
484
            help="Disable flashinfer sampling kernels. This option will be deprecated in the next release. Please use '--sampling-backend pytorch' instead.",
Liangsheng Yin's avatar
Liangsheng Yin committed
485
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
486
        parser.add_argument(
Lianmin Zheng's avatar
Lianmin Zheng committed
487
            "--disable-radix-cache",
Liangsheng Yin's avatar
Liangsheng Yin committed
488
            action="store_true",
Ying Sheng's avatar
Ying Sheng committed
489
            help="Disable RadixAttention for prefix caching.",
Liangsheng Yin's avatar
Liangsheng Yin committed
490
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
491
        parser.add_argument(
492
            "--disable-regex-jump-forward",
Liangsheng Yin's avatar
Liangsheng Yin committed
493
            action="store_true",
Ying Sheng's avatar
Ying Sheng committed
494
            help="Disable regex jump-forward.",
Liangsheng Yin's avatar
Liangsheng Yin committed
495
        )
496
497
498
499
500
        parser.add_argument(
            "--disable-cuda-graph",
            action="store_true",
            help="Disable cuda graph.",
        )
501
        parser.add_argument(
502
503
504
505
506
            "--disable-cuda-graph-padding",
            action="store_true",
            help="Disable cuda graph when padding is needed. Still uses cuda graph when padding is not needed.",
        )
        parser.add_argument(
507
508
509
510
            "--disable-disk-cache",
            action="store_true",
            help="Disable disk cache to avoid possible crashes related to file system or high concurrency.",
        )
511
512
513
514
515
516
        parser.add_argument(
            "--disable-custom-all-reduce",
            action="store_true",
            default=False,
            help="Disable the custom all-reduce kernel and fall back to NCCL.",
        )
517
518
519
        parser.add_argument(
            "--enable-mixed-chunk",
            action="store_true",
520
            help="Enabling mixing prefill and decode in a batch when using chunked prefill.",
521
        )
522
523
524
        parser.add_argument(
            "--enable-torch-compile",
            action="store_true",
525
526
            help="Optimize the model with torch.compile. Experimental feature.",
        )
527
528
529
530
531
532
        parser.add_argument(
            "--max-torch-compile-bs",
            type=int,
            default=ServerArgs.max_torch_compile_bs,
            help="Set the maximum batch size when using torch compile.",
        )
533
534
535
536
537
        parser.add_argument(
            "--torchao-config",
            type=str,
            default=ServerArgs.torchao_config,
            help="Optimize the model with torchao. Experimental feature. Current choices are: int8dq, int8wo, int4wo-<group_size>, fp8wo",
538
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
539
        parser.add_argument(
540
            "--enable-p2p-check",
Lianmin Zheng's avatar
Lianmin Zheng committed
541
            action="store_true",
542
            help="Enable P2P check for GPU access, otherwise the p2p access is allowed by default.",
Lianmin Zheng's avatar
Lianmin Zheng committed
543
        )
544
545
546
        parser.add_argument(
            "--enable-mla",
            action="store_true",
547
            help="Enable Multi-head Latent Attention (MLA) for DeepSeek-V2.",
548
        )
549
        parser.add_argument(
550
            "--triton-attention-reduce-in-fp32",
551
            action="store_true",
552
            help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
553
            "This only affects Triton attention kernels.",
554
        )
555
556
557
558
559
        parser.add_argument(
            "--efficient-weight-load",
            action="store_true",
            help="Turn on memory efficient weight loading with quantization (quantize per layer during loading).",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
560

561
562
563
564
565
566
        # LoRA options
        parser.add_argument(
            "--lora-paths",
            type=str,
            nargs="*",
            default=None,
567
568
            action=LoRAPathAction,
            help="The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}",
569
570
571
572
573
574
575
576
        )
        parser.add_argument(
            "--max-loras-per-batch",
            type=int,
            default=8,
            help="Maximum number of adapters for a running batch, include base-only request",
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
577
578
    @classmethod
    def from_cli_args(cls, args: argparse.Namespace):
579
580
        args.tp_size = args.tensor_parallel_size
        args.dp_size = args.data_parallel_size
Lianmin Zheng's avatar
Lianmin Zheng committed
581
582
583
584
585
586
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        return cls(**{attr: getattr(args, attr) for attr in attrs})

    def url(self):
        return f"http://{self.host}:{self.port}"

587
588
589
590
591
592
593
    def check_server_args(self):
        assert (
            self.tp_size % self.nnodes == 0
        ), "tp_size must be divisible by number of nodes"
        assert not (
            self.dp_size > 1 and self.node_rank is not None
        ), "multi-node data parallel is not supported"
594
595
596
597
598
599
        assert (
            self.max_loras_per_batch > 0
            # FIXME
            and (self.lora_paths is None or self.disable_cuda_graph)
            and (self.lora_paths is None or self.disable_radix_cache)
        ), "compatibility of lora and cuda graph and radix attention is in progress"
600

Lianmin Zheng's avatar
Lianmin Zheng committed
601

Lianmin Zheng's avatar
Lianmin Zheng committed
602
def prepare_server_args(argv: List[str]) -> ServerArgs:
603
604
605
606
607
608
609
610
611
612
613
614
    """
    Prepare the server arguments from the command line arguments.

    Args:
        args: The command line arguments. Typically, it should be `sys.argv[1:]`
            to ensure compatibility with `parse_args` when no arguments are passed.

    Returns:
        The server arguments.
    """
    parser = argparse.ArgumentParser()
    ServerArgs.add_cli_args(parser)
Lianmin Zheng's avatar
Lianmin Zheng committed
615
    raw_args = parser.parse_args(argv)
616
617
618
619
    server_args = ServerArgs.from_cli_args(raw_args)
    return server_args


Lianmin Zheng's avatar
Lianmin Zheng committed
620
621
622
@dataclasses.dataclass
class PortArgs:
    tokenizer_port: int
Mingyi's avatar
Mingyi committed
623
    controller_port: int
Lianmin Zheng's avatar
Lianmin Zheng committed
624
    detokenizer_port: int
Mingyi's avatar
Mingyi committed
625
    nccl_ports: List[int]