server_args.py 43.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
Lianmin Zheng's avatar
Lianmin Zheng committed
14
15
"""The arguments of the server."""

Lianmin Zheng's avatar
Lianmin Zheng committed
16
17
import argparse
import dataclasses
18
import logging
19
import random
20
21
import tempfile
from typing import List, Optional
Lianmin Zheng's avatar
Lianmin Zheng committed
22

23
24
import torch

25
from sglang.srt.hf_transformers_utils import check_gguf_file
26
from sglang.srt.utils import (
HAI's avatar
HAI committed
27
    get_amdgpu_memory_capacity,
28
    get_hpu_memory_capacity,
HAI's avatar
HAI committed
29
    get_nvgpu_memory_capacity,
30
    is_flashinfer_available,
HAI's avatar
HAI committed
31
    is_hip,
32
    is_port_available,
33
    is_valid_ipv6_address,
bjmsong's avatar
bjmsong committed
34
    nullable_str,
35
)
36

37
38
logger = logging.getLogger(__name__)

Lianmin Zheng's avatar
Lianmin Zheng committed
39
40
41

@dataclasses.dataclass
class ServerArgs:
Lianmin Zheng's avatar
Lianmin Zheng committed
42
    # Model and tokenizer
Lianmin Zheng's avatar
Lianmin Zheng committed
43
44
45
    model_path: str
    tokenizer_path: Optional[str] = None
    tokenizer_mode: str = "auto"
46
    skip_tokenizer_init: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
47
    load_format: str = "auto"
48
    trust_remote_code: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
49
    dtype: str = "auto"
50
    kv_cache_dtype: str = "auto"
Lianmin Zheng's avatar
Lianmin Zheng committed
51
    quantization: Optional[str] = None
52
    quantization_param_path: nullable_str = None
53
54
    context_length: Optional[int] = None
    device: str = "cuda"
55
    served_model_name: Optional[str] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
56
    chat_template: Optional[str] = None
57
    is_embedding: bool = False
58
    revision: Optional[str] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
59

60
    # Port for the HTTP server
Lianmin Zheng's avatar
Lianmin Zheng committed
61
62
63
64
    host: str = "127.0.0.1"
    port: int = 30000

    # Memory and scheduling
Lianmin Zheng's avatar
Lianmin Zheng committed
65
    mem_fraction_static: Optional[float] = None
66
    max_running_requests: Optional[int] = None
67
    max_total_tokens: Optional[int] = None
68
    chunked_prefill_size: Optional[int] = None
69
    max_prefill_tokens: int = 16384
70
    schedule_policy: str = "fcfs"
71
    schedule_conservativeness: float = 1.0
72
    cpu_offload_gb: int = 0
73
    prefill_only_one_req: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
74
75
76

    # Other runtime options
    tp_size: int = 1
77
    stream_interval: int = 1
78
    stream_output: bool = False
79
    random_seed: Optional[int] = None
80
    constrained_json_whitespace_pattern: Optional[str] = None
81
    watchdog_timeout: float = 300
82
    dist_timeout: Optional[int] = None  # timeout for torch.distributed
83
    download_dir: Optional[str] = None
84
    base_gpu_id: int = 0
85
    gpu_id_step: int = 1
Lianmin Zheng's avatar
Lianmin Zheng committed
86
87
88

    # Logging
    log_level: str = "info"
89
    log_level_http: Optional[str] = None
90
    log_requests: bool = False
91
    log_requests_level: int = 0
Liangsheng Yin's avatar
Liangsheng Yin committed
92
    show_time_cost: bool = False
93
    enable_metrics: bool = False
94
    decode_log_interval: int = 40
Liangsheng Yin's avatar
Liangsheng Yin committed
95

96
    # API related
97
    api_key: Optional[str] = None
98
    file_storage_path: str = "sglang_storage"
99
    enable_cache_report: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
100

101
102
103
    # Data parallelism
    dp_size: int = 1
    load_balance_method: str = "round_robin"
104

xiaobochen's avatar
xiaobochen committed
105
106
    # Expert parallelism
    ep_size: int = 1
107

108
    # Multi-node distributed serving
109
    dist_init_addr: Optional[str] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
110
    nnodes: int = 1
111
    node_rank: int = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
112
113
114
115

    # Model override args in JSON
    json_model_override_args: str = "{}"

116
117
118
    # LoRA
    lora_paths: Optional[List[str]] = None
    max_loras_per_batch: int = 8
119
    lora_backend: str = "triton"
120
121

    # Kernel backend
122
123
    attention_backend: Optional[str] = None
    sampling_backend: Optional[str] = None
124
    grammar_backend: Optional[str] = "outlines"
125

126
127
    # Speculative decoding
    speculative_algorithm: Optional[str] = None
128
    speculative_draft_model_path: Optional[str] = None
129
    speculative_num_steps: int = 5
130
131
132
133
    speculative_eagle_topk: int = 4
    speculative_num_draft_tokens: int = 8
    speculative_accept_threshold_single: float = 1.0
    speculative_accept_threshold_acc: float = 1.0
134
    speculative_token_map: Optional[str] = None
135
136
137
138
139
140
141
142
143

    # Double Sparsity
    enable_double_sparsity: bool = False
    ds_channel_config_path: str = None
    ds_heavy_channel_num: int = 32
    ds_heavy_token_num: int = 256
    ds_heavy_channel_type: str = "qk"
    ds_sparse_decode_threshold: int = 4096

144
    # Optimization/debug options
Lianmin Zheng's avatar
Lianmin Zheng committed
145
    disable_radix_cache: bool = False
146
    disable_cuda_graph: bool = False
147
    disable_cuda_graph_padding: bool = False
148
    enable_nccl_nvls: bool = False
149
    disable_outlines_disk_cache: bool = False
150
    disable_custom_all_reduce: bool = False
Ke Bao's avatar
Ke Bao committed
151
    disable_mla: bool = False
152
    disable_overlap_schedule: bool = False
153
    enable_mixed_chunk: bool = False
Ke Bao's avatar
Ke Bao committed
154
    enable_dp_attention: bool = False
xiaobochen's avatar
xiaobochen committed
155
    enable_ep_moe: bool = False
156
    enable_torch_compile: bool = False
157
    torch_compile_max_bs: int = 32
158
    cuda_graph_max_bs: Optional[int] = None
159
    cuda_graph_bs: Optional[List[int]] = None
160
    torchao_config: str = ""
161
    enable_nan_detection: bool = False
162
    enable_p2p_check: bool = False
163
    triton_attention_reduce_in_fp32: bool = False
164
    triton_attention_num_kv_splits: int = 8
165
    num_continuous_decode_steps: int = 1
166
    delete_ckpt_after_loading: bool = False
167
    enable_memory_saver: bool = False
168
    allow_auto_truncate: bool = False
169
    enable_custom_logit_processor: bool = False
YAMY's avatar
YAMY committed
170
    tool_call_parser: str = None
171
    enable_hierarchical_cache: bool = False
172
    enable_flashinfer_mla: bool = False
173
    flashinfer_mla_disable_ragged: bool = False
174
175
176
177
178
179
    warmups: Optional[str] = None

    # Debug tensor dumps
    debug_tensor_dump_output_folder: Optional[str] = None
    debug_tensor_dump_input_file: Optional[str] = None
    debug_tensor_dump_inject: bool = False
180

Lianmin Zheng's avatar
Lianmin Zheng committed
181
    def __post_init__(self):
182
        # Set missing default values
Lianmin Zheng's avatar
Lianmin Zheng committed
183
184
        if self.tokenizer_path is None:
            self.tokenizer_path = self.model_path
185
186
187
188

        if self.served_model_name is None:
            self.served_model_name = self.model_path

189
190
191
        if self.random_seed is None:
            self.random_seed = random.randint(0, 1 << 30)

192
193
        if is_hip():
            gpu_mem = get_amdgpu_memory_capacity()
194
        elif torch.cuda.is_available():
195
            gpu_mem = get_nvgpu_memory_capacity()
196
197
        elif self.device == "hpu":
            gpu_mem = get_hpu_memory_capacity()
198
199
200
        else:
            # GPU memory is not known yet or no GPU is available.
            gpu_mem = None
201
202

        # Set mem fraction static, which depends on the tensor parallelism size
Lianmin Zheng's avatar
Lianmin Zheng committed
203
        if self.mem_fraction_static is None:
204
            if self.tp_size >= 16:
205
                self.mem_fraction_static = 0.79
206
            elif self.tp_size >= 8:
207
                self.mem_fraction_static = 0.81
Lianmin Zheng's avatar
Lianmin Zheng committed
208
            elif self.tp_size >= 4:
209
                self.mem_fraction_static = 0.85
Lianmin Zheng's avatar
Lianmin Zheng committed
210
            elif self.tp_size >= 2:
211
                self.mem_fraction_static = 0.87
Ying Sheng's avatar
Ying Sheng committed
212
            else:
213
                self.mem_fraction_static = 0.88
214

215
216
        # Set chunked prefill size, which depends on the gpu memory capacity
        if self.chunked_prefill_size is None:
217
            if gpu_mem is not None and gpu_mem < 25_000:
218
219
220
                self.chunked_prefill_size = 2048
            else:
                self.chunked_prefill_size = 8192
221

222
223
        # Set cuda graph max batch size
        if self.cuda_graph_max_bs is None:
224
            # Based on detailed statistics, when serving TP1/TP2 models on lower-end GPUs with HBM<25G, you can either disable cuda graph or set `cuda_graph_max_bs` to a very small value to reduce the memory overhead of creating cuda graphs, with almost no impact on performance. However, when serving models with TP4 or TP8, we need to enable cuda graph to maintain high performance. In this case, we can set `cuda_graph_max_bs` to 80 (half of the default value 160) to reduce the memory overhead of creating cuda graphs. Looking at the logs from TP4 serving of qwen2-72b, a value of 80 is sufficient and can reduce the memory overhead of creating cuda graphs on lower-end GPUs compared to the original 160, avoiding OOM issues.
225
            if gpu_mem is not None and gpu_mem < 25_000:
226
227
228
229
                if self.tp_size < 4:
                    self.cuda_graph_max_bs = 8
                else:
                    self.cuda_graph_max_bs = 80
230
231
            else:
                self.cuda_graph_max_bs = 160
232

233
        # Choose kernel backends
234
235
236
237
        if self.device == "hpu":
            self.attention_backend = "torch_native"
            self.sampling_backend = "pytorch"

238
        if self.attention_backend is None:
239
240
241
            self.attention_backend = (
                "flashinfer" if is_flashinfer_available() else "triton"
            )
242
        if self.sampling_backend is None:
243
244
245
246
247
            self.sampling_backend = (
                "flashinfer" if is_flashinfer_available() else "pytorch"
            )

        if self.attention_backend == "torch_native":
248
            logger.warning(
249
250
251
                "Cuda graph is disabled because of using torch native attention backend"
            )
            self.disable_cuda_graph = True
252

253
254
255
256
257
258
259
        # Expert parallelism
        if self.enable_ep_moe:
            self.ep_size = self.tp_size
            logger.info(
                f"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
            )

260
        # Others
Ke Bao's avatar
Ke Bao committed
261
262
        if self.enable_dp_attention:
            self.dp_size = self.tp_size
263
            assert self.tp_size % self.dp_size == 0
Ke Bao's avatar
Ke Bao committed
264
            self.chunked_prefill_size = self.chunked_prefill_size // 2
265
            self.schedule_conservativeness = self.schedule_conservativeness * 0.3
266
            logger.warning(
267
                f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE kernel issues. "
268
                f"The schedule conservativeness is adjusted to {self.schedule_conservativeness}. "
269
270
271
                "Data parallel size is adjusted to be the same as tensor parallel size. "
            )

272
        # Speculative Decoding
273
274
275
276
277
        if self.speculative_algorithm == "NEXTN":
            # NEXTN shares the same implementation of EAGLE
            self.speculative_algorithm = "EAGLE"

        if self.speculative_algorithm == "EAGLE":
278
            self.disable_overlap_schedule = True
279
280
281
282
283
            self.prefill_only_one_req = True
            self.disable_cuda_graph_padding = True
            self.disable_radix_cache = True
            self.chunked_prefill_size = -1
            logger.info(
284
                f"The radix cache, chunked prefill, and overlap scheduler are disabled because of using {self.speculative_algorithm} speculative decoding."
285
286
            )

287
288
289
290
291
292
        # GGUF
        if (
            self.load_format == "auto" or self.load_format == "gguf"
        ) and check_gguf_file(self.model_path):
            self.quantization = self.load_format = "gguf"

293
294
295
296
        # AMD-specific Triton attention KV splits default number
        if is_hip():
            self.triton_attention_num_kv_splits = 16

Lianmin Zheng's avatar
Lianmin Zheng committed
297
298
    @staticmethod
    def add_cli_args(parser: argparse.ArgumentParser):
299
        # Model and port args
Lianmin Zheng's avatar
Lianmin Zheng committed
300
301
302
303
304
305
306
307
308
309
310
311
        parser.add_argument(
            "--model-path",
            type=str,
            help="The path of the model weights. This can be a local folder or a Hugging Face repo ID.",
            required=True,
        )
        parser.add_argument(
            "--tokenizer-path",
            type=str,
            default=ServerArgs.tokenizer_path,
            help="The path of the tokenizer.",
        )
Yuanhan Zhang's avatar
Yuanhan Zhang committed
312
313
314
315
316
317
        parser.add_argument(
            "--host", type=str, default=ServerArgs.host, help="The host of the server."
        )
        parser.add_argument(
            "--port", type=int, default=ServerArgs.port, help="The port of the server."
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
318
319
320
321
322
323
324
325
326
        parser.add_argument(
            "--tokenizer-mode",
            type=str,
            default=ServerArgs.tokenizer_mode,
            choices=["auto", "slow"],
            help="Tokenizer mode. 'auto' will use the fast "
            "tokenizer if available, and 'slow' will "
            "always use the slow tokenizer.",
        )
327
328
329
330
331
        parser.add_argument(
            "--skip-tokenizer-init",
            action="store_true",
            help="If set, skip init tokenizer and pass input_ids in generate request",
        )
332
        parser.add_argument(
Lianmin Zheng's avatar
Lianmin Zheng committed
333
334
335
            "--load-format",
            type=str,
            default=ServerArgs.load_format,
336
337
338
339
340
341
342
343
            choices=[
                "auto",
                "pt",
                "safetensors",
                "npcache",
                "dummy",
                "gguf",
                "bitsandbytes",
344
                "layered",
345
            ],
Lianmin Zheng's avatar
Lianmin Zheng committed
346
347
348
349
350
351
352
353
354
            help="The format of the model weights to load. "
            '"auto" will try to load the weights in the safetensors format '
            "and fall back to the pytorch bin format if safetensors format "
            "is not available. "
            '"pt" will load the weights in the pytorch bin format. '
            '"safetensors" will load the weights in the safetensors format. '
            '"npcache" will load the weights in pytorch format and store '
            "a numpy cache to speed up the loading. "
            '"dummy" will initialize the weights with random values, '
355
            "which is mainly for profiling."
356
357
            '"gguf" will load the weights in the gguf format. '
            '"bitsandbytes" will load the weights using bitsandbytes '
358
359
360
361
            "quantization."
            '"layered" loads weights layer by layer so that one can quantize a '
            "layer before loading another to make the peak memory envelope "
            "smaller.",
Lianmin Zheng's avatar
Lianmin Zheng committed
362
        )
363
364
365
366
367
        parser.add_argument(
            "--trust-remote-code",
            action="store_true",
            help="Whether or not to allow for custom models defined on the Hub in their own modeling files.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
368
        parser.add_argument(
Lianmin Zheng's avatar
Lianmin Zheng committed
369
            "--dtype",
Cody Yu's avatar
Cody Yu committed
370
            type=str,
Lianmin Zheng's avatar
Lianmin Zheng committed
371
            default=ServerArgs.dtype,
Ying Sheng's avatar
Ying Sheng committed
372
373
            choices=["auto", "half", "float16", "bfloat16", "float", "float32"],
            help="Data type for model weights and activations.\n\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
374
            '* "auto" will use FP16 precision for FP32 and FP16 models, and '
Ying Sheng's avatar
Ying Sheng committed
375
            "BF16 precision for BF16 models.\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
376
377
378
379
            '* "half" for FP16. Recommended for AWQ quantization.\n'
            '* "float16" is the same as "half".\n'
            '* "bfloat16" for a balance between precision and range.\n'
            '* "float" is shorthand for FP32 precision.\n'
Ying Sheng's avatar
Ying Sheng committed
380
381
            '* "float32" for FP32 precision.',
        )
382
383
384
385
        parser.add_argument(
            "--kv-cache-dtype",
            type=str,
            default=ServerArgs.kv_cache_dtype,
bjmsong's avatar
bjmsong committed
386
387
388
            choices=["auto", "fp8_e5m2", "fp8_e4m3"],
            help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" and "fp8_e4m3" is supported for CUDA 11.8+.',
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
389
390
391
392
        parser.add_argument(
            "--quantization",
            type=str,
            default=ServerArgs.quantization,
Ying Sheng's avatar
Ying Sheng committed
393
394
395
396
397
398
            choices=[
                "awq",
                "fp8",
                "gptq",
                "marlin",
                "gptq_marlin",
Ying Sheng's avatar
Ying Sheng committed
399
                "awq_marlin",
Ying Sheng's avatar
Ying Sheng committed
400
                "bitsandbytes",
401
                "gguf",
402
                "modelopt",
403
                "w8a8_int8",
Ying Sheng's avatar
Ying Sheng committed
404
            ],
Lianmin Zheng's avatar
Lianmin Zheng committed
405
406
            help="The quantization method.",
        )
407
408
409
410
411
412
413
414
415
        parser.add_argument(
            "--quantization-param-path",
            type=nullable_str,
            default=None,
            help="Path to the JSON file containing the KV cache "
            "scaling factors. This should generally be supplied, when "
            "KV cache dtype is FP8. Otherwise, KV cache scaling factors "
            "default to 1.0, which may cause accuracy issues. ",
        )
416
417
418
419
420
421
422
423
424
425
        parser.add_argument(
            "--context-length",
            type=int,
            default=ServerArgs.context_length,
            help="The model's maximum context length. Defaults to None (will use the value from the model's config.json instead).",
        )
        parser.add_argument(
            "--device",
            type=str,
            default="cuda",
426
            choices=["cuda", "xpu", "hpu", "cpu"],
427
428
            help="The device type.",
        )
429
430
431
432
433
434
        parser.add_argument(
            "--served-model-name",
            type=str,
            default=ServerArgs.served_model_name,
            help="Override the model name returned by the v1/models endpoint in OpenAI API server.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
435
436
437
438
439
440
        parser.add_argument(
            "--chat-template",
            type=str,
            default=ServerArgs.chat_template,
            help="The buliltin chat template name or the path of the chat template file. This is only used for OpenAI-compatible API server.",
        )
441
442
443
444
445
        parser.add_argument(
            "--is-embedding",
            action="store_true",
            help="Whether to use a CausalLM as an embedding model.",
        )
446
447
448
449
450
451
452
453
        parser.add_argument(
            "--revision",
            type=str,
            default=None,
            help="The specific model version to use. It can be a branch "
            "name, a tag name, or a commit id. If unspecified, will use "
            "the default version.",
        )
454
        # Memory and scheduling
Lianmin Zheng's avatar
Lianmin Zheng committed
455
456
457
458
        parser.add_argument(
            "--mem-fraction-static",
            type=float,
            default=ServerArgs.mem_fraction_static,
459
            help="The fraction of the memory used for static allocation (model weights and KV cache memory pool). Use a smaller value if you see out-of-memory errors.",
Lianmin Zheng's avatar
Lianmin Zheng committed
460
        )
461
462
463
464
465
466
        parser.add_argument(
            "--max-running-requests",
            type=int,
            default=ServerArgs.max_running_requests,
            help="The maximum number of running requests.",
        )
467
468
469
470
        parser.add_argument(
            "--max-total-tokens",
            type=int,
            default=ServerArgs.max_total_tokens,
471
472
            help="The maximum number of tokens in the memory pool. If not specified, it will be automatically calculated based on the memory usage fraction. "
            "This option is typically used for development and debugging purposes.",
473
        )
474
475
476
477
478
479
480
481
482
483
484
485
        parser.add_argument(
            "--chunked-prefill-size",
            type=int,
            default=ServerArgs.chunked_prefill_size,
            help="The maximum number of tokens in a chunk for the chunked prefill. Setting this to -1 means disabling chunked prefill",
        )
        parser.add_argument(
            "--max-prefill-tokens",
            type=int,
            default=ServerArgs.max_prefill_tokens,
            help="The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model's maximum context length.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
486
        parser.add_argument(
487
            "--schedule-policy",
Lianmin Zheng's avatar
Lianmin Zheng committed
488
            type=str,
489
            default=ServerArgs.schedule_policy,
Liangsheng Yin's avatar
Liangsheng Yin committed
490
            choices=["lpm", "random", "fcfs", "dfs-weight"],
491
            help="The scheduling policy of the requests.",
Lianmin Zheng's avatar
Lianmin Zheng committed
492
        )
493
494
495
496
        parser.add_argument(
            "--schedule-conservativeness",
            type=float,
            default=ServerArgs.schedule_conservativeness,
497
            help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see requests being retracted frequently.",
498
        )
499
500
501
502
503
504
        parser.add_argument(
            "--cpu-offload-gb",
            type=int,
            default=ServerArgs.cpu_offload_gb,
            help="How many GBs of RAM to reserve for CPU offloading",
        )
505
506
507
508
509
510
        parser.add_argument(
            "--prefill-only-one-req",
            type=bool,
            help="If true, we only prefill one request at one prefill batch",
            default=ServerArgs.prefill_only_one_req,
        )
511

512
        # Other runtime options
Lianmin Zheng's avatar
Lianmin Zheng committed
513
        parser.add_argument(
514
            "--tensor-parallel-size",
Lianmin Zheng's avatar
Lianmin Zheng committed
515
            "--tp-size",
Lianmin Zheng's avatar
Lianmin Zheng committed
516
            type=int,
Lianmin Zheng's avatar
Lianmin Zheng committed
517
            default=ServerArgs.tp_size,
518
            help="The tensor parallelism size.",
519
        )
520
521
522
        parser.add_argument(
            "--stream-interval",
            type=int,
Lianmin Zheng's avatar
Lianmin Zheng committed
523
            default=ServerArgs.stream_interval,
524
            help="The interval (or buffer size) for streaming in terms of the token length. A smaller value makes streaming smoother, while a larger value makes the throughput higher",
525
        )
526
527
528
529
530
        parser.add_argument(
            "--stream-output",
            action="store_true",
            help="Whether to output as a sequence of disjoint segments.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
531
532
533
534
        parser.add_argument(
            "--random-seed",
            type=int,
            default=ServerArgs.random_seed,
535
            help="The random seed.",
Lianmin Zheng's avatar
Lianmin Zheng committed
536
        )
537
538
539
540
541
542
        parser.add_argument(
            "--constrained-json-whitespace-pattern",
            type=str,
            default=ServerArgs.constrained_json_whitespace_pattern,
            help=r"Regex pattern for syntactic whitespaces allowed in JSON constrained output. For example, to allow the model generate consecutive whitespaces, set the pattern to [\n\t ]*",
        )
543
544
545
546
547
548
        parser.add_argument(
            "--watchdog-timeout",
            type=float,
            default=ServerArgs.watchdog_timeout,
            help="Set watchdog timeout in seconds. If a forward batch takes longer than this, the server will crash to prevent hanging.",
        )
549
550
551
552
553
554
        parser.add_argument(
            "--dist-timeout",
            type=int,
            default=ServerArgs.dist_timeout,
            help="Set timeout for torch.distributed initialization.",
        )
555
556
557
558
        parser.add_argument(
            "--download-dir",
            type=str,
            default=ServerArgs.download_dir,
Lianmin Zheng's avatar
Lianmin Zheng committed
559
            help="Model download directory for huggingface.",
560
        )
561
562
563
564
565
566
        parser.add_argument(
            "--base-gpu-id",
            type=int,
            default=ServerArgs.base_gpu_id,
            help="The base GPU ID to start allocating GPUs from. Useful when running multiple instances on the same machine.",
        )
567
568
569
570
571
572
        parser.add_argument(
            "--gpu-id-step",
            type=int,
            default=ServerArgs.gpu_id_step,
            help="The delta between consecutive GPU IDs that are used. For example, setting it to 2 will use GPU 0,2,4,...",
        )
573
574

        # Logging
Lianmin Zheng's avatar
Lianmin Zheng committed
575
576
577
578
        parser.add_argument(
            "--log-level",
            type=str,
            default=ServerArgs.log_level,
579
            help="The logging level of all loggers.",
Lianmin Zheng's avatar
Lianmin Zheng committed
580
        )
581
        parser.add_argument(
582
583
584
585
            "--log-level-http",
            type=str,
            default=ServerArgs.log_level_http,
            help="The logging level of HTTP server. If not set, reuse --log-level by default.",
586
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
587
        parser.add_argument(
588
            "--log-requests",
Lianmin Zheng's avatar
Lianmin Zheng committed
589
            action="store_true",
590
591
592
593
594
595
596
597
            help="Log metadata, inputs, outputs of all requests. The verbosity is decided by --log-requests-level",
        )
        parser.add_argument(
            "--log-requests-level",
            type=int,
            default=0,
            help="0: Log metadata. 1. Log metadata and partial input/output. 2. Log every input/output.",
            choices=[0, 1, 2],
Lianmin Zheng's avatar
Lianmin Zheng committed
598
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
599
600
601
        parser.add_argument(
            "--show-time-cost",
            action="store_true",
Ying Sheng's avatar
Ying Sheng committed
602
            help="Show time cost of custom marks.",
Lianmin Zheng's avatar
Lianmin Zheng committed
603
        )
604
605
606
607
608
        parser.add_argument(
            "--enable-metrics",
            action="store_true",
            help="Enable log prometheus metrics.",
        )
609
610
611
612
        parser.add_argument(
            "--decode-log-interval",
            type=int,
            default=ServerArgs.decode_log_interval,
613
            help="The log interval of decode batch.",
614
        )
615

616
        # API related
Liangsheng Yin's avatar
Liangsheng Yin committed
617
618
619
620
        parser.add_argument(
            "--api-key",
            type=str,
            default=ServerArgs.api_key,
621
            help="Set API key of the server. It is also used in the OpenAI API compatible server.",
Liangsheng Yin's avatar
Liangsheng Yin committed
622
        )
623
        parser.add_argument(
624
            "--file-storage-path",
625
            type=str,
626
            default=ServerArgs.file_storage_path,
627
628
            help="The path of the file storage in backend.",
        )
629
630
631
632
633
        parser.add_argument(
            "--enable-cache-report",
            action="store_true",
            help="Return number of cached tokens in usage.prompt_tokens_details for each openai request.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
634

635
636
        # Data parallelism
        parser.add_argument(
637
            "--data-parallel-size",
638
639
640
            "--dp-size",
            type=int,
            default=ServerArgs.dp_size,
641
            help="The data parallelism size.",
642
643
644
645
646
        )
        parser.add_argument(
            "--load-balance-method",
            type=str,
            default=ServerArgs.load_balance_method,
647
            help="The load balancing strategy for data parallelism.",
648
649
650
651
652
            choices=[
                "round_robin",
                "shortest_queue",
            ],
        )
653

xiaobochen's avatar
xiaobochen committed
654
655
656
657
658
659
660
661
        # Expert parallelism
        parser.add_argument(
            "--expert-parallel-size",
            "--ep-size",
            type=int,
            default=ServerArgs.ep_size,
            help="The expert parallelism size.",
        )
662

663
        # Multi-node distributed serving
664
        parser.add_argument(
665
666
            "--dist-init-addr",
            "--nccl-init-addr",  # For backward compatbility. This will be removed in the future.
667
            type=str,
668
            help="The host address for initializing distributed backend (e.g., `192.168.0.2:25000`).",
669
670
        )
        parser.add_argument(
Liangsheng Yin's avatar
Liangsheng Yin committed
671
            "--nnodes", type=int, default=ServerArgs.nnodes, help="The number of nodes."
672
        )
673
674
675
        parser.add_argument(
            "--node-rank", type=int, default=ServerArgs.node_rank, help="The node rank."
        )
676

Lianmin Zheng's avatar
Lianmin Zheng committed
677
678
679
680
681
682
683
684
        # Model override args
        parser.add_argument(
            "--json-model-override-args",
            type=str,
            help="A dictionary in JSON string format used to override default model configurations.",
            default=ServerArgs.json_model_override_args,
        )

685
686
687
688
689
690
691
        # LoRA
        parser.add_argument(
            "--lora-paths",
            type=str,
            nargs="*",
            default=None,
            action=LoRAPathAction,
692
            help="The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}.",
693
694
695
696
697
        )
        parser.add_argument(
            "--max-loras-per-batch",
            type=int,
            default=8,
698
699
700
701
702
703
704
            help="Maximum number of adapters for a running batch, include base-only request.",
        )
        parser.add_argument(
            "--lora-backend",
            type=str,
            default="triton",
            help="Choose the kernel backend for multi-LoRA serving.",
705
706
707
        )

        # Kernel backend
708
709
710
        parser.add_argument(
            "--attention-backend",
            type=str,
711
            choices=["flashinfer", "triton", "torch_native"],
712
713
714
715
716
717
718
719
720
721
            default=ServerArgs.attention_backend,
            help="Choose the kernels for attention layers.",
        )
        parser.add_argument(
            "--sampling-backend",
            type=str,
            choices=["flashinfer", "pytorch"],
            default=ServerArgs.sampling_backend,
            help="Choose the kernels for sampling layers.",
        )
722
723
724
        parser.add_argument(
            "--grammar-backend",
            type=str,
725
            choices=["xgrammar", "outlines", "llguidance"],
726
            default=ServerArgs.grammar_backend,
Lianmin Zheng's avatar
Lianmin Zheng committed
727
            help="Choose the backend for grammar-guided decoding.",
728
        )
729
730
731
732
733
        parser.add_argument(
            "--enable-flashinfer-mla",
            action="store_true",
            help="Enable FlashInfer MLA optimization",
        )
734
735
736
737
738
        parser.add_argument(
            "--flashinfer-mla-disable-ragged",
            action="store_true",
            help="Not using ragged prefill wrapper when running flashinfer mla",
        )
739

740
741
742
743
        # Speculative decoding
        parser.add_argument(
            "--speculative-algorithm",
            type=str,
744
            choices=["EAGLE", "NEXTN"],
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
            help="Speculative algorithm.",
        )
        parser.add_argument(
            "--speculative-draft-model-path",
            type=str,
            help="The path of the draft model weights. This can be a local folder or a Hugging Face repo ID.",
        )
        parser.add_argument(
            "--speculative-num-steps",
            type=int,
            help="The number of steps sampled from draft model in Speculative Decoding.",
            default=ServerArgs.speculative_num_steps,
        )
        parser.add_argument(
            "--speculative-eagle-topk",
            type=int,
761
            help="The number of tokens sampled from the draft model in eagle2 each step.",
762
763
764
            choices=[1, 2, 4, 8],
            default=ServerArgs.speculative_eagle_topk,
        )
765
766
767
        parser.add_argument(
            "--speculative-num-draft-tokens",
            type=int,
768
            help="The number of tokens sampled from the draft model in Speculative Decoding.",
769
770
            default=ServerArgs.speculative_num_draft_tokens,
        )
771
772
773
774
775
776
777
778
779
780
781
782
        parser.add_argument(
            "--speculative-accept-threshold-single",
            type=float,
            help="Accept a draft token if its probability in the target model is greater than this threshold.",
            default=ServerArgs.speculative_accept_threshold_single,
        )
        parser.add_argument(
            "--speculative-accept-threshold-acc",
            type=float,
            help="The accept probability of a draft token is raised from its target probability p to min(1, p / threshold_acc).",
            default=ServerArgs.speculative_accept_threshold_acc,
        )
783
784
785
786
787
788
        parser.add_argument(
            "--speculative-token-map",
            type=str,
            help="The path of the draft model's small vocab table.",
            default=ServerArgs.speculative_token_map,
        )
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826

        # Double Sparsity
        parser.add_argument(
            "--enable-double-sparsity",
            action="store_true",
            help="Enable double sparsity attention",
        )
        parser.add_argument(
            "--ds-channel-config-path",
            type=str,
            default=ServerArgs.ds_channel_config_path,
            help="The path of the double sparsity channel config",
        )
        parser.add_argument(
            "--ds-heavy-channel-num",
            type=int,
            default=ServerArgs.ds_heavy_channel_num,
            help="The number of heavy channels in double sparsity attention",
        )
        parser.add_argument(
            "--ds-heavy-token-num",
            type=int,
            default=ServerArgs.ds_heavy_token_num,
            help="The number of heavy tokens in double sparsity attention",
        )
        parser.add_argument(
            "--ds-heavy-channel-type",
            type=str,
            default=ServerArgs.ds_heavy_channel_type,
            help="The type of heavy channels in double sparsity attention",
        )
        parser.add_argument(
            "--ds-sparse-decode-threshold",
            type=int,
            default=ServerArgs.ds_sparse_decode_threshold,
            help="The type of heavy channels in double sparsity attention",
        )

827
        # Optimization/debug options
Liangsheng Yin's avatar
Liangsheng Yin committed
828
        parser.add_argument(
Lianmin Zheng's avatar
Lianmin Zheng committed
829
            "--disable-radix-cache",
Liangsheng Yin's avatar
Liangsheng Yin committed
830
            action="store_true",
Ying Sheng's avatar
Ying Sheng committed
831
            help="Disable RadixAttention for prefix caching.",
Liangsheng Yin's avatar
Liangsheng Yin committed
832
        )
833
834
835
836
837
        parser.add_argument(
            "--disable-cuda-graph",
            action="store_true",
            help="Disable cuda graph.",
        )
838
        parser.add_argument(
839
840
841
842
            "--disable-cuda-graph-padding",
            action="store_true",
            help="Disable cuda graph when padding is needed. Still uses cuda graph when padding is not needed.",
        )
843
844
845
846
847
        parser.add_argument(
            "--enable-nccl-nvls",
            action="store_true",
            help="Enable NCCL NVLS for prefill heavy requests when available.",
        )
848
        parser.add_argument(
849
            "--disable-outlines-disk-cache",
850
            action="store_true",
851
            help="Disable disk cache of outlines to avoid possible crashes related to file system or high concurrency.",
852
        )
853
854
855
856
857
        parser.add_argument(
            "--disable-custom-all-reduce",
            action="store_true",
            help="Disable the custom all-reduce kernel and fall back to NCCL.",
        )
Ke Bao's avatar
Ke Bao committed
858
859
860
        parser.add_argument(
            "--disable-mla",
            action="store_true",
Xiaoyu Zhang's avatar
Xiaoyu Zhang committed
861
            help="Disable Multi-head Latent Attention (MLA) for DeepSeek V2/V3/R1 series models.",
Ke Bao's avatar
Ke Bao committed
862
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
863
        parser.add_argument(
864
            "--disable-overlap-schedule",
Lianmin Zheng's avatar
Lianmin Zheng committed
865
            action="store_true",
866
            help="Disable the overlap scheduler, which overlaps the CPU scheduler with GPU model worker.",
Lianmin Zheng's avatar
Lianmin Zheng committed
867
        )
868
869
870
        parser.add_argument(
            "--enable-mixed-chunk",
            action="store_true",
871
            help="Enabling mixing prefill and decode in a batch when using chunked prefill.",
872
        )
Ke Bao's avatar
Ke Bao committed
873
874
875
876
877
        parser.add_argument(
            "--enable-dp-attention",
            action="store_true",
            help="Enabling data parallelism for attention and tensor parallelism for FFN. The dp size should be equal to the tp size. Currently only DeepSeek-V2 is supported.",
        )
xiaobochen's avatar
xiaobochen committed
878
879
880
881
882
        parser.add_argument(
            "--enable-ep-moe",
            action="store_true",
            help="Enabling expert parallelism for moe. The ep size is equal to the tp size.",
        )
883
884
885
        parser.add_argument(
            "--enable-torch-compile",
            action="store_true",
886
887
            help="Optimize the model with torch.compile. Experimental feature.",
        )
888
        parser.add_argument(
889
            "--torch-compile-max-bs",
890
            type=int,
891
            default=ServerArgs.torch_compile_max_bs,
892
893
            help="Set the maximum batch size when using torch compile.",
        )
894
        parser.add_argument(
895
            "--cuda-graph-max-bs",
896
            type=int,
897
            default=ServerArgs.cuda_graph_max_bs,
898
899
            help="Set the maximum batch size for cuda graph.",
        )
900
901
902
903
904
905
        parser.add_argument(
            "--cuda-graph-bs",
            type=int,
            nargs="+",
            help="Set the list of batch sizes for cuda graph.",
        )
906
907
908
909
        parser.add_argument(
            "--torchao-config",
            type=str,
            default=ServerArgs.torchao_config,
910
            help="Optimize the model with torchao. Experimental feature. Current choices are: int8dq, int8wo, int4wo-<group_size>, fp8wo, fp8dq-per_tensor, fp8dq-per_row",
911
        )
912
913
914
915
916
        parser.add_argument(
            "--enable-nan-detection",
            action="store_true",
            help="Enable the NaN detection for debugging purposes.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
917
        parser.add_argument(
918
            "--enable-p2p-check",
Lianmin Zheng's avatar
Lianmin Zheng committed
919
            action="store_true",
920
            help="Enable P2P check for GPU access, otherwise the p2p access is allowed by default.",
Lianmin Zheng's avatar
Lianmin Zheng committed
921
        )
922
        parser.add_argument(
923
            "--triton-attention-reduce-in-fp32",
924
            action="store_true",
925
            help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
926
            "This only affects Triton attention kernels.",
927
        )
928
929
930
931
932
933
        parser.add_argument(
            "--triton-attention-num-kv-splits",
            type=int,
            default=ServerArgs.triton_attention_num_kv_splits,
            help="The number of KV splits in flash decoding Triton kernel. Larger value is better in longer context scenarios. The default value is 8.",
        )
934
935
936
937
938
939
940
941
        parser.add_argument(
            "--num-continuous-decode-steps",
            type=int,
            default=ServerArgs.num_continuous_decode_steps,
            help="Run multiple continuous decoding steps to reduce scheduling overhead. "
            "This can potentially increase throughput but may also increase time-to-first-token latency. "
            "The default value is 1, meaning only run one decoding step at a time.",
        )
942
943
944
945
946
        parser.add_argument(
            "--delete-ckpt-after-loading",
            action="store_true",
            help="Delete the model checkpoint after loading the model.",
        )
947
948
949
950
951
        parser.add_argument(
            "--enable-memory-saver",
            action="store_true",
            help="Allow saving memory using release_memory_occupation and resume_memory_occupation",
        )
952
953
954
955
956
        parser.add_argument(
            "--allow-auto-truncate",
            action="store_true",
            help="Allow automatically truncating requests that exceed the maximum input length instead of returning an error.",
        )
957
958
959
960
961
        parser.add_argument(
            "--enable-custom-logit-processor",
            action="store_true",
            help="Enable users to pass custom logit processors to the server (disabled by default for security)",
        )
YAMY's avatar
YAMY committed
962
963
964
965
966
967
968
        parser.add_argument(
            "--tool-call-parser",
            type=str,
            choices=["qwen25", "mistral", "llama3"],
            default=ServerArgs.tool_call_parser,
            help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', and 'llama3'.",
        )
969
970
971
972
973
        parser.add_argument(
            "--enable-hierarchical-cache",
            action="store_true",
            help="Enable hierarchical cache",
        )
974

975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
        # Server warmups
        parser.add_argument(
            "--warmups",
            type=str,
            required=False,
            help="Specify custom warmup functions (csv) to run before server starts eg. --warmups=warmup_name1,warmup_name2 "
            "will run the functions `warmup_name1` and `warmup_name2` specified in warmup.py before the server starts listening for requests",
        )

        # Debug tensor dumps
        parser.add_argument(
            "--debug-tensor-dump-output-folder",
            type=str,
            default=ServerArgs.debug_tensor_dump_output_folder,
            help="The output folder for dumping tensors.",
        )
        parser.add_argument(
            "--debug-tensor-dump-input-file",
            type=str,
            default=ServerArgs.debug_tensor_dump_input_file,
            help="The input filename for dumping tensors",
        )
        parser.add_argument(
            "--debug-tensor-dump-inject",
            type=str,
            default=ServerArgs.debug_tensor_dump_inject,
            help="Inject the outputs from jax as the input of every layer.",
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
1004
1005
    @classmethod
    def from_cli_args(cls, args: argparse.Namespace):
1006
1007
        args.tp_size = args.tensor_parallel_size
        args.dp_size = args.data_parallel_size
xiaobochen's avatar
xiaobochen committed
1008
        args.ep_size = args.expert_parallel_size
Lianmin Zheng's avatar
Lianmin Zheng committed
1009
1010
1011
1012
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        return cls(**{attr: getattr(args, attr) for attr in attrs})

    def url(self):
1013
        if is_valid_ipv6_address(self.host):
1014
1015
1016
            return f"http://[{self.host}]:{self.port}"
        else:
            return f"http://{self.host}:{self.port}"
Lianmin Zheng's avatar
Lianmin Zheng committed
1017

1018
1019
1020
1021
1022
    def check_server_args(self):
        assert (
            self.tp_size % self.nnodes == 0
        ), "tp_size must be divisible by number of nodes"
        assert not (
1023
1024
            self.dp_size > 1 and self.nnodes != 1 and not self.enable_dp_attention
        ), "multi-node data parallel is not supported unless dp attention!"
1025
1026
1027
1028
1029
1030
        assert (
            self.max_loras_per_batch > 0
            # FIXME
            and (self.lora_paths is None or self.disable_cuda_graph)
            and (self.lora_paths is None or self.disable_radix_cache)
        ), "compatibility of lora and cuda graph and radix attention is in progress"
1031
        assert self.base_gpu_id >= 0, "base_gpu_id must be non-negative"
1032
        assert self.gpu_id_step >= 1, "gpu_id_step must be positive"
1033

1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
        if isinstance(self.lora_paths, list):
            lora_paths = self.lora_paths
            self.lora_paths = {}
            for lora_path in lora_paths:
                if "=" in lora_path:
                    name, path = lora_path.split("=", 1)
                    self.lora_paths[name] = path
                else:
                    self.lora_paths[lora_path] = lora_path

Lianmin Zheng's avatar
Lianmin Zheng committed
1044

Lianmin Zheng's avatar
Lianmin Zheng committed
1045
def prepare_server_args(argv: List[str]) -> ServerArgs:
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
    """
    Prepare the server arguments from the command line arguments.

    Args:
        args: The command line arguments. Typically, it should be `sys.argv[1:]`
            to ensure compatibility with `parse_args` when no arguments are passed.

    Returns:
        The server arguments.
    """
    parser = argparse.ArgumentParser()
    ServerArgs.add_cli_args(parser)
Lianmin Zheng's avatar
Lianmin Zheng committed
1058
    raw_args = parser.parse_args(argv)
1059
1060
1061
1062
    server_args = ServerArgs.from_cli_args(raw_args)
    return server_args


1063
1064
1065
ZMQ_TCP_PORT_DELTA = 233


Lianmin Zheng's avatar
Lianmin Zheng committed
1066
1067
@dataclasses.dataclass
class PortArgs:
1068
1069
1070
1071
1072
1073
    # The ipc filename for tokenizer to receive inputs from detokenizer (zmq)
    tokenizer_ipc_name: str
    # The ipc filename for scheduler (rank 0) to receive inputs from tokenizer (zmq)
    scheduler_input_ipc_name: str
    # The ipc filename for detokenizer to receive inputs from scheduler (zmq)
    detokenizer_ipc_name: str
1074

1075
1076
    # The port for nccl initialization (torch.dist)
    nccl_port: int
1077

1078
    @staticmethod
1079
    def init_new(server_args, dp_rank: Optional[int] = None) -> "PortArgs":
1080
        port = server_args.port + random.randint(100, 1000)
1081
1082
1083
        while True:
            if is_port_available(port):
                break
TianYu GUO's avatar
TianYu GUO committed
1084
1085
1086
1087
            if port < 60000:
                port += 42
            else:
                port -= 43
1088

1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
        if not server_args.enable_dp_attention:
            # Normal case, use IPC within a single node
            return PortArgs(
                tokenizer_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
                scheduler_input_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
                detokenizer_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
                nccl_port=port,
            )
        else:
            # DP attention. Use TCP + port to handle both single-node and multi-node.
            if server_args.nnodes == 1 and server_args.dist_init_addr is None:
                dist_init_addr = ("127.0.0.1", server_args.port + ZMQ_TCP_PORT_DELTA)
            else:
                dist_init_addr = server_args.dist_init_addr.split(":")
            assert (
                len(dist_init_addr) == 2
            ), "please provide --dist-init-addr as host:port of head node"

            dist_init_host, dist_init_port = dist_init_addr
            port_base = int(dist_init_port) + 1
            if dp_rank is None:
                scheduler_input_port = (
                    port_base + 2
1112
                )  # TokenizerManager to DataParallelController
1113
1114
1115
1116
1117
1118
1119
1120
1121
            else:
                scheduler_input_port = port_base + 2 + 1 + dp_rank

            return PortArgs(
                tokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base}",
                scheduler_input_ipc_name=f"tcp://{dist_init_host}:{scheduler_input_port}",
                detokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base + 1}",
                nccl_port=port,
            )
1122

1123
1124
1125
1126
1127
1128
1129
1130
1131
1132

class LoRAPathAction(argparse.Action):
    def __call__(self, parser, namespace, values, option_string=None):
        setattr(namespace, self.dest, {})
        for lora_path in values:
            if "=" in lora_path:
                name, path = lora_path.split("=", 1)
                getattr(namespace, self.dest)[name] = path
            else:
                getattr(namespace, self.dest)[lora_path] = lora_path
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142


class DeprecatedAction(argparse.Action):
    def __init__(self, option_strings, dest, nargs=0, **kwargs):
        super(DeprecatedAction, self).__init__(
            option_strings, dest, nargs=nargs, **kwargs
        )

    def __call__(self, parser, namespace, values, option_string=None):
        raise ValueError(self.help)