server_args.py 15 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

Lianmin Zheng's avatar
Lianmin Zheng committed
16
17
"""The arguments of the server."""

Lianmin Zheng's avatar
Lianmin Zheng committed
18
19
import argparse
import dataclasses
20
import random
21
from typing import List, Optional, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
22
23
24
25


@dataclasses.dataclass
class ServerArgs:
Lianmin Zheng's avatar
Lianmin Zheng committed
26
    # Model and tokenizer
Lianmin Zheng's avatar
Lianmin Zheng committed
27
28
29
    model_path: str
    tokenizer_path: Optional[str] = None
    tokenizer_mode: str = "auto"
Lianmin Zheng's avatar
Lianmin Zheng committed
30
31
    load_format: str = "auto"
    dtype: str = "auto"
Lianmin Zheng's avatar
Lianmin Zheng committed
32
    trust_remote_code: bool = True
Lianmin Zheng's avatar
Lianmin Zheng committed
33
    context_length: Optional[int] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
34
    quantization: Optional[str] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
35
    chat_template: Optional[str] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
36
37
38
39
40
41
42

    # Port
    host: str = "127.0.0.1"
    port: int = 30000
    additional_ports: Optional[Union[List[int], int]] = None

    # Memory and scheduling
Lianmin Zheng's avatar
Lianmin Zheng committed
43
    mem_fraction_static: Optional[float] = None
44
45
    max_prefill_tokens: Optional[int] = None
    max_running_requests: Optional[int] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
46
    max_num_reqs: Optional[int] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
47
    schedule_heuristic: str = "lpm"
48
    schedule_conservativeness: float = 1.0
Lianmin Zheng's avatar
Lianmin Zheng committed
49
50
51

    # Other runtime options
    tp_size: int = 1
52
    stream_interval: int = 1
53
    random_seed: Optional[int] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
54
55
56

    # Logging
    log_level: str = "info"
57
    log_level_http: Optional[str] = None
58
    log_requests: bool = False
Liangsheng Yin's avatar
Liangsheng Yin committed
59
    show_time_cost: bool = False
Liangsheng Yin's avatar
Liangsheng Yin committed
60

Lianmin Zheng's avatar
Lianmin Zheng committed
61
62
    # Other
    api_key: str = ""
63
    file_storage_pth: str = "SGlang_storage"
Lianmin Zheng's avatar
Lianmin Zheng committed
64

65
66
67
68
    # Data parallelism
    dp_size: int = 1
    load_balance_method: str = "round_robin"

Liangsheng Yin's avatar
Liangsheng Yin committed
69
70
71
    # Chunked Prefill
    chunked_prefill_size: Optional[int] = None

Lianmin Zheng's avatar
Lianmin Zheng committed
72
    # Optimization/debug options
73
    disable_flashinfer: bool = False
74
    disable_flashinfer_sampling: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
75
    disable_radix_cache: bool = False
76
    disable_regex_jump_forward: bool = False
77
    disable_cuda_graph: bool = False
78
    disable_disk_cache: bool = False
79
    enable_torch_compile: bool = False
80
    enable_p2p_check: bool = False
81
    attention_reduce_in_fp32: bool = False
82
    efficient_weight_load: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
83

84
85
86
87
88
    # Distributed args
    nccl_init_addr: Optional[str] = None
    nnodes: int = 1
    node_rank: Optional[int] = None

Lianmin Zheng's avatar
Lianmin Zheng committed
89
90
91
    def __post_init__(self):
        if self.tokenizer_path is None:
            self.tokenizer_path = self.model_path
Lianmin Zheng's avatar
Lianmin Zheng committed
92
        if self.mem_fraction_static is None:
93
            if self.tp_size >= 16:
Ying Sheng's avatar
Ying Sheng committed
94
                self.mem_fraction_static = 0.80
95
            elif self.tp_size >= 8:
Ying Sheng's avatar
Ying Sheng committed
96
                self.mem_fraction_static = 0.84
Lianmin Zheng's avatar
Lianmin Zheng committed
97
            elif self.tp_size >= 4:
Ying Sheng's avatar
Ying Sheng committed
98
                self.mem_fraction_static = 0.86
Lianmin Zheng's avatar
Lianmin Zheng committed
99
            elif self.tp_size >= 2:
100
                self.mem_fraction_static = 0.88
Ying Sheng's avatar
Ying Sheng committed
101
102
            else:
                self.mem_fraction_static = 0.89
103
104
105
106
        if isinstance(self.additional_ports, int):
            self.additional_ports = [self.additional_ports]
        elif self.additional_ports is None:
            self.additional_ports = []
Lianmin Zheng's avatar
Lianmin Zheng committed
107

108
109
110
        if self.random_seed is None:
            self.random_seed = random.randint(0, 1 << 30)

Lianmin Zheng's avatar
Lianmin Zheng committed
111
112
113
114
115
116
117
118
119
120
121
122
123
124
    @staticmethod
    def add_cli_args(parser: argparse.ArgumentParser):
        parser.add_argument(
            "--model-path",
            type=str,
            help="The path of the model weights. This can be a local folder or a Hugging Face repo ID.",
            required=True,
        )
        parser.add_argument(
            "--tokenizer-path",
            type=str,
            default=ServerArgs.tokenizer_path,
            help="The path of the tokenizer.",
        )
Yuanhan Zhang's avatar
Yuanhan Zhang committed
125
126
127
128
129
130
        parser.add_argument(
            "--host", type=str, default=ServerArgs.host, help="The host of the server."
        )
        parser.add_argument(
            "--port", type=int, default=ServerArgs.port, help="The port of the server."
        )
131
132
133
134
135
        parser.add_argument(
            "--additional-ports",
            type=int,
            nargs="*",
            default=[],
136
            help="The additional ports specified for the server.",
137
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
138
139
140
141
142
143
144
145
146
        parser.add_argument(
            "--tokenizer-mode",
            type=str,
            default=ServerArgs.tokenizer_mode,
            choices=["auto", "slow"],
            help="Tokenizer mode. 'auto' will use the fast "
            "tokenizer if available, and 'slow' will "
            "always use the slow tokenizer.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
        parser.add_argument(
            "--load-format",
            type=str,
            default=ServerArgs.load_format,
            choices=["auto", "pt", "safetensors", "npcache", "dummy"],
            help="The format of the model weights to load. "
            '"auto" will try to load the weights in the safetensors format '
            "and fall back to the pytorch bin format if safetensors format "
            "is not available. "
            '"pt" will load the weights in the pytorch bin format. '
            '"safetensors" will load the weights in the safetensors format. '
            '"npcache" will load the weights in pytorch format and store '
            "a numpy cache to speed up the loading. "
            '"dummy" will initialize the weights with random values, '
            "which is mainly for profiling.",
        )
        parser.add_argument(
Lianmin Zheng's avatar
Lianmin Zheng committed
164
            "--dtype",
Cody Yu's avatar
Cody Yu committed
165
            type=str,
Lianmin Zheng's avatar
Lianmin Zheng committed
166
            default=ServerArgs.dtype,
Ying Sheng's avatar
Ying Sheng committed
167
168
            choices=["auto", "half", "float16", "bfloat16", "float", "float32"],
            help="Data type for model weights and activations.\n\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
169
            '* "auto" will use FP16 precision for FP32 and FP16 models, and '
Ying Sheng's avatar
Ying Sheng committed
170
            "BF16 precision for BF16 models.\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
171
172
173
174
            '* "half" for FP16. Recommended for AWQ quantization.\n'
            '* "float16" is the same as "half".\n'
            '* "bfloat16" for a balance between precision and range.\n'
            '* "float" is shorthand for FP32 precision.\n'
Ying Sheng's avatar
Ying Sheng committed
175
176
            '* "float32" for FP32 precision.',
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
177
178
179
180
181
        parser.add_argument(
            "--trust-remote-code",
            action="store_true",
            help="Whether or not to allow for custom models defined on the Hub in their own modeling files.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
182
183
184
185
186
187
        parser.add_argument(
            "--context-length",
            type=int,
            default=ServerArgs.context_length,
            help="The model's maximum context length. Defaults to None (will use the value from the model's config.json instead).",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
188
189
190
191
        parser.add_argument(
            "--quantization",
            type=str,
            default=ServerArgs.quantization,
Ying Sheng's avatar
Ying Sheng committed
192
193
194
195
196
197
198
199
200
            choices=[
                "awq",
                "fp8",
                "gptq",
                "marlin",
                "gptq_marlin",
                "squeezellm",
                "bitsandbytes",
            ],
Lianmin Zheng's avatar
Lianmin Zheng committed
201
202
            help="The quantization method.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
203
204
205
206
207
208
        parser.add_argument(
            "--chat-template",
            type=str,
            default=ServerArgs.chat_template,
            help="The buliltin chat template name or the path of the chat template file. This is only used for OpenAI-compatible API server.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
209
210
211
212
        parser.add_argument(
            "--mem-fraction-static",
            type=float,
            default=ServerArgs.mem_fraction_static,
213
            help="The fraction of the memory used for static allocation (model weights and KV cache memory pool). Use a smaller value if you see out-of-memory errors.",
Lianmin Zheng's avatar
Lianmin Zheng committed
214
        )
215
        parser.add_argument(
216
            "--max-prefill-tokens",
217
            type=int,
218
            default=ServerArgs.max_prefill_tokens,
219
            help="The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model's maximum context length.",
220
        )
221
222
223
224
225
226
        parser.add_argument(
            "--max-running-requests",
            type=int,
            default=ServerArgs.max_running_requests,
            help="The maximum number of running requests.",
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
227
228
229
        parser.add_argument(
            "--max-num-reqs",
            type=int,
Liangsheng Yin's avatar
Liangsheng Yin committed
230
            default=ServerArgs.max_num_reqs,
Liangsheng Yin's avatar
Liangsheng Yin committed
231
232
            help="The maximum number of requests to serve in the memory pool. If the model have a large context length, you may need to decrease this value to avoid out-of-memory errors.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
233
234
235
236
        parser.add_argument(
            "--schedule-heuristic",
            type=str,
            default=ServerArgs.schedule_heuristic,
Liangsheng Yin's avatar
Liangsheng Yin committed
237
            choices=["lpm", "random", "fcfs", "dfs-weight"],
238
            help="The scheduling heuristic.",
Lianmin Zheng's avatar
Lianmin Zheng committed
239
        )
240
241
242
243
        parser.add_argument(
            "--schedule-conservativeness",
            type=float,
            default=ServerArgs.schedule_conservativeness,
244
            help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see requests being retracted frequently.",
245
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
246
        parser.add_argument(
Lianmin Zheng's avatar
Lianmin Zheng committed
247
            "--tp-size",
Lianmin Zheng's avatar
Lianmin Zheng committed
248
            type=int,
Lianmin Zheng's avatar
Lianmin Zheng committed
249
            default=ServerArgs.tp_size,
250
            help="The tensor parallelism size.",
251
        )
252
253
254
        parser.add_argument(
            "--stream-interval",
            type=int,
Lianmin Zheng's avatar
Lianmin Zheng committed
255
            default=ServerArgs.stream_interval,
256
            help="The interval (or buffer size) for streaming in terms of the token length. A smaller value makes streaming smoother, while a larger value makes the throughput higher",
257
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
258
259
260
261
        parser.add_argument(
            "--random-seed",
            type=int,
            default=ServerArgs.random_seed,
262
            help="The random seed.",
Lianmin Zheng's avatar
Lianmin Zheng committed
263
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
264
265
266
267
        parser.add_argument(
            "--log-level",
            type=str,
            default=ServerArgs.log_level,
268
            help="The logging level of all loggers.",
Lianmin Zheng's avatar
Lianmin Zheng committed
269
        )
270
        parser.add_argument(
271
272
273
274
            "--log-level-http",
            type=str,
            default=ServerArgs.log_level_http,
            help="The logging level of HTTP server. If not set, reuse --log-level by default.",
275
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
276
        parser.add_argument(
277
            "--log-requests",
Lianmin Zheng's avatar
Lianmin Zheng committed
278
            action="store_true",
279
            help="Log the inputs and outputs of all requests.",
Lianmin Zheng's avatar
Lianmin Zheng committed
280
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
281
282
283
        parser.add_argument(
            "--show-time-cost",
            action="store_true",
Ying Sheng's avatar
Ying Sheng committed
284
            help="Show time cost of custom marks.",
Lianmin Zheng's avatar
Lianmin Zheng committed
285
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
286
287
288
289
        parser.add_argument(
            "--api-key",
            type=str,
            default=ServerArgs.api_key,
Ying Sheng's avatar
Ying Sheng committed
290
            help="Set API key of the server.",
Liangsheng Yin's avatar
Liangsheng Yin committed
291
        )
292
293
294
295
296
297
        parser.add_argument(
            "--file-storage-pth",
            type=str,
            default=ServerArgs.file_storage_pth,
            help="The path of the file storage in backend.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
298

299
300
301
302
303
        # Data parallelism
        parser.add_argument(
            "--dp-size",
            type=int,
            default=ServerArgs.dp_size,
304
            help="The data parallelism size.",
305
306
307
308
309
        )
        parser.add_argument(
            "--load-balance-method",
            type=str,
            default=ServerArgs.load_balance_method,
310
            help="The load balancing strategy for data parallelism.",
311
312
313
314
315
316
            choices=[
                "round_robin",
                "shortest_queue",
            ],
        )

317
318
319
320
        # Multi-node distributed serving args
        parser.add_argument(
            "--nccl-init-addr",
            type=str,
Ying Sheng's avatar
Ying Sheng committed
321
            help="The nccl init address of multi-node server.",
322
323
        )
        parser.add_argument(
Liangsheng Yin's avatar
Liangsheng Yin committed
324
            "--nnodes", type=int, default=ServerArgs.nnodes, help="The number of nodes."
325
        )
Ying Sheng's avatar
Ying Sheng committed
326
        parser.add_argument("--node-rank", type=int, help="The node rank.")
327

Liangsheng Yin's avatar
Liangsheng Yin committed
328
329
330
331
332
333
334
335
        # Chunked prefill
        parser.add_argument(
            "--chunked-prefill-size",
            type=int,
            default=ServerArgs.chunked_prefill_size,
            help="The size of the chunked prefill.",
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
336
        # Optimization/debug options
Liangsheng Yin's avatar
Liangsheng Yin committed
337
        parser.add_argument(
338
            "--disable-flashinfer",
Liangsheng Yin's avatar
Liangsheng Yin committed
339
            action="store_true",
340
341
342
343
344
345
            help="Disable flashinfer attention kernels.",
        )
        parser.add_argument(
            "--disable-flashinfer-sampling",
            action="store_true",
            help="Disable flashinfer sampling kernels.",
Liangsheng Yin's avatar
Liangsheng Yin committed
346
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
347
        parser.add_argument(
Lianmin Zheng's avatar
Lianmin Zheng committed
348
            "--disable-radix-cache",
Liangsheng Yin's avatar
Liangsheng Yin committed
349
            action="store_true",
Ying Sheng's avatar
Ying Sheng committed
350
            help="Disable RadixAttention for prefix caching.",
Liangsheng Yin's avatar
Liangsheng Yin committed
351
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
352
        parser.add_argument(
353
            "--disable-regex-jump-forward",
Liangsheng Yin's avatar
Liangsheng Yin committed
354
            action="store_true",
Ying Sheng's avatar
Ying Sheng committed
355
            help="Disable regex jump-forward.",
Liangsheng Yin's avatar
Liangsheng Yin committed
356
        )
357
358
359
360
361
        parser.add_argument(
            "--disable-cuda-graph",
            action="store_true",
            help="Disable cuda graph.",
        )
362
363
364
365
366
        parser.add_argument(
            "--disable-disk-cache",
            action="store_true",
            help="Disable disk cache to avoid possible crashes related to file system or high concurrency.",
        )
367
368
369
370
371
        parser.add_argument(
            "--enable-torch-compile",
            action="store_true",
            help="Optimize the model with torch.compile, experimental feature.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
372
        parser.add_argument(
373
            "--enable-p2p-check",
Lianmin Zheng's avatar
Lianmin Zheng committed
374
            action="store_true",
375
            help="Enable P2P check for GPU access, otherwise the p2p access is allowed by default.",
Lianmin Zheng's avatar
Lianmin Zheng committed
376
        )
377
        parser.add_argument(
378
            "--attention-reduce-in-fp32",
379
            action="store_true",
380
381
            help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
            "This only affects Triton attention kernels",
382
        )
383
384
385
386
387
        parser.add_argument(
            "--efficient-weight-load",
            action="store_true",
            help="Turn on memory efficient weight loading with quantization (quantize per layer during loading).",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
388
389
390
391
392
393
394
395
396

    @classmethod
    def from_cli_args(cls, args: argparse.Namespace):
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        return cls(**{attr: getattr(args, attr) for attr in attrs})

    def url(self):
        return f"http://{self.host}:{self.port}"

Lianmin Zheng's avatar
Lianmin Zheng committed
397
    def print_mode_args(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
398
        return (
399
            f"disable_flashinfer={self.disable_flashinfer}, "
400
            f"attention_reduce_in_fp32={self.attention_reduce_in_fp32}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
401
            f"disable_radix_cache={self.disable_radix_cache}, "
Liangsheng Yin's avatar
Liangsheng Yin committed
402
403
404
405
            f"disable_regex_jump_forward={self.disable_regex_jump_forward}, "
            f"disable_disk_cache={self.disable_disk_cache}, "
        )

406
407
408
409
410
411
412
413
    def check_server_args(self):
        assert (
            self.tp_size % self.nnodes == 0
        ), "tp_size must be divisible by number of nodes"
        assert not (
            self.dp_size > 1 and self.node_rank is not None
        ), "multi-node data parallel is not supported"

Liangsheng Yin's avatar
Liangsheng Yin committed
414
        assert not (
415
            self.chunked_prefill_size is not None and self.disable_radix_cache
Liangsheng Yin's avatar
Liangsheng Yin committed
416
417
        ), "chunked prefill is not supported with radix cache disabled currently"

Lianmin Zheng's avatar
Lianmin Zheng committed
418
419
420
421

@dataclasses.dataclass
class PortArgs:
    tokenizer_port: int
Mingyi's avatar
Mingyi committed
422
    controller_port: int
Lianmin Zheng's avatar
Lianmin Zheng committed
423
    detokenizer_port: int
Mingyi's avatar
Mingyi committed
424
    nccl_ports: List[int]