server_args.py 14.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

Lianmin Zheng's avatar
Lianmin Zheng committed
16
17
"""The arguments of the server."""

Lianmin Zheng's avatar
Lianmin Zheng committed
18
19
import argparse
import dataclasses
20
import random
21
from typing import List, Optional, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
22
23
24
25


@dataclasses.dataclass
class ServerArgs:
Lianmin Zheng's avatar
Lianmin Zheng committed
26
    # Model and tokenizer
Lianmin Zheng's avatar
Lianmin Zheng committed
27
28
29
    model_path: str
    tokenizer_path: Optional[str] = None
    tokenizer_mode: str = "auto"
Lianmin Zheng's avatar
Lianmin Zheng committed
30
31
    load_format: str = "auto"
    dtype: str = "auto"
Lianmin Zheng's avatar
Lianmin Zheng committed
32
    trust_remote_code: bool = True
Lianmin Zheng's avatar
Lianmin Zheng committed
33
    context_length: Optional[int] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
34
    quantization: Optional[str] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
35
    chat_template: Optional[str] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
36
37
38
39
40
41
42

    # Port
    host: str = "127.0.0.1"
    port: int = 30000
    additional_ports: Optional[Union[List[int], int]] = None

    # Memory and scheduling
Lianmin Zheng's avatar
Lianmin Zheng committed
43
    mem_fraction_static: Optional[float] = None
44
45
    max_prefill_tokens: Optional[int] = None
    max_running_requests: Optional[int] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
46
    max_num_reqs: Optional[int] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
47
    schedule_heuristic: str = "lpm"
48
    schedule_conservativeness: float = 1.0
Lianmin Zheng's avatar
Lianmin Zheng committed
49
50
51

    # Other runtime options
    tp_size: int = 1
52
    stream_interval: int = 1
53
    random_seed: Optional[int] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
54
55
56

    # Logging
    log_level: str = "info"
57
    log_level_http: Optional[str] = None
58
    log_requests: bool = False
Liangsheng Yin's avatar
Liangsheng Yin committed
59
    show_time_cost: bool = False
Liangsheng Yin's avatar
Liangsheng Yin committed
60

Lianmin Zheng's avatar
Lianmin Zheng committed
61
62
63
    # Other
    api_key: str = ""

64
65
66
67
    # Data parallelism
    dp_size: int = 1
    load_balance_method: str = "round_robin"

Liangsheng Yin's avatar
Liangsheng Yin committed
68
69
70
    # Chunked Prefill
    chunked_prefill_size: Optional[int] = None

Lianmin Zheng's avatar
Lianmin Zheng committed
71
    # Optimization/debug options
72
    disable_flashinfer: bool = False
73
    disable_flashinfer_sampling: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
74
    disable_radix_cache: bool = False
75
    disable_regex_jump_forward: bool = False
76
    disable_cuda_graph: bool = False
77
    disable_disk_cache: bool = False
78
    enable_torch_compile: bool = False
79
    enable_p2p_check: bool = False
80
    attention_reduce_in_fp32: bool = False
81
    efficient_weight_load: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
82

83
84
85
86
87
    # Distributed args
    nccl_init_addr: Optional[str] = None
    nnodes: int = 1
    node_rank: Optional[int] = None

Lianmin Zheng's avatar
Lianmin Zheng committed
88
    def __post_init__(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
89
90
        if self.chunked_prefill_size is None:
            self.chunked_prefill_size = 1 << 30
Lianmin Zheng's avatar
Lianmin Zheng committed
91
92
        if self.tokenizer_path is None:
            self.tokenizer_path = self.model_path
Lianmin Zheng's avatar
Lianmin Zheng committed
93
        if self.mem_fraction_static is None:
94
            if self.tp_size >= 16:
Ying Sheng's avatar
Ying Sheng committed
95
                self.mem_fraction_static = 0.80
96
            elif self.tp_size >= 8:
Ying Sheng's avatar
Ying Sheng committed
97
                self.mem_fraction_static = 0.84
Lianmin Zheng's avatar
Lianmin Zheng committed
98
            elif self.tp_size >= 4:
Ying Sheng's avatar
Ying Sheng committed
99
                self.mem_fraction_static = 0.86
Lianmin Zheng's avatar
Lianmin Zheng committed
100
            elif self.tp_size >= 2:
101
                self.mem_fraction_static = 0.88
Ying Sheng's avatar
Ying Sheng committed
102
103
            else:
                self.mem_fraction_static = 0.89
104
105
106
107
        if isinstance(self.additional_ports, int):
            self.additional_ports = [self.additional_ports]
        elif self.additional_ports is None:
            self.additional_ports = []
Lianmin Zheng's avatar
Lianmin Zheng committed
108

109
110
111
        if self.random_seed is None:
            self.random_seed = random.randint(0, 1 << 30)

Lianmin Zheng's avatar
Lianmin Zheng committed
112
113
114
115
116
117
118
119
120
121
122
123
124
125
    @staticmethod
    def add_cli_args(parser: argparse.ArgumentParser):
        parser.add_argument(
            "--model-path",
            type=str,
            help="The path of the model weights. This can be a local folder or a Hugging Face repo ID.",
            required=True,
        )
        parser.add_argument(
            "--tokenizer-path",
            type=str,
            default=ServerArgs.tokenizer_path,
            help="The path of the tokenizer.",
        )
Yuanhan Zhang's avatar
Yuanhan Zhang committed
126
127
128
129
130
131
        parser.add_argument(
            "--host", type=str, default=ServerArgs.host, help="The host of the server."
        )
        parser.add_argument(
            "--port", type=int, default=ServerArgs.port, help="The port of the server."
        )
132
133
134
135
136
        parser.add_argument(
            "--additional-ports",
            type=int,
            nargs="*",
            default=[],
137
            help="The additional ports specified for the server.",
138
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
139
140
141
142
143
144
145
146
147
        parser.add_argument(
            "--tokenizer-mode",
            type=str,
            default=ServerArgs.tokenizer_mode,
            choices=["auto", "slow"],
            help="Tokenizer mode. 'auto' will use the fast "
            "tokenizer if available, and 'slow' will "
            "always use the slow tokenizer.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
        parser.add_argument(
            "--load-format",
            type=str,
            default=ServerArgs.load_format,
            choices=["auto", "pt", "safetensors", "npcache", "dummy"],
            help="The format of the model weights to load. "
            '"auto" will try to load the weights in the safetensors format '
            "and fall back to the pytorch bin format if safetensors format "
            "is not available. "
            '"pt" will load the weights in the pytorch bin format. '
            '"safetensors" will load the weights in the safetensors format. '
            '"npcache" will load the weights in pytorch format and store '
            "a numpy cache to speed up the loading. "
            '"dummy" will initialize the weights with random values, '
            "which is mainly for profiling.",
        )
        parser.add_argument(
Lianmin Zheng's avatar
Lianmin Zheng committed
165
            "--dtype",
Cody Yu's avatar
Cody Yu committed
166
            type=str,
Lianmin Zheng's avatar
Lianmin Zheng committed
167
            default=ServerArgs.dtype,
Ying Sheng's avatar
Ying Sheng committed
168
169
            choices=["auto", "half", "float16", "bfloat16", "float", "float32"],
            help="Data type for model weights and activations.\n\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
170
            '* "auto" will use FP16 precision for FP32 and FP16 models, and '
Ying Sheng's avatar
Ying Sheng committed
171
            "BF16 precision for BF16 models.\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
172
173
174
175
            '* "half" for FP16. Recommended for AWQ quantization.\n'
            '* "float16" is the same as "half".\n'
            '* "bfloat16" for a balance between precision and range.\n'
            '* "float" is shorthand for FP32 precision.\n'
Ying Sheng's avatar
Ying Sheng committed
176
177
            '* "float32" for FP32 precision.',
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
178
179
180
181
182
        parser.add_argument(
            "--trust-remote-code",
            action="store_true",
            help="Whether or not to allow for custom models defined on the Hub in their own modeling files.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
183
184
185
186
187
188
        parser.add_argument(
            "--context-length",
            type=int,
            default=ServerArgs.context_length,
            help="The model's maximum context length. Defaults to None (will use the value from the model's config.json instead).",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
189
190
191
192
        parser.add_argument(
            "--quantization",
            type=str,
            default=ServerArgs.quantization,
Ying Sheng's avatar
Ying Sheng committed
193
194
195
196
197
198
199
200
201
            choices=[
                "awq",
                "fp8",
                "gptq",
                "marlin",
                "gptq_marlin",
                "squeezellm",
                "bitsandbytes",
            ],
Lianmin Zheng's avatar
Lianmin Zheng committed
202
203
            help="The quantization method.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
204
205
206
207
208
209
        parser.add_argument(
            "--chat-template",
            type=str,
            default=ServerArgs.chat_template,
            help="The buliltin chat template name or the path of the chat template file. This is only used for OpenAI-compatible API server.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
210
211
212
213
        parser.add_argument(
            "--mem-fraction-static",
            type=float,
            default=ServerArgs.mem_fraction_static,
214
            help="The fraction of the memory used for static allocation (model weights and KV cache memory pool). Use a smaller value if you see out-of-memory errors.",
Lianmin Zheng's avatar
Lianmin Zheng committed
215
        )
216
        parser.add_argument(
217
            "--max-prefill-tokens",
218
            type=int,
219
            default=ServerArgs.max_prefill_tokens,
220
            help="The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model's maximum context length.",
221
        )
222
223
224
225
226
227
        parser.add_argument(
            "--max-running-requests",
            type=int,
            default=ServerArgs.max_running_requests,
            help="The maximum number of running requests.",
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
228
229
230
        parser.add_argument(
            "--max-num-reqs",
            type=int,
Liangsheng Yin's avatar
Liangsheng Yin committed
231
            default=ServerArgs.max_num_reqs,
Liangsheng Yin's avatar
Liangsheng Yin committed
232
233
            help="The maximum number of requests to serve in the memory pool. If the model have a large context length, you may need to decrease this value to avoid out-of-memory errors.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
234
235
236
237
        parser.add_argument(
            "--schedule-heuristic",
            type=str,
            default=ServerArgs.schedule_heuristic,
Liangsheng Yin's avatar
Liangsheng Yin committed
238
            choices=["lpm", "random", "fcfs", "dfs-weight"],
239
            help="The scheduling heuristic.",
Lianmin Zheng's avatar
Lianmin Zheng committed
240
        )
241
242
243
244
        parser.add_argument(
            "--schedule-conservativeness",
            type=float,
            default=ServerArgs.schedule_conservativeness,
245
            help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see requests being retracted frequently.",
246
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
247
        parser.add_argument(
Lianmin Zheng's avatar
Lianmin Zheng committed
248
            "--tp-size",
Lianmin Zheng's avatar
Lianmin Zheng committed
249
            type=int,
Lianmin Zheng's avatar
Lianmin Zheng committed
250
            default=ServerArgs.tp_size,
251
            help="The tensor parallelism size.",
252
        )
253
254
255
        parser.add_argument(
            "--stream-interval",
            type=int,
Lianmin Zheng's avatar
Lianmin Zheng committed
256
            default=ServerArgs.stream_interval,
257
            help="The interval (or buffer size) for streaming in terms of the token length. A smaller value makes streaming smoother, while a larger value makes the throughput higher",
258
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
259
260
261
262
        parser.add_argument(
            "--random-seed",
            type=int,
            default=ServerArgs.random_seed,
263
            help="The random seed.",
Lianmin Zheng's avatar
Lianmin Zheng committed
264
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
265
266
267
268
        parser.add_argument(
            "--log-level",
            type=str,
            default=ServerArgs.log_level,
269
            help="The logging level of all loggers.",
Lianmin Zheng's avatar
Lianmin Zheng committed
270
        )
271
        parser.add_argument(
272
273
274
275
            "--log-level-http",
            type=str,
            default=ServerArgs.log_level_http,
            help="The logging level of HTTP server. If not set, reuse --log-level by default.",
276
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
277
        parser.add_argument(
278
            "--log-requests",
Lianmin Zheng's avatar
Lianmin Zheng committed
279
            action="store_true",
280
            help="Log the inputs and outputs of all requests.",
Lianmin Zheng's avatar
Lianmin Zheng committed
281
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
282
283
284
        parser.add_argument(
            "--show-time-cost",
            action="store_true",
Ying Sheng's avatar
Ying Sheng committed
285
            help="Show time cost of custom marks.",
Lianmin Zheng's avatar
Lianmin Zheng committed
286
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
287
288
289
290
        parser.add_argument(
            "--api-key",
            type=str,
            default=ServerArgs.api_key,
Ying Sheng's avatar
Ying Sheng committed
291
            help="Set API key of the server.",
Liangsheng Yin's avatar
Liangsheng Yin committed
292
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
293

294
295
296
297
298
        # Data parallelism
        parser.add_argument(
            "--dp-size",
            type=int,
            default=ServerArgs.dp_size,
299
            help="The data parallelism size.",
300
301
302
303
304
        )
        parser.add_argument(
            "--load-balance-method",
            type=str,
            default=ServerArgs.load_balance_method,
305
            help="The load balancing strategy for data parallelism.",
306
307
308
309
310
311
            choices=[
                "round_robin",
                "shortest_queue",
            ],
        )

312
313
314
315
        # Multi-node distributed serving args
        parser.add_argument(
            "--nccl-init-addr",
            type=str,
Ying Sheng's avatar
Ying Sheng committed
316
            help="The nccl init address of multi-node server.",
317
318
        )
        parser.add_argument(
Liangsheng Yin's avatar
Liangsheng Yin committed
319
            "--nnodes", type=int, default=ServerArgs.nnodes, help="The number of nodes."
320
        )
Ying Sheng's avatar
Ying Sheng committed
321
        parser.add_argument("--node-rank", type=int, help="The node rank.")
322

Liangsheng Yin's avatar
Liangsheng Yin committed
323
324
325
326
327
328
329
330
        # Chunked prefill
        parser.add_argument(
            "--chunked-prefill-size",
            type=int,
            default=ServerArgs.chunked_prefill_size,
            help="The size of the chunked prefill.",
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
331
        # Optimization/debug options
Liangsheng Yin's avatar
Liangsheng Yin committed
332
        parser.add_argument(
333
            "--disable-flashinfer",
Liangsheng Yin's avatar
Liangsheng Yin committed
334
            action="store_true",
335
336
337
338
339
340
            help="Disable flashinfer attention kernels.",
        )
        parser.add_argument(
            "--disable-flashinfer-sampling",
            action="store_true",
            help="Disable flashinfer sampling kernels.",
Liangsheng Yin's avatar
Liangsheng Yin committed
341
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
342
        parser.add_argument(
Lianmin Zheng's avatar
Lianmin Zheng committed
343
            "--disable-radix-cache",
Liangsheng Yin's avatar
Liangsheng Yin committed
344
            action="store_true",
Ying Sheng's avatar
Ying Sheng committed
345
            help="Disable RadixAttention for prefix caching.",
Liangsheng Yin's avatar
Liangsheng Yin committed
346
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
347
        parser.add_argument(
348
            "--disable-regex-jump-forward",
Liangsheng Yin's avatar
Liangsheng Yin committed
349
            action="store_true",
Ying Sheng's avatar
Ying Sheng committed
350
            help="Disable regex jump-forward.",
Liangsheng Yin's avatar
Liangsheng Yin committed
351
        )
352
353
354
355
356
        parser.add_argument(
            "--disable-cuda-graph",
            action="store_true",
            help="Disable cuda graph.",
        )
357
358
359
360
361
        parser.add_argument(
            "--disable-disk-cache",
            action="store_true",
            help="Disable disk cache to avoid possible crashes related to file system or high concurrency.",
        )
362
363
364
365
366
        parser.add_argument(
            "--enable-torch-compile",
            action="store_true",
            help="Optimize the model with torch.compile, experimental feature.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
367
        parser.add_argument(
368
            "--enable-p2p-check",
Lianmin Zheng's avatar
Lianmin Zheng committed
369
            action="store_true",
370
            help="Enable P2P check for GPU access, otherwise the p2p access is allowed by default.",
Lianmin Zheng's avatar
Lianmin Zheng committed
371
        )
372
        parser.add_argument(
373
            "--attention-reduce-in-fp32",
374
            action="store_true",
375
376
            help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
            "This only affects Triton attention kernels",
377
        )
378
379
380
381
382
        parser.add_argument(
            "--efficient-weight-load",
            action="store_true",
            help="Turn on memory efficient weight loading with quantization (quantize per layer during loading).",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
383
384
385
386
387
388
389
390
391

    @classmethod
    def from_cli_args(cls, args: argparse.Namespace):
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        return cls(**{attr: getattr(args, attr) for attr in attrs})

    def url(self):
        return f"http://{self.host}:{self.port}"

Lianmin Zheng's avatar
Lianmin Zheng committed
392
    def print_mode_args(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
393
        return (
394
            f"disable_flashinfer={self.disable_flashinfer}, "
395
            f"attention_reduce_in_fp32={self.attention_reduce_in_fp32}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
396
            f"disable_radix_cache={self.disable_radix_cache}, "
Liangsheng Yin's avatar
Liangsheng Yin committed
397
398
399
400
            f"disable_regex_jump_forward={self.disable_regex_jump_forward}, "
            f"disable_disk_cache={self.disable_disk_cache}, "
        )

401
402
403
404
405
406
407
408
    def check_server_args(self):
        assert (
            self.tp_size % self.nnodes == 0
        ), "tp_size must be divisible by number of nodes"
        assert not (
            self.dp_size > 1 and self.node_rank is not None
        ), "multi-node data parallel is not supported"

Liangsheng Yin's avatar
Liangsheng Yin committed
409
410
411
412
        assert not (
            self.chunked_prefill_size < (1 << 30) and self.disable_radix_cache
        ), "chunked prefill is not supported with radix cache disabled currently"

Lianmin Zheng's avatar
Lianmin Zheng committed
413
414
415
416

@dataclasses.dataclass
class PortArgs:
    tokenizer_port: int
Mingyi's avatar
Mingyi committed
417
    controller_port: int
Lianmin Zheng's avatar
Lianmin Zheng committed
418
    detokenizer_port: int
Mingyi's avatar
Mingyi committed
419
    nccl_ports: List[int]