server_args.py 36.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
Lianmin Zheng's avatar
Lianmin Zheng committed
14
15
"""The arguments of the server."""

Lianmin Zheng's avatar
Lianmin Zheng committed
16
17
import argparse
import dataclasses
18
import logging
19
import random
20
21
import tempfile
from typing import List, Optional
Lianmin Zheng's avatar
Lianmin Zheng committed
22

23
24
import torch

25
from sglang.srt.hf_transformers_utils import check_gguf_file
26
from sglang.srt.utils import (
HAI's avatar
HAI committed
27
    get_amdgpu_memory_capacity,
28
    get_hpu_memory_capacity,
HAI's avatar
HAI committed
29
    get_nvgpu_memory_capacity,
30
    is_flashinfer_available,
HAI's avatar
HAI committed
31
    is_hip,
32
33
    is_ipv6,
    is_port_available,
bjmsong's avatar
bjmsong committed
34
    nullable_str,
35
)
36

37
38
logger = logging.getLogger(__name__)

Lianmin Zheng's avatar
Lianmin Zheng committed
39
40
41

@dataclasses.dataclass
class ServerArgs:
Lianmin Zheng's avatar
Lianmin Zheng committed
42
    # Model and tokenizer
Lianmin Zheng's avatar
Lianmin Zheng committed
43
44
45
    model_path: str
    tokenizer_path: Optional[str] = None
    tokenizer_mode: str = "auto"
Lianmin Zheng's avatar
Lianmin Zheng committed
46
    load_format: str = "auto"
47
    trust_remote_code: bool = True
Lianmin Zheng's avatar
Lianmin Zheng committed
48
    dtype: str = "auto"
49
    kv_cache_dtype: str = "auto"
bjmsong's avatar
bjmsong committed
50
    quantization_param_path: nullable_str = None
Lianmin Zheng's avatar
Lianmin Zheng committed
51
    quantization: Optional[str] = None
52
53
    context_length: Optional[int] = None
    device: str = "cuda"
54
    served_model_name: Optional[str] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
55
    chat_template: Optional[str] = None
56
    is_embedding: bool = False
57
    revision: Optional[str] = None
58
    skip_tokenizer_init: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
59

60
    # Port for the HTTP server
Lianmin Zheng's avatar
Lianmin Zheng committed
61
62
63
64
    host: str = "127.0.0.1"
    port: int = 30000

    # Memory and scheduling
Lianmin Zheng's avatar
Lianmin Zheng committed
65
    mem_fraction_static: Optional[float] = None
66
    max_running_requests: Optional[int] = None
67
    max_total_tokens: Optional[int] = None
68
    chunked_prefill_size: Optional[int] = None
69
    max_prefill_tokens: int = 16384
70
    schedule_policy: str = "lpm"
71
    schedule_conservativeness: float = 1.0
72
    cpu_offload_gb: int = 0
73
    prefill_only_one_req: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
74
75
76

    # Other runtime options
    tp_size: int = 1
77
    stream_interval: int = 1
78
    random_seed: Optional[int] = None
79
    constrained_json_whitespace_pattern: Optional[str] = None
80
    watchdog_timeout: float = 300
81
    download_dir: Optional[str] = None
82
    base_gpu_id: int = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
83
84
85

    # Logging
    log_level: str = "info"
86
    log_level_http: Optional[str] = None
87
    log_requests: bool = False
Liangsheng Yin's avatar
Liangsheng Yin committed
88
    show_time_cost: bool = False
89
    enable_metrics: bool = False
90
    decode_log_interval: int = 40
Lianmin Zheng's avatar
Lianmin Zheng committed
91
    dump_requests_folder: str = ""
Liangsheng Yin's avatar
Liangsheng Yin committed
92

93
    # API related
94
    api_key: Optional[str] = None
95
    file_storage_pth: str = "SGLang_storage"
96
    enable_cache_report: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
97

98
99
100
    # Data parallelism
    dp_size: int = 1
    load_balance_method: str = "round_robin"
101

xiaobochen's avatar
xiaobochen committed
102
103
    # Expert parallelism
    ep_size: int = 1
104

105
    # Multi-node distributed serving
106
    dist_init_addr: Optional[str] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
107
    nnodes: int = 1
108
    node_rank: int = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
109
110
111
112

    # Model override args in JSON
    json_model_override_args: str = "{}"

113
114
115
116
117
    # LoRA
    lora_paths: Optional[List[str]] = None
    max_loras_per_batch: int = 8

    # Kernel backend
118
119
    attention_backend: Optional[str] = None
    sampling_backend: Optional[str] = None
120
    grammar_backend: Optional[str] = "outlines"
121

122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
    # Speculative decoding
    speculative_draft_model_path: Optional[str] = None
    speculative_algorithm: Optional[str] = None
    speculative_num_steps: int = 5
    speculative_num_draft_tokens: int = 64
    speculative_eagle_topk: int = 8

    # Double Sparsity
    enable_double_sparsity: bool = False
    ds_channel_config_path: str = None
    ds_heavy_channel_num: int = 32
    ds_heavy_token_num: int = 256
    ds_heavy_channel_type: str = "qk"
    ds_sparse_decode_threshold: int = 4096

137
    # Optimization/debug options
Lianmin Zheng's avatar
Lianmin Zheng committed
138
    disable_radix_cache: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
139
    disable_jump_forward: bool = False
140
    disable_cuda_graph: bool = False
141
    disable_cuda_graph_padding: bool = False
142
    disable_outlines_disk_cache: bool = False
143
    disable_custom_all_reduce: bool = False
Ke Bao's avatar
Ke Bao committed
144
    disable_mla: bool = False
145
    disable_overlap_schedule: bool = False
146
    enable_mixed_chunk: bool = False
Ke Bao's avatar
Ke Bao committed
147
    enable_dp_attention: bool = False
xiaobochen's avatar
xiaobochen committed
148
    enable_ep_moe: bool = False
149
    enable_torch_compile: bool = False
150
    torch_compile_max_bs: int = 32
151
    cuda_graph_max_bs: Optional[int] = None
152
    cuda_graph_bs: Optional[List[int]] = None
153
    torchao_config: str = ""
154
    enable_nan_detection: bool = False
155
    enable_p2p_check: bool = False
156
    triton_attention_reduce_in_fp32: bool = False
157
    triton_attention_num_kv_splits: int = 8
158
    num_continuous_decode_steps: int = 1
159
    delete_ckpt_after_loading: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
160
161

    def __post_init__(self):
162
        # Set missing default values
Lianmin Zheng's avatar
Lianmin Zheng committed
163
164
        if self.tokenizer_path is None:
            self.tokenizer_path = self.model_path
165
166
167
168

        if self.served_model_name is None:
            self.served_model_name = self.model_path

169
170
171
        if self.random_seed is None:
            self.random_seed = random.randint(0, 1 << 30)

172
173
        if is_hip():
            gpu_mem = get_amdgpu_memory_capacity()
174
        elif torch.cuda.is_available():
175
            gpu_mem = get_nvgpu_memory_capacity()
176
177
        elif self.device == "hpu":
            gpu_mem = get_hpu_memory_capacity()
178
179
180
        else:
            # GPU memory is not known yet or no GPU is available.
            gpu_mem = None
181
182

        # Set mem fraction static, which depends on the tensor parallelism size
Lianmin Zheng's avatar
Lianmin Zheng committed
183
        if self.mem_fraction_static is None:
184
            if self.tp_size >= 16:
185
                self.mem_fraction_static = 0.79
186
            elif self.tp_size >= 8:
187
                self.mem_fraction_static = 0.81
Lianmin Zheng's avatar
Lianmin Zheng committed
188
            elif self.tp_size >= 4:
189
                self.mem_fraction_static = 0.85
Lianmin Zheng's avatar
Lianmin Zheng committed
190
            elif self.tp_size >= 2:
191
                self.mem_fraction_static = 0.87
Ying Sheng's avatar
Ying Sheng committed
192
            else:
193
                self.mem_fraction_static = 0.88
194

195
196
        # Set chunked prefill size, which depends on the gpu memory capacity
        if self.chunked_prefill_size is None:
197
            if gpu_mem is not None and gpu_mem < 25_000:
198
199
200
                self.chunked_prefill_size = 2048
            else:
                self.chunked_prefill_size = 8192
201

202
203
        # Set cuda graph max batch size
        if self.cuda_graph_max_bs is None:
204
            # Based on detailed statistics, when serving TP1/TP2 models on lower-end GPUs with HBM<25G, you can either disable cuda graph or set `cuda_graph_max_bs` to a very small value to reduce the memory overhead of creating cuda graphs, with almost no impact on performance. However, when serving models with TP4 or TP8, we need to enable cuda graph to maintain high performance. In this case, we can set `cuda_graph_max_bs` to 80 (half of the default value 160) to reduce the memory overhead of creating cuda graphs. Looking at the logs from TP4 serving of qwen2-72b, a value of 80 is sufficient and can reduce the memory overhead of creating cuda graphs on lower-end GPUs compared to the original 160, avoiding OOM issues.
205
            if gpu_mem is not None and gpu_mem < 25_000:
206
207
208
209
                if self.tp_size < 4:
                    self.cuda_graph_max_bs = 8
                else:
                    self.cuda_graph_max_bs = 80
210
211
            else:
                self.cuda_graph_max_bs = 160
212

213
        # Choose kernel backends
214
215
216
217
        if self.device == "hpu":
            self.attention_backend = "torch_native"
            self.sampling_backend = "pytorch"

218
        if self.attention_backend is None:
219
220
221
            self.attention_backend = (
                "flashinfer" if is_flashinfer_available() else "triton"
            )
222
        if self.sampling_backend is None:
223
224
225
226
227
            self.sampling_backend = (
                "flashinfer" if is_flashinfer_available() else "pytorch"
            )

        if self.attention_backend == "torch_native":
228
            logger.warning(
229
230
231
                "Cuda graph is disabled because of using torch native attention backend"
            )
            self.disable_cuda_graph = True
232

233
234
235
236
237
238
239
        # Expert parallelism
        if self.enable_ep_moe:
            self.ep_size = self.tp_size
            logger.info(
                f"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
            )

240
        # Others
Ke Bao's avatar
Ke Bao committed
241
242
243
        if self.enable_dp_attention:
            self.dp_size = self.tp_size
            self.chunked_prefill_size = self.chunked_prefill_size // 2
244
            self.schedule_conservativeness = self.schedule_conservativeness * 0.3
245
            self.disable_overlap_schedule = True
246
            logger.warning(
247
                f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE kernel issues. "
248
                f"The schedule conservativeness is adjusted to {self.schedule_conservativeness}. "
249
                "Data parallel size is adjusted to be the same as tensor parallel size. "
250
                "Overlap scheduler is disabled."
251
252
            )

253
254
255
256
257
258
259
260
261
262
263
        # Speculative Decoding
        if self.speculative_algorithm == "EAGLE":
            self.prefill_only_one_req = True
            self.disable_cuda_graph_padding = True
            self.disable_radix_cache = True
            self.disable_overlap_schedule = True
            self.chunked_prefill_size = -1
            logger.info(
                "The radix cache, chunked prefill, and overlap scheduler are disabled because of using eagle speculative decoding."
            )

264
265
266
267
268
269
        # GGUF
        if (
            self.load_format == "auto" or self.load_format == "gguf"
        ) and check_gguf_file(self.model_path):
            self.quantization = self.load_format = "gguf"

Lianmin Zheng's avatar
Lianmin Zheng committed
270
271
    @staticmethod
    def add_cli_args(parser: argparse.ArgumentParser):
272
        # Model and port args
Lianmin Zheng's avatar
Lianmin Zheng committed
273
274
275
276
277
278
279
280
281
282
283
284
        parser.add_argument(
            "--model-path",
            type=str,
            help="The path of the model weights. This can be a local folder or a Hugging Face repo ID.",
            required=True,
        )
        parser.add_argument(
            "--tokenizer-path",
            type=str,
            default=ServerArgs.tokenizer_path,
            help="The path of the tokenizer.",
        )
Yuanhan Zhang's avatar
Yuanhan Zhang committed
285
286
287
288
289
290
        parser.add_argument(
            "--host", type=str, default=ServerArgs.host, help="The host of the server."
        )
        parser.add_argument(
            "--port", type=int, default=ServerArgs.port, help="The port of the server."
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
291
292
293
294
295
296
297
298
299
        parser.add_argument(
            "--tokenizer-mode",
            type=str,
            default=ServerArgs.tokenizer_mode,
            choices=["auto", "slow"],
            help="Tokenizer mode. 'auto' will use the fast "
            "tokenizer if available, and 'slow' will "
            "always use the slow tokenizer.",
        )
300
301
302
303
304
        parser.add_argument(
            "--skip-tokenizer-init",
            action="store_true",
            help="If set, skip init tokenizer and pass input_ids in generate request",
        )
305
        parser.add_argument(
Lianmin Zheng's avatar
Lianmin Zheng committed
306
307
308
            "--load-format",
            type=str,
            default=ServerArgs.load_format,
309
310
311
312
313
314
315
316
317
            choices=[
                "auto",
                "pt",
                "safetensors",
                "npcache",
                "dummy",
                "gguf",
                "bitsandbytes",
            ],
Lianmin Zheng's avatar
Lianmin Zheng committed
318
319
320
321
322
323
324
325
326
            help="The format of the model weights to load. "
            '"auto" will try to load the weights in the safetensors format '
            "and fall back to the pytorch bin format if safetensors format "
            "is not available. "
            '"pt" will load the weights in the pytorch bin format. '
            '"safetensors" will load the weights in the safetensors format. '
            '"npcache" will load the weights in pytorch format and store '
            "a numpy cache to speed up the loading. "
            '"dummy" will initialize the weights with random values, '
327
            "which is mainly for profiling."
328
329
330
            '"gguf" will load the weights in the gguf format. '
            '"bitsandbytes" will load the weights using bitsandbytes '
            "quantization.",
Lianmin Zheng's avatar
Lianmin Zheng committed
331
        )
332
333
334
335
336
        parser.add_argument(
            "--trust-remote-code",
            action="store_true",
            help="Whether or not to allow for custom models defined on the Hub in their own modeling files.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
337
        parser.add_argument(
Lianmin Zheng's avatar
Lianmin Zheng committed
338
            "--dtype",
Cody Yu's avatar
Cody Yu committed
339
            type=str,
Lianmin Zheng's avatar
Lianmin Zheng committed
340
            default=ServerArgs.dtype,
Ying Sheng's avatar
Ying Sheng committed
341
342
            choices=["auto", "half", "float16", "bfloat16", "float", "float32"],
            help="Data type for model weights and activations.\n\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
343
            '* "auto" will use FP16 precision for FP32 and FP16 models, and '
Ying Sheng's avatar
Ying Sheng committed
344
            "BF16 precision for BF16 models.\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
345
346
347
348
            '* "half" for FP16. Recommended for AWQ quantization.\n'
            '* "float16" is the same as "half".\n'
            '* "bfloat16" for a balance between precision and range.\n'
            '* "float" is shorthand for FP32 precision.\n'
Ying Sheng's avatar
Ying Sheng committed
349
350
            '* "float32" for FP32 precision.',
        )
351
352
353
354
        parser.add_argument(
            "--kv-cache-dtype",
            type=str,
            default=ServerArgs.kv_cache_dtype,
bjmsong's avatar
bjmsong committed
355
356
357
358
359
360
361
362
363
364
365
            choices=["auto", "fp8_e5m2", "fp8_e4m3"],
            help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" and "fp8_e4m3" is supported for CUDA 11.8+.',
        )
        parser.add_argument(
            "--quantization-param-path",
            type=nullable_str,
            default=None,
            help="Path to the JSON file containing the KV cache "
            "scaling factors. This should generally be supplied, when "
            "KV cache dtype is FP8. Otherwise, KV cache scaling factors "
            "default to 1.0, which may cause accuracy issues. ",
366
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
367
368
369
370
        parser.add_argument(
            "--quantization",
            type=str,
            default=ServerArgs.quantization,
Ying Sheng's avatar
Ying Sheng committed
371
372
373
374
375
376
            choices=[
                "awq",
                "fp8",
                "gptq",
                "marlin",
                "gptq_marlin",
Ying Sheng's avatar
Ying Sheng committed
377
                "awq_marlin",
Ying Sheng's avatar
Ying Sheng committed
378
                "bitsandbytes",
379
                "gguf",
380
                "modelopt",
Ying Sheng's avatar
Ying Sheng committed
381
            ],
Lianmin Zheng's avatar
Lianmin Zheng committed
382
383
            help="The quantization method.",
        )
384
385
386
387
388
389
390
391
392
393
        parser.add_argument(
            "--context-length",
            type=int,
            default=ServerArgs.context_length,
            help="The model's maximum context length. Defaults to None (will use the value from the model's config.json instead).",
        )
        parser.add_argument(
            "--device",
            type=str,
            default="cuda",
394
            choices=["cuda", "xpu", "hpu"],
395
396
            help="The device type.",
        )
397
398
399
400
401
402
        parser.add_argument(
            "--served-model-name",
            type=str,
            default=ServerArgs.served_model_name,
            help="Override the model name returned by the v1/models endpoint in OpenAI API server.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
403
404
405
406
407
408
        parser.add_argument(
            "--chat-template",
            type=str,
            default=ServerArgs.chat_template,
            help="The buliltin chat template name or the path of the chat template file. This is only used for OpenAI-compatible API server.",
        )
409
410
411
412
413
        parser.add_argument(
            "--is-embedding",
            action="store_true",
            help="Whether to use a CausalLM as an embedding model.",
        )
414
415
416
417
418
419
420
421
        parser.add_argument(
            "--revision",
            type=str,
            default=None,
            help="The specific model version to use. It can be a branch "
            "name, a tag name, or a commit id. If unspecified, will use "
            "the default version.",
        )
422
        # Memory and scheduling
Lianmin Zheng's avatar
Lianmin Zheng committed
423
424
425
426
        parser.add_argument(
            "--mem-fraction-static",
            type=float,
            default=ServerArgs.mem_fraction_static,
427
            help="The fraction of the memory used for static allocation (model weights and KV cache memory pool). Use a smaller value if you see out-of-memory errors.",
Lianmin Zheng's avatar
Lianmin Zheng committed
428
        )
429
430
431
432
433
434
        parser.add_argument(
            "--max-running-requests",
            type=int,
            default=ServerArgs.max_running_requests,
            help="The maximum number of running requests.",
        )
435
436
437
438
        parser.add_argument(
            "--max-total-tokens",
            type=int,
            default=ServerArgs.max_total_tokens,
439
440
            help="The maximum number of tokens in the memory pool. If not specified, it will be automatically calculated based on the memory usage fraction. "
            "This option is typically used for development and debugging purposes.",
441
        )
442
443
444
445
446
447
448
449
450
451
452
453
        parser.add_argument(
            "--chunked-prefill-size",
            type=int,
            default=ServerArgs.chunked_prefill_size,
            help="The maximum number of tokens in a chunk for the chunked prefill. Setting this to -1 means disabling chunked prefill",
        )
        parser.add_argument(
            "--max-prefill-tokens",
            type=int,
            default=ServerArgs.max_prefill_tokens,
            help="The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model's maximum context length.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
454
        parser.add_argument(
455
            "--schedule-policy",
Lianmin Zheng's avatar
Lianmin Zheng committed
456
            type=str,
457
            default=ServerArgs.schedule_policy,
Liangsheng Yin's avatar
Liangsheng Yin committed
458
            choices=["lpm", "random", "fcfs", "dfs-weight"],
459
            help="The scheduling policy of the requests.",
Lianmin Zheng's avatar
Lianmin Zheng committed
460
        )
461
462
463
464
        parser.add_argument(
            "--schedule-conservativeness",
            type=float,
            default=ServerArgs.schedule_conservativeness,
465
            help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see requests being retracted frequently.",
466
        )
467
468
469
470
471
472
        parser.add_argument(
            "--cpu-offload-gb",
            type=int,
            default=ServerArgs.cpu_offload_gb,
            help="How many GBs of RAM to reserve for CPU offloading",
        )
473
474
475
476
477
478
        parser.add_argument(
            "--prefill-only-one-req",
            type=bool,
            help="If true, we only prefill one request at one prefill batch",
            default=ServerArgs.prefill_only_one_req,
        )
479

480
        # Other runtime options
Lianmin Zheng's avatar
Lianmin Zheng committed
481
        parser.add_argument(
482
            "--tensor-parallel-size",
Lianmin Zheng's avatar
Lianmin Zheng committed
483
            "--tp-size",
Lianmin Zheng's avatar
Lianmin Zheng committed
484
            type=int,
Lianmin Zheng's avatar
Lianmin Zheng committed
485
            default=ServerArgs.tp_size,
486
            help="The tensor parallelism size.",
487
        )
488
489
490
        parser.add_argument(
            "--stream-interval",
            type=int,
Lianmin Zheng's avatar
Lianmin Zheng committed
491
            default=ServerArgs.stream_interval,
492
            help="The interval (or buffer size) for streaming in terms of the token length. A smaller value makes streaming smoother, while a larger value makes the throughput higher",
493
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
494
495
496
497
        parser.add_argument(
            "--random-seed",
            type=int,
            default=ServerArgs.random_seed,
498
            help="The random seed.",
Lianmin Zheng's avatar
Lianmin Zheng committed
499
        )
500
501
502
503
504
505
        parser.add_argument(
            "--constrained-json-whitespace-pattern",
            type=str,
            default=ServerArgs.constrained_json_whitespace_pattern,
            help=r"Regex pattern for syntactic whitespaces allowed in JSON constrained output. For example, to allow the model generate consecutive whitespaces, set the pattern to [\n\t ]*",
        )
506
507
508
509
510
511
        parser.add_argument(
            "--watchdog-timeout",
            type=float,
            default=ServerArgs.watchdog_timeout,
            help="Set watchdog timeout in seconds. If a forward batch takes longer than this, the server will crash to prevent hanging.",
        )
512
513
514
515
516
517
        parser.add_argument(
            "--download-dir",
            type=str,
            default=ServerArgs.download_dir,
            help="Model download directory.",
        )
518
519
520
521
522
523
        parser.add_argument(
            "--base-gpu-id",
            type=int,
            default=ServerArgs.base_gpu_id,
            help="The base GPU ID to start allocating GPUs from. Useful when running multiple instances on the same machine.",
        )
524
525

        # Logging
Lianmin Zheng's avatar
Lianmin Zheng committed
526
527
528
529
        parser.add_argument(
            "--log-level",
            type=str,
            default=ServerArgs.log_level,
530
            help="The logging level of all loggers.",
Lianmin Zheng's avatar
Lianmin Zheng committed
531
        )
532
        parser.add_argument(
533
534
535
536
            "--log-level-http",
            type=str,
            default=ServerArgs.log_level_http,
            help="The logging level of HTTP server. If not set, reuse --log-level by default.",
537
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
538
        parser.add_argument(
539
            "--log-requests",
Lianmin Zheng's avatar
Lianmin Zheng committed
540
            action="store_true",
541
            help="Log the inputs and outputs of all requests.",
Lianmin Zheng's avatar
Lianmin Zheng committed
542
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
543
544
545
        parser.add_argument(
            "--show-time-cost",
            action="store_true",
Ying Sheng's avatar
Ying Sheng committed
546
            help="Show time cost of custom marks.",
Lianmin Zheng's avatar
Lianmin Zheng committed
547
        )
548
549
550
551
552
        parser.add_argument(
            "--enable-metrics",
            action="store_true",
            help="Enable log prometheus metrics.",
        )
553
554
555
556
        parser.add_argument(
            "--decode-log-interval",
            type=int,
            default=ServerArgs.decode_log_interval,
Lianmin Zheng's avatar
Lianmin Zheng committed
557
558
559
560
561
562
563
            help="The log interval of decode batch.",
        )
        parser.add_argument(
            "--dump-requests-folder",
            type=str,
            default=ServerArgs.decode_log_interval,
            help="Dump raw requests to a folder for replay.",
564
        )
565

566
        # API related
Liangsheng Yin's avatar
Liangsheng Yin committed
567
568
569
570
        parser.add_argument(
            "--api-key",
            type=str,
            default=ServerArgs.api_key,
571
            help="Set API key of the server. It is also used in the OpenAI API compatible server.",
Liangsheng Yin's avatar
Liangsheng Yin committed
572
        )
573
574
575
576
577
578
        parser.add_argument(
            "--file-storage-pth",
            type=str,
            default=ServerArgs.file_storage_pth,
            help="The path of the file storage in backend.",
        )
579
580
581
582
583
        parser.add_argument(
            "--enable-cache-report",
            action="store_true",
            help="Return number of cached tokens in usage.prompt_tokens_details for each openai request.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
584

585
586
        # Data parallelism
        parser.add_argument(
587
            "--data-parallel-size",
588
589
590
            "--dp-size",
            type=int,
            default=ServerArgs.dp_size,
591
            help="The data parallelism size.",
592
593
594
595
596
        )
        parser.add_argument(
            "--load-balance-method",
            type=str,
            default=ServerArgs.load_balance_method,
597
            help="The load balancing strategy for data parallelism.",
598
599
600
601
602
            choices=[
                "round_robin",
                "shortest_queue",
            ],
        )
603

xiaobochen's avatar
xiaobochen committed
604
605
606
607
608
609
610
611
        # Expert parallelism
        parser.add_argument(
            "--expert-parallel-size",
            "--ep-size",
            type=int,
            default=ServerArgs.ep_size,
            help="The expert parallelism size.",
        )
612

613
        # Multi-node distributed serving
614
        parser.add_argument(
615
616
            "--dist-init-addr",
            "--nccl-init-addr",  # For backward compatbility. This will be removed in the future.
617
            type=str,
618
            help="The host address for initializing distributed backend (e.g., `192.168.0.2:25000`).",
619
620
        )
        parser.add_argument(
Liangsheng Yin's avatar
Liangsheng Yin committed
621
            "--nnodes", type=int, default=ServerArgs.nnodes, help="The number of nodes."
622
        )
623
624
625
        parser.add_argument(
            "--node-rank", type=int, default=ServerArgs.node_rank, help="The node rank."
        )
626

Lianmin Zheng's avatar
Lianmin Zheng committed
627
628
629
630
631
632
633
634
        # Model override args
        parser.add_argument(
            "--json-model-override-args",
            type=str,
            help="A dictionary in JSON string format used to override default model configurations.",
            default=ServerArgs.json_model_override_args,
        )

635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
        # LoRA
        parser.add_argument(
            "--lora-paths",
            type=str,
            nargs="*",
            default=None,
            action=LoRAPathAction,
            help="The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}",
        )
        parser.add_argument(
            "--max-loras-per-batch",
            type=int,
            default=8,
            help="Maximum number of adapters for a running batch, include base-only request",
        )

        # Kernel backend
652
653
654
        parser.add_argument(
            "--attention-backend",
            type=str,
655
            choices=["flashinfer", "triton", "torch_native"],
656
657
658
659
660
661
662
663
664
665
            default=ServerArgs.attention_backend,
            help="Choose the kernels for attention layers.",
        )
        parser.add_argument(
            "--sampling-backend",
            type=str,
            choices=["flashinfer", "pytorch"],
            default=ServerArgs.sampling_backend,
            help="Choose the kernels for sampling layers.",
        )
666
667
668
669
670
        parser.add_argument(
            "--grammar-backend",
            type=str,
            choices=["xgrammar", "outlines"],
            default=ServerArgs.grammar_backend,
Lianmin Zheng's avatar
Lianmin Zheng committed
671
            help="Choose the backend for grammar-guided decoding.",
672
        )
673

674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
        # Speculative decoding
        parser.add_argument(
            "--speculative-algorithm",
            type=str,
            choices=["EAGLE"],
            help="Speculative algorithm.",
        )
        parser.add_argument(
            "--speculative-draft-model-path",
            type=str,
            help="The path of the draft model weights. This can be a local folder or a Hugging Face repo ID.",
        )
        parser.add_argument(
            "--speculative-num-steps",
            type=int,
            help="The number of steps sampled from draft model in Speculative Decoding.",
            default=ServerArgs.speculative_num_steps,
        )
        parser.add_argument(
            "--speculative-num-draft-tokens",
            type=int,
            help="The number of token sampled from draft model in Speculative Decoding.",
            default=ServerArgs.speculative_num_draft_tokens,
        )
        parser.add_argument(
            "--speculative-eagle-topk",
            type=int,
            help="The number of token sampled from draft model in eagle2 each step.",
            choices=[1, 2, 4, 8],
            default=ServerArgs.speculative_eagle_topk,
        )

        # Double Sparsity
        parser.add_argument(
            "--enable-double-sparsity",
            action="store_true",
            help="Enable double sparsity attention",
        )
        parser.add_argument(
            "--ds-channel-config-path",
            type=str,
            default=ServerArgs.ds_channel_config_path,
            help="The path of the double sparsity channel config",
        )
        parser.add_argument(
            "--ds-heavy-channel-num",
            type=int,
            default=ServerArgs.ds_heavy_channel_num,
            help="The number of heavy channels in double sparsity attention",
        )
        parser.add_argument(
            "--ds-heavy-token-num",
            type=int,
            default=ServerArgs.ds_heavy_token_num,
            help="The number of heavy tokens in double sparsity attention",
        )
        parser.add_argument(
            "--ds-heavy-channel-type",
            type=str,
            default=ServerArgs.ds_heavy_channel_type,
            help="The type of heavy channels in double sparsity attention",
        )
        parser.add_argument(
            "--ds-sparse-decode-threshold",
            type=int,
            default=ServerArgs.ds_sparse_decode_threshold,
            help="The type of heavy channels in double sparsity attention",
        )

743
        # Optimization/debug options
Liangsheng Yin's avatar
Liangsheng Yin committed
744
        parser.add_argument(
Lianmin Zheng's avatar
Lianmin Zheng committed
745
            "--disable-radix-cache",
Liangsheng Yin's avatar
Liangsheng Yin committed
746
            action="store_true",
Ying Sheng's avatar
Ying Sheng committed
747
            help="Disable RadixAttention for prefix caching.",
Liangsheng Yin's avatar
Liangsheng Yin committed
748
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
749
        parser.add_argument(
Lianmin Zheng's avatar
Lianmin Zheng committed
750
            "--disable-jump-forward",
Liangsheng Yin's avatar
Liangsheng Yin committed
751
            action="store_true",
Lianmin Zheng's avatar
Lianmin Zheng committed
752
            help="Disable jump-forward for grammar-guided decoding.",
Liangsheng Yin's avatar
Liangsheng Yin committed
753
        )
754
755
756
757
758
        parser.add_argument(
            "--disable-cuda-graph",
            action="store_true",
            help="Disable cuda graph.",
        )
759
        parser.add_argument(
760
761
762
763
764
            "--disable-cuda-graph-padding",
            action="store_true",
            help="Disable cuda graph when padding is needed. Still uses cuda graph when padding is not needed.",
        )
        parser.add_argument(
765
            "--disable-outlines-disk-cache",
766
            action="store_true",
767
            help="Disable disk cache of outlines to avoid possible crashes related to file system or high concurrency.",
768
        )
769
770
771
772
773
        parser.add_argument(
            "--disable-custom-all-reduce",
            action="store_true",
            help="Disable the custom all-reduce kernel and fall back to NCCL.",
        )
Ke Bao's avatar
Ke Bao committed
774
775
776
777
778
        parser.add_argument(
            "--disable-mla",
            action="store_true",
            help="Disable Multi-head Latent Attention (MLA) for DeepSeek-V2.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
779
        parser.add_argument(
780
            "--disable-overlap-schedule",
Lianmin Zheng's avatar
Lianmin Zheng committed
781
            action="store_true",
782
            help="Disable the overlap scheduler, which overlaps the CPU scheduler with GPU model worker.",
Lianmin Zheng's avatar
Lianmin Zheng committed
783
        )
784
785
786
        parser.add_argument(
            "--enable-mixed-chunk",
            action="store_true",
787
            help="Enabling mixing prefill and decode in a batch when using chunked prefill.",
788
        )
Ke Bao's avatar
Ke Bao committed
789
790
791
792
793
        parser.add_argument(
            "--enable-dp-attention",
            action="store_true",
            help="Enabling data parallelism for attention and tensor parallelism for FFN. The dp size should be equal to the tp size. Currently only DeepSeek-V2 is supported.",
        )
xiaobochen's avatar
xiaobochen committed
794
795
796
797
798
        parser.add_argument(
            "--enable-ep-moe",
            action="store_true",
            help="Enabling expert parallelism for moe. The ep size is equal to the tp size.",
        )
799
800
801
        parser.add_argument(
            "--enable-torch-compile",
            action="store_true",
802
803
            help="Optimize the model with torch.compile. Experimental feature.",
        )
804
        parser.add_argument(
805
            "--torch-compile-max-bs",
806
            type=int,
807
            default=ServerArgs.torch_compile_max_bs,
808
809
            help="Set the maximum batch size when using torch compile.",
        )
810
        parser.add_argument(
811
            "--cuda-graph-max-bs",
812
            type=int,
813
            default=ServerArgs.cuda_graph_max_bs,
814
815
            help="Set the maximum batch size for cuda graph.",
        )
816
817
818
819
820
821
        parser.add_argument(
            "--cuda-graph-bs",
            type=int,
            nargs="+",
            help="Set the list of batch sizes for cuda graph.",
        )
822
823
824
825
        parser.add_argument(
            "--torchao-config",
            type=str,
            default=ServerArgs.torchao_config,
826
            help="Optimize the model with torchao. Experimental feature. Current choices are: int8dq, int8wo, int4wo-<group_size>, fp8wo, fp8dq-per_tensor, fp8dq-per_row",
827
        )
828
829
830
831
832
        parser.add_argument(
            "--enable-nan-detection",
            action="store_true",
            help="Enable the NaN detection for debugging purposes.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
833
        parser.add_argument(
834
            "--enable-p2p-check",
Lianmin Zheng's avatar
Lianmin Zheng committed
835
            action="store_true",
836
            help="Enable P2P check for GPU access, otherwise the p2p access is allowed by default.",
Lianmin Zheng's avatar
Lianmin Zheng committed
837
        )
838
        parser.add_argument(
839
            "--triton-attention-reduce-in-fp32",
840
            action="store_true",
841
            help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
842
            "This only affects Triton attention kernels.",
843
        )
844
845
846
847
848
849
        parser.add_argument(
            "--triton-attention-num-kv-splits",
            type=int,
            default=ServerArgs.triton_attention_num_kv_splits,
            help="The number of KV splits in flash decoding Triton kernel. Larger value is better in longer context scenarios. The default value is 8.",
        )
850
851
852
853
854
855
856
857
        parser.add_argument(
            "--num-continuous-decode-steps",
            type=int,
            default=ServerArgs.num_continuous_decode_steps,
            help="Run multiple continuous decoding steps to reduce scheduling overhead. "
            "This can potentially increase throughput but may also increase time-to-first-token latency. "
            "The default value is 1, meaning only run one decoding step at a time.",
        )
858
859
860
861
862
        parser.add_argument(
            "--delete-ckpt-after-loading",
            action="store_true",
            help="Delete the model checkpoint after loading the model.",
        )
863

Lianmin Zheng's avatar
Lianmin Zheng committed
864
865
    @classmethod
    def from_cli_args(cls, args: argparse.Namespace):
866
867
        args.tp_size = args.tensor_parallel_size
        args.dp_size = args.data_parallel_size
xiaobochen's avatar
xiaobochen committed
868
        args.ep_size = args.expert_parallel_size
Lianmin Zheng's avatar
Lianmin Zheng committed
869
870
871
872
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        return cls(**{attr: getattr(args, attr) for attr in attrs})

    def url(self):
873
874
875
876
        if is_ipv6(self.host):
            return f"http://[{self.host}]:{self.port}"
        else:
            return f"http://{self.host}:{self.port}"
Lianmin Zheng's avatar
Lianmin Zheng committed
877

878
879
880
881
882
    def check_server_args(self):
        assert (
            self.tp_size % self.nnodes == 0
        ), "tp_size must be divisible by number of nodes"
        assert not (
883
            self.dp_size > 1 and self.nnodes != 1
884
        ), "multi-node data parallel is not supported"
885
886
887
888
889
890
        assert (
            self.max_loras_per_batch > 0
            # FIXME
            and (self.lora_paths is None or self.disable_cuda_graph)
            and (self.lora_paths is None or self.disable_radix_cache)
        ), "compatibility of lora and cuda graph and radix attention is in progress"
891
        assert self.base_gpu_id >= 0, "base_gpu_id must be non-negative"
892

893
894
895
896
897
898
899
900
901
902
        if isinstance(self.lora_paths, list):
            lora_paths = self.lora_paths
            self.lora_paths = {}
            for lora_path in lora_paths:
                if "=" in lora_path:
                    name, path = lora_path.split("=", 1)
                    self.lora_paths[name] = path
                else:
                    self.lora_paths[lora_path] = lora_path

Lianmin Zheng's avatar
Lianmin Zheng committed
903

Lianmin Zheng's avatar
Lianmin Zheng committed
904
def prepare_server_args(argv: List[str]) -> ServerArgs:
905
906
907
908
909
910
911
912
913
914
915
916
    """
    Prepare the server arguments from the command line arguments.

    Args:
        args: The command line arguments. Typically, it should be `sys.argv[1:]`
            to ensure compatibility with `parse_args` when no arguments are passed.

    Returns:
        The server arguments.
    """
    parser = argparse.ArgumentParser()
    ServerArgs.add_cli_args(parser)
Lianmin Zheng's avatar
Lianmin Zheng committed
917
    raw_args = parser.parse_args(argv)
918
919
920
921
    server_args = ServerArgs.from_cli_args(raw_args)
    return server_args


Lianmin Zheng's avatar
Lianmin Zheng committed
922
923
@dataclasses.dataclass
class PortArgs:
924
925
926
927
928
929
    # The ipc filename for tokenizer to receive inputs from detokenizer (zmq)
    tokenizer_ipc_name: str
    # The ipc filename for scheduler (rank 0) to receive inputs from tokenizer (zmq)
    scheduler_input_ipc_name: str
    # The ipc filename for detokenizer to receive inputs from scheduler (zmq)
    detokenizer_ipc_name: str
930

931
932
    # The port for nccl initialization (torch.dist)
    nccl_port: int
933

934
935
    @staticmethod
    def init_new(server_args) -> "PortArgs":
936
        port = server_args.port + random.randint(100, 1000)
937
938
939
        while True:
            if is_port_available(port):
                break
TianYu GUO's avatar
TianYu GUO committed
940
941
942
943
            if port < 60000:
                port += 42
            else:
                port -= 43
944
945
946
947
948

        return PortArgs(
            tokenizer_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
            scheduler_input_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
            detokenizer_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
949
            nccl_port=port,
950
951
        )

952
953
954
955
956
957
958
959
960
961

class LoRAPathAction(argparse.Action):
    def __call__(self, parser, namespace, values, option_string=None):
        setattr(namespace, self.dest, {})
        for lora_path in values:
            if "=" in lora_path:
                name, path = lora_path.split("=", 1)
                getattr(namespace, self.dest)[name] = path
            else:
                getattr(namespace, self.dest)[lora_path] = lora_path
962
963
964
965
966
967
968
969
970
971


class DeprecatedAction(argparse.Action):
    def __init__(self, option_strings, dest, nargs=0, **kwargs):
        super(DeprecatedAction, self).__init__(
            option_strings, dest, nargs=nargs, **kwargs
        )

    def __call__(self, parser, namespace, values, option_string=None):
        raise ValueError(self.help)