server_args.py 17 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

Lianmin Zheng's avatar
Lianmin Zheng committed
16
17
"""The arguments of the server."""

Lianmin Zheng's avatar
Lianmin Zheng committed
18
19
import argparse
import dataclasses
20
import logging
21
import random
22
from typing import List, Optional, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
23

24
25
logger = logging.getLogger(__name__)

Lianmin Zheng's avatar
Lianmin Zheng committed
26
27
28

@dataclasses.dataclass
class ServerArgs:
Lianmin Zheng's avatar
Lianmin Zheng committed
29
    # Model and tokenizer
Lianmin Zheng's avatar
Lianmin Zheng committed
30
31
32
    model_path: str
    tokenizer_path: Optional[str] = None
    tokenizer_mode: str = "auto"
33
    skip_tokenizer_init: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
34
35
    load_format: str = "auto"
    dtype: str = "auto"
Lianmin Zheng's avatar
Lianmin Zheng committed
36
    trust_remote_code: bool = True
Lianmin Zheng's avatar
Lianmin Zheng committed
37
    context_length: Optional[int] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
38
    quantization: Optional[str] = None
39
    served_model_name: Optional[str] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
40
    chat_template: Optional[str] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
41
42
43
44
45
46
47

    # Port
    host: str = "127.0.0.1"
    port: int = 30000
    additional_ports: Optional[Union[List[int], int]] = None

    # Memory and scheduling
Lianmin Zheng's avatar
Lianmin Zheng committed
48
    mem_fraction_static: Optional[float] = None
49
    max_running_requests: Optional[int] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
50
    max_num_reqs: Optional[int] = None
51
    max_total_tokens: Optional[int] = None
52
    chunked_prefill_size: int = 8192
53
    max_prefill_tokens: int = 16384
54
    schedule_policy: str = "lpm"
55
    schedule_conservativeness: float = 1.0
Lianmin Zheng's avatar
Lianmin Zheng committed
56
57
58

    # Other runtime options
    tp_size: int = 1
59
    stream_interval: int = 1
60
    random_seed: Optional[int] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
61
62
63

    # Logging
    log_level: str = "info"
64
    log_level_http: Optional[str] = None
65
    log_requests: bool = False
Liangsheng Yin's avatar
Liangsheng Yin committed
66
    show_time_cost: bool = False
Liangsheng Yin's avatar
Liangsheng Yin committed
67

Lianmin Zheng's avatar
Lianmin Zheng committed
68
    # Other
69
    api_key: Optional[str] = None
70
    file_storage_pth: str = "SGLang_storage"
Lianmin Zheng's avatar
Lianmin Zheng committed
71

72
73
74
75
    # Data parallelism
    dp_size: int = 1
    load_balance_method: str = "round_robin"

Lianmin Zheng's avatar
Lianmin Zheng committed
76
    # Optimization/debug options
77
    disable_flashinfer: bool = False
78
    disable_flashinfer_sampling: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
79
    disable_radix_cache: bool = False
80
    disable_regex_jump_forward: bool = False
81
    disable_cuda_graph: bool = False
82
    disable_cuda_graph_padding: bool = False
83
    disable_disk_cache: bool = False
84
    enable_mixed_chunk: bool = False
85
    enable_torch_compile: bool = False
86
    enable_p2p_check: bool = False
87
    enable_mla: bool = False
88
    attention_reduce_in_fp32: bool = False
89
    efficient_weight_load: bool = False
90
    disable_custom_all_reduce: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
91

92
93
94
95
96
    # Distributed args
    nccl_init_addr: Optional[str] = None
    nnodes: int = 1
    node_rank: Optional[int] = None

Lianmin Zheng's avatar
Lianmin Zheng committed
97
98
99
    def __post_init__(self):
        if self.tokenizer_path is None:
            self.tokenizer_path = self.model_path
100
101
102
103

        if self.served_model_name is None:
            self.served_model_name = self.model_path

104
105
106
107
        if self.chunked_prefill_size <= 0:
            # Disable chunked prefill
            self.chunked_prefill_size = None

Lianmin Zheng's avatar
Lianmin Zheng committed
108
        if self.mem_fraction_static is None:
109
            if self.tp_size >= 16:
110
                self.mem_fraction_static = 0.79
111
            elif self.tp_size >= 8:
112
                self.mem_fraction_static = 0.83
Lianmin Zheng's avatar
Lianmin Zheng committed
113
            elif self.tp_size >= 4:
114
                self.mem_fraction_static = 0.85
Lianmin Zheng's avatar
Lianmin Zheng committed
115
            elif self.tp_size >= 2:
116
                self.mem_fraction_static = 0.87
Ying Sheng's avatar
Ying Sheng committed
117
            else:
118
                self.mem_fraction_static = 0.88
119

120
121
122
123
        if isinstance(self.additional_ports, int):
            self.additional_ports = [self.additional_ports]
        elif self.additional_ports is None:
            self.additional_ports = []
Lianmin Zheng's avatar
Lianmin Zheng committed
124

125
126
127
        if self.random_seed is None:
            self.random_seed = random.randint(0, 1 << 30)

Lianmin Zheng's avatar
Lianmin Zheng committed
128
129
130
131
132
133
134
135
136
137
138
139
140
141
    @staticmethod
    def add_cli_args(parser: argparse.ArgumentParser):
        parser.add_argument(
            "--model-path",
            type=str,
            help="The path of the model weights. This can be a local folder or a Hugging Face repo ID.",
            required=True,
        )
        parser.add_argument(
            "--tokenizer-path",
            type=str,
            default=ServerArgs.tokenizer_path,
            help="The path of the tokenizer.",
        )
Yuanhan Zhang's avatar
Yuanhan Zhang committed
142
143
144
145
146
147
        parser.add_argument(
            "--host", type=str, default=ServerArgs.host, help="The host of the server."
        )
        parser.add_argument(
            "--port", type=int, default=ServerArgs.port, help="The port of the server."
        )
148
149
150
151
152
        parser.add_argument(
            "--additional-ports",
            type=int,
            nargs="*",
            default=[],
153
            help="The additional ports specified for the server.",
154
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
155
156
157
158
159
160
161
162
163
        parser.add_argument(
            "--tokenizer-mode",
            type=str,
            default=ServerArgs.tokenizer_mode,
            choices=["auto", "slow"],
            help="Tokenizer mode. 'auto' will use the fast "
            "tokenizer if available, and 'slow' will "
            "always use the slow tokenizer.",
        )
164
165
166
167
168
        parser.add_argument(
            "--skip-tokenizer-init",
            action="store_true",
            help="If set, skip init tokenizer and pass input_ids in generate request",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
        parser.add_argument(
            "--load-format",
            type=str,
            default=ServerArgs.load_format,
            choices=["auto", "pt", "safetensors", "npcache", "dummy"],
            help="The format of the model weights to load. "
            '"auto" will try to load the weights in the safetensors format '
            "and fall back to the pytorch bin format if safetensors format "
            "is not available. "
            '"pt" will load the weights in the pytorch bin format. '
            '"safetensors" will load the weights in the safetensors format. '
            '"npcache" will load the weights in pytorch format and store '
            "a numpy cache to speed up the loading. "
            '"dummy" will initialize the weights with random values, '
            "which is mainly for profiling.",
        )
        parser.add_argument(
Lianmin Zheng's avatar
Lianmin Zheng committed
186
            "--dtype",
Cody Yu's avatar
Cody Yu committed
187
            type=str,
Lianmin Zheng's avatar
Lianmin Zheng committed
188
            default=ServerArgs.dtype,
Ying Sheng's avatar
Ying Sheng committed
189
190
            choices=["auto", "half", "float16", "bfloat16", "float", "float32"],
            help="Data type for model weights and activations.\n\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
191
            '* "auto" will use FP16 precision for FP32 and FP16 models, and '
Ying Sheng's avatar
Ying Sheng committed
192
            "BF16 precision for BF16 models.\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
193
194
195
196
            '* "half" for FP16. Recommended for AWQ quantization.\n'
            '* "float16" is the same as "half".\n'
            '* "bfloat16" for a balance between precision and range.\n'
            '* "float" is shorthand for FP32 precision.\n'
Ying Sheng's avatar
Ying Sheng committed
197
198
            '* "float32" for FP32 precision.',
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
199
200
201
202
203
        parser.add_argument(
            "--trust-remote-code",
            action="store_true",
            help="Whether or not to allow for custom models defined on the Hub in their own modeling files.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
204
205
206
207
208
209
        parser.add_argument(
            "--context-length",
            type=int,
            default=ServerArgs.context_length,
            help="The model's maximum context length. Defaults to None (will use the value from the model's config.json instead).",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
210
211
212
213
        parser.add_argument(
            "--quantization",
            type=str,
            default=ServerArgs.quantization,
Ying Sheng's avatar
Ying Sheng committed
214
215
216
217
218
219
            choices=[
                "awq",
                "fp8",
                "gptq",
                "marlin",
                "gptq_marlin",
Ying Sheng's avatar
Ying Sheng committed
220
                "awq_marlin",
Ying Sheng's avatar
Ying Sheng committed
221
222
223
                "squeezellm",
                "bitsandbytes",
            ],
Lianmin Zheng's avatar
Lianmin Zheng committed
224
225
            help="The quantization method.",
        )
226
227
228
229
230
231
        parser.add_argument(
            "--served-model-name",
            type=str,
            default=ServerArgs.served_model_name,
            help="Override the model name returned by the v1/models endpoint in OpenAI API server.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
232
233
234
235
236
237
        parser.add_argument(
            "--chat-template",
            type=str,
            default=ServerArgs.chat_template,
            help="The buliltin chat template name or the path of the chat template file. This is only used for OpenAI-compatible API server.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
238
239
240
241
        parser.add_argument(
            "--mem-fraction-static",
            type=float,
            default=ServerArgs.mem_fraction_static,
242
            help="The fraction of the memory used for static allocation (model weights and KV cache memory pool). Use a smaller value if you see out-of-memory errors.",
Lianmin Zheng's avatar
Lianmin Zheng committed
243
        )
244
245
246
247
248
249
        parser.add_argument(
            "--max-running-requests",
            type=int,
            default=ServerArgs.max_running_requests,
            help="The maximum number of running requests.",
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
250
251
252
        parser.add_argument(
            "--max-num-reqs",
            type=int,
Liangsheng Yin's avatar
Liangsheng Yin committed
253
            default=ServerArgs.max_num_reqs,
Liangsheng Yin's avatar
Liangsheng Yin committed
254
255
            help="The maximum number of requests to serve in the memory pool. If the model have a large context length, you may need to decrease this value to avoid out-of-memory errors.",
        )
256
257
258
259
260
261
        parser.add_argument(
            "--max-total-tokens",
            type=int,
            default=ServerArgs.max_total_tokens,
            help="The maximum number of tokens in the memory pool. If not specified, it will be automatically calculated based on the memory usage fraction. This option is typically used for development and debugging purposes.",
        )
262
263
264
265
266
267
268
269
270
271
272
273
        parser.add_argument(
            "--chunked-prefill-size",
            type=int,
            default=ServerArgs.chunked_prefill_size,
            help="The maximum number of tokens in a chunk for the chunked prefill. Setting this to -1 means disabling chunked prefill",
        )
        parser.add_argument(
            "--max-prefill-tokens",
            type=int,
            default=ServerArgs.max_prefill_tokens,
            help="The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model's maximum context length.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
274
        parser.add_argument(
275
            "--schedule-policy",
Lianmin Zheng's avatar
Lianmin Zheng committed
276
            type=str,
277
            default=ServerArgs.schedule_policy,
Liangsheng Yin's avatar
Liangsheng Yin committed
278
            choices=["lpm", "random", "fcfs", "dfs-weight"],
279
            help="The scheduling policy of the requests.",
Lianmin Zheng's avatar
Lianmin Zheng committed
280
        )
281
282
283
284
        parser.add_argument(
            "--schedule-conservativeness",
            type=float,
            default=ServerArgs.schedule_conservativeness,
285
            help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see requests being retracted frequently.",
286
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
287
        parser.add_argument(
288
            "--tensor-parallel-size",
Lianmin Zheng's avatar
Lianmin Zheng committed
289
            "--tp-size",
Lianmin Zheng's avatar
Lianmin Zheng committed
290
            type=int,
Lianmin Zheng's avatar
Lianmin Zheng committed
291
            default=ServerArgs.tp_size,
292
            help="The tensor parallelism size.",
293
        )
294
295
296
        parser.add_argument(
            "--stream-interval",
            type=int,
Lianmin Zheng's avatar
Lianmin Zheng committed
297
            default=ServerArgs.stream_interval,
298
            help="The interval (or buffer size) for streaming in terms of the token length. A smaller value makes streaming smoother, while a larger value makes the throughput higher",
299
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
300
301
302
303
        parser.add_argument(
            "--random-seed",
            type=int,
            default=ServerArgs.random_seed,
304
            help="The random seed.",
Lianmin Zheng's avatar
Lianmin Zheng committed
305
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
306
307
308
309
        parser.add_argument(
            "--log-level",
            type=str,
            default=ServerArgs.log_level,
310
            help="The logging level of all loggers.",
Lianmin Zheng's avatar
Lianmin Zheng committed
311
        )
312
        parser.add_argument(
313
314
315
316
            "--log-level-http",
            type=str,
            default=ServerArgs.log_level_http,
            help="The logging level of HTTP server. If not set, reuse --log-level by default.",
317
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
318
        parser.add_argument(
319
            "--log-requests",
Lianmin Zheng's avatar
Lianmin Zheng committed
320
            action="store_true",
321
            help="Log the inputs and outputs of all requests.",
Lianmin Zheng's avatar
Lianmin Zheng committed
322
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
323
324
325
        parser.add_argument(
            "--show-time-cost",
            action="store_true",
Ying Sheng's avatar
Ying Sheng committed
326
            help="Show time cost of custom marks.",
Lianmin Zheng's avatar
Lianmin Zheng committed
327
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
328
329
330
331
        parser.add_argument(
            "--api-key",
            type=str,
            default=ServerArgs.api_key,
332
            help="Set API key of the server. It is also used in the OpenAI API compatible server.",
Liangsheng Yin's avatar
Liangsheng Yin committed
333
        )
334
335
336
337
338
339
        parser.add_argument(
            "--file-storage-pth",
            type=str,
            default=ServerArgs.file_storage_pth,
            help="The path of the file storage in backend.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
340

341
342
        # Data parallelism
        parser.add_argument(
343
            "--data-parallel-size",
344
345
346
            "--dp-size",
            type=int,
            default=ServerArgs.dp_size,
347
            help="The data parallelism size.",
348
349
350
351
352
        )
        parser.add_argument(
            "--load-balance-method",
            type=str,
            default=ServerArgs.load_balance_method,
353
            help="The load balancing strategy for data parallelism.",
354
355
356
357
358
359
            choices=[
                "round_robin",
                "shortest_queue",
            ],
        )

360
361
362
363
        # Multi-node distributed serving args
        parser.add_argument(
            "--nccl-init-addr",
            type=str,
Ying Sheng's avatar
Ying Sheng committed
364
            help="The nccl init address of multi-node server.",
365
366
        )
        parser.add_argument(
Liangsheng Yin's avatar
Liangsheng Yin committed
367
            "--nnodes", type=int, default=ServerArgs.nnodes, help="The number of nodes."
368
        )
Ying Sheng's avatar
Ying Sheng committed
369
        parser.add_argument("--node-rank", type=int, help="The node rank.")
370

Lianmin Zheng's avatar
Lianmin Zheng committed
371
        # Optimization/debug options
Liangsheng Yin's avatar
Liangsheng Yin committed
372
        parser.add_argument(
373
            "--disable-flashinfer",
Liangsheng Yin's avatar
Liangsheng Yin committed
374
            action="store_true",
375
376
377
378
379
380
            help="Disable flashinfer attention kernels.",
        )
        parser.add_argument(
            "--disable-flashinfer-sampling",
            action="store_true",
            help="Disable flashinfer sampling kernels.",
Liangsheng Yin's avatar
Liangsheng Yin committed
381
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
382
        parser.add_argument(
Lianmin Zheng's avatar
Lianmin Zheng committed
383
            "--disable-radix-cache",
Liangsheng Yin's avatar
Liangsheng Yin committed
384
            action="store_true",
Ying Sheng's avatar
Ying Sheng committed
385
            help="Disable RadixAttention for prefix caching.",
Liangsheng Yin's avatar
Liangsheng Yin committed
386
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
387
        parser.add_argument(
388
            "--disable-regex-jump-forward",
Liangsheng Yin's avatar
Liangsheng Yin committed
389
            action="store_true",
Ying Sheng's avatar
Ying Sheng committed
390
            help="Disable regex jump-forward.",
Liangsheng Yin's avatar
Liangsheng Yin committed
391
        )
392
393
394
395
396
        parser.add_argument(
            "--disable-cuda-graph",
            action="store_true",
            help="Disable cuda graph.",
        )
397
        parser.add_argument(
398
399
400
401
402
            "--disable-cuda-graph-padding",
            action="store_true",
            help="Disable cuda graph when padding is needed. Still uses cuda graph when padding is not needed.",
        )
        parser.add_argument(
403
404
405
406
            "--disable-disk-cache",
            action="store_true",
            help="Disable disk cache to avoid possible crashes related to file system or high concurrency.",
        )
407
408
409
410
411
        parser.add_argument(
            "--enable-mixed-chunk",
            action="store_true",
            help="Enabling mixing prefill and decode in a chunked batch.",
        )
412
413
414
415
416
        parser.add_argument(
            "--enable-torch-compile",
            action="store_true",
            help="Optimize the model with torch.compile, experimental feature.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
417
        parser.add_argument(
418
            "--enable-p2p-check",
Lianmin Zheng's avatar
Lianmin Zheng committed
419
            action="store_true",
420
            help="Enable P2P check for GPU access, otherwise the p2p access is allowed by default.",
Lianmin Zheng's avatar
Lianmin Zheng committed
421
        )
422
423
424
        parser.add_argument(
            "--enable-mla",
            action="store_true",
425
            help="Enable Multi-head Latent Attention (MLA) for DeepSeek-V2.",
426
        )
427
        parser.add_argument(
428
            "--attention-reduce-in-fp32",
429
            action="store_true",
430
            help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
431
            "This only affects Triton attention kernels.",
432
        )
433
434
435
436
437
        parser.add_argument(
            "--efficient-weight-load",
            action="store_true",
            help="Turn on memory efficient weight loading with quantization (quantize per layer during loading).",
        )
438
439
440
441
442
443
        parser.add_argument(
            "--disable-custom-all-reduce",
            action="store_true",
            default=False,
            help="Disable the custom all-reduce kernel and fall back to NCCL.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
444
445
446

    @classmethod
    def from_cli_args(cls, args: argparse.Namespace):
447
448
        args.tp_size = args.tensor_parallel_size
        args.dp_size = args.data_parallel_size
Lianmin Zheng's avatar
Lianmin Zheng committed
449
450
451
452
453
454
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        return cls(**{attr: getattr(args, attr) for attr in attrs})

    def url(self):
        return f"http://{self.host}:{self.port}"

455
456
457
458
459
460
461
    def check_server_args(self):
        assert (
            self.tp_size % self.nnodes == 0
        ), "tp_size must be divisible by number of nodes"
        assert not (
            self.dp_size > 1 and self.node_rank is not None
        ), "multi-node data parallel is not supported"
462
        if "gemma-2" in self.model_path.lower():
463
            logger.info("When using sliding window in gemma-2, turn on flashinfer.")
464
            self.disable_flashinfer = False
465

Lianmin Zheng's avatar
Lianmin Zheng committed
466
467
468
469

@dataclasses.dataclass
class PortArgs:
    tokenizer_port: int
Mingyi's avatar
Mingyi committed
470
    controller_port: int
Lianmin Zheng's avatar
Lianmin Zheng committed
471
    detokenizer_port: int
Mingyi's avatar
Mingyi committed
472
    nccl_ports: List[int]